# Share my Java 2-D Binary Indexed Tree Solution

• Based on the 1-D solution in problem Range Sum Query - Mutable, we can extend it to solve this 2-D problem.

Initializing the binary indexed tree takes `O(mn*logm*logn)` time, both `update()` and `getSum()` take `O(logm*logn)` time. The `arr[][]` is used to keep a backup of the `matrix[][]` so that we know the difference of the updated element and use that to update the binary indexed tree. The idea of calculating `sumRegion()` is the same as in Range Sum Query 2D - Immutable.

``````public class NumMatrix {
int m, n;
int[][] arr;    // stores matrix[][]
int[][] BITree; // 2-D binary indexed tree

public NumMatrix(int[][] matrix) {
if (matrix.length == 0 || matrix[0].length == 0) {
return;
}

m = matrix.length;
n = matrix[0].length;

arr = new int[m][n];
BITree = new int[m + 1][n + 1];

for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
update(i, j, matrix[i][j]); // init BITree[][]
arr[i][j] = matrix[i][j];   // init arr[][]
}
}
}

public void update(int i, int j, int val) {
int diff = val - arr[i][j];  // get the diff
arr[i][j] = val;             // update arr[][]

i++; j++;
while (i <= m) {
int k = j;
while (k <= n) {
BITree[i][k] += diff; // update BITree[][]
k += k & (-k); // update column index to that of parent
}
i += i & (-i);   // update row index to that of parent
}
}

int getSum(int i, int j) {
int sum = 0;

i++; j++;
while (i > 0) {
int k = j;
while (k > 0) {
sum += BITree[i][k]; // accumulate the sum
k -= k & (-k); // move column index to parent node
}
i -= i & (-i);   // move row index to parent node
}
return sum;
}

public int sumRegion(int i1, int j1, int i2, int j2) {
return getSum(i2, j2) - getSum(i1-1, j2) - getSum(i2, j1-1) + getSum(i1-1, j1-1);
}
}
``````

• Thank you so much for posting this. I have spent several hours to understand binary indexed trees. Your codes are the easiest to understand. However, I slightly modified your codes.

``````public class NumMatrix {
private int[][] arrs;
private int[][] Bindex;

public NumMatrix(int[][] matrix) {
if (matrix == null || matrix.length == 0) return;
int row = matrix.length, col = matrix[0].length;
this.arrs = new int[row][col];
this.Bindex = new int[row + 1][col + 1];
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
update(i, j, matrix[i][j]);
arrs[i][j] = matrix[i][j];
}
}
}
public void update(int row, int col, int val) {
int diff = val - arrs[row][col];
arrs[row][col] = val;
row ++;
col ++;
for (;row < Bindex.length; row += (row & -row)) {
for (int j = col; j < Bindex[0].length; j += (j & -j)) {
Bindex[row][j] += diff;
}
}
}
public int getSum(int row, int col) {
int sum = 0;
row++;
col++;
for (; row > 0; row -= (row & -row)) {
for (int j = col; j > 0; j -= (j & -j)) {
sum += Bindex[row][j];
}
}
return sum;
}
public int sumRegion(int row1, int col1, int row2, int col2) {
return getSum(row2,col2) - getSum(row1-1, col2) - getSum(row2, col1-1) + getSum(row1-1, col1-1);
}
}``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.