C++ 23 lines Fenwick Tree solution


  • 1

    Using Fenwick Tree / Binary Indexed Tree.
    Both update and sumRegion runtimes are O(logm * logn).

    class NumMatrix {
    private:
        vector<vector<int>> matrix, fenwick;            // original matrix, and fenwich tree array
        
        int lsb(int n) { return n & (-n); }             // get the least significant bit
        
        void add(int row, int col, int diff) {          // add diff to matrix[row][col]
            for (int r = row + 1; r < fenwick.size(); r += lsb(r))
                for (int c = col + 1; c < fenwick[0].size(); c += lsb(c))
                    fenwick[r][c] += diff;
        }
        
        int getSum(int row, int col) {                  // get sum of submatrix: (0, 0) - (row, col)
            int sum = 0;
            for (int r = row + 1; r > 0; r -= lsb(r))
                for (int c = col + 1; c > 0; c -= lsb(c))
                    sum += fenwick[r][c];
            return sum;
        }
        
    public:
        NumMatrix(vector<vector<int>> &matrix) {
            if (matrix.empty() || matrix[0].empty()) { return; }
            this->matrix = matrix;                      // record original matrix
            fenwick = vector<vector<int>>(matrix.size() + 1, vector<int>(matrix[0].size() + 1, 0));
            for (int r = 0; r < matrix.size(); r++)
                for (int c = 0; c < matrix[0].size(); c++)
                    add(r, c, matrix[r][c]);            // initialize fenwick tree
        }
    
        void update(int row, int col, int val) {
            add(row, col, val - matrix[row][col]);      // add diff to fenwick tree
            matrix[row][col] = val;                     // update original matrix value
        }
    
        int sumRegion(int row1, int col1, int row2, int col2) {
            /*
            int A = getSum(row2, col2);
            int B = getSum(row1 - 1, col1 - 1);
            int C = getSum(row1 - 1, col2);
            int D = getSum(row2, col1 - 1);
            return A + B - C - D;
            */
            return getSum(row2, col2) + getSum(row1 - 1, col1 - 1) - getSum(row1 - 1, col2) - getSum(row2, col1 - 1);
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.