C++ BIT solution


  • 0
    G
    class NumMatrix {
    public:
        
        vector<vector<int> > sum;
        vector<vector<int> > bit;
        
        NumMatrix(vector<vector<int>> matrix) {
            if(matrix.size() == 0){
                return;
            }
            
            sum = vector<vector<int> > (matrix.size(),vector<int>(matrix[0].size(),0));
            bit = vector<vector<int> > (matrix.size() + 1, vector<int>(matrix[0].size() + 1,0));
            
            for(int i = 0; i < matrix.size(); ++i){
                for(int j = 0; j < matrix[i].size(); ++j){
                    update(i,j,matrix[i][j]);
                }
            }
            
            /*
            for(int i = 0; i < bit.size(); ++i){
                for(int j = 0; j < bit[0].size(); ++j){
                    cout << bit[i][j] << " ";
                }
                
                cout << endl;
            }*/
        }
        
        void update(int row, int col, int val) {
            int curr = sum[row][col];        
            int toSet = val - curr;
            sum[row][col] = val;
    
            for(int i = row + 1; i < bit.size(); i += (i & (-i))){
                for(int j = col + 1; j < bit[0].size(); j += (j & (-j))){
                    bit[i][j] += toSet;
                }
            }
        }
        
        int getSum(int row, int col){
            int currSum = 0;
            
            for(int i = row + 1; i > 0; i -= (i & (-i))){
                for(int j = col + 1; j > 0; j -= (j & (-j))){
                    currSum += bit[i][j];
                }
            }
            
            return currSum;
        }
        
        int sumRegion(int row1, int col1, int row2, int col2) {
            int sumAll = getSum(row2,col2);
            int partialSum1 = 0,partialSum2 = 0, partialSum3 = 0;
            
            if(row1 > 0){
                partialSum1 = getSum(row1 - 1,col2);
            }
            
            if(col1 > 0){
                partialSum2 = getSum(row2, col1 - 1);
            }
            
            if(row1 > 0 && col1 > 0){
                partialSum3 = getSum(row1 - 1, col1 - 1);
            }
            
            return sumAll - partialSum1 - partialSum2 + partialSum3;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.