Share my Java 2-D Binary Indexed Tree Solution


  • 8

    Based on the 1-D solution in problem Range Sum Query - Mutable, we can extend it to solve this 2-D problem.

    Initializing the binary indexed tree takes O(mn*logm*logn) time, both update() and getSum() take O(logm*logn) time. The arr[][] is used to keep a backup of the matrix[][] so that we know the difference of the updated element and use that to update the binary indexed tree. The idea of calculating sumRegion() is the same as in Range Sum Query 2D - Immutable.

    public class NumMatrix {
      int m, n;
      int[][] arr;    // stores matrix[][]
      int[][] BITree; // 2-D binary indexed tree
      
      public NumMatrix(int[][] matrix) {
        if (matrix.length == 0 || matrix[0].length == 0) {
            return;
        }
        
        m = matrix.length;
        n = matrix[0].length;
        
        arr = new int[m][n];
        BITree = new int[m + 1][n + 1];
        
        for (int i = 0; i < m; i++) {
          for (int j = 0; j < n; j++) {
            update(i, j, matrix[i][j]); // init BITree[][]
            arr[i][j] = matrix[i][j];   // init arr[][]
          }
        }
      }
      
      public void update(int i, int j, int val) {
        int diff = val - arr[i][j];  // get the diff
        arr[i][j] = val;             // update arr[][]
          
        i++; j++;
        while (i <= m) {
          int k = j;
          while (k <= n) {
            BITree[i][k] += diff; // update BITree[][]
            k += k & (-k); // update column index to that of parent
          }
          i += i & (-i);   // update row index to that of parent
        }
      }
      
      int getSum(int i, int j) {
        int sum = 0;
        
        i++; j++;
        while (i > 0) {
          int k = j;
          while (k > 0) {
            sum += BITree[i][k]; // accumulate the sum
            k -= k & (-k); // move column index to parent node
          }
          i -= i & (-i);   // move row index to parent node
        }
        return sum;
      }
      
      public int sumRegion(int i1, int j1, int i2, int j2) {
        return getSum(i2, j2) - getSum(i1-1, j2) - getSum(i2, j1-1) + getSum(i1-1, j1-1);
      }
    }
    

  • 0
    M

    Thank you so much for posting this. I have spent several hours to understand binary indexed trees. Your codes are the easiest to understand. However, I slightly modified your codes.

    public class NumMatrix {
        private int[][] arrs;
        private int[][] Bindex;
    
        public NumMatrix(int[][] matrix) {
            if (matrix == null || matrix.length == 0) return;
            int row = matrix.length, col = matrix[0].length;
            this.arrs = new int[row][col];
            this.Bindex = new int[row + 1][col + 1];
            for (int i = 0; i < row; i++) {
                for (int j = 0; j < col; j++) {
                    update(i, j, matrix[i][j]);
                    arrs[i][j] = matrix[i][j];
                }
            }
        }
        public void update(int row, int col, int val) {
            int diff = val - arrs[row][col];
            arrs[row][col] = val;
            row ++;
            col ++;
            for (;row < Bindex.length; row += (row & -row)) {
                for (int j = col; j < Bindex[0].length; j += (j & -j)) {
                    Bindex[row][j] += diff;
                }
            }
        }
        public int getSum(int row, int col) {
            int sum = 0;
            row++;
            col++;
            for (; row > 0; row -= (row & -row)) {
                for (int j = col; j > 0; j -= (j & -j)) {
                    sum += Bindex[row][j];
                }
            }
            return sum;
        }
        public int sumRegion(int row1, int col1, int row2, int col2) {
            return getSum(row2,col2) - getSum(row1-1, col2) - getSum(row2, col1-1) + getSum(row1-1, col1-1);
        }
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.