C++ Quad tree (736ms ) and indexed tree (492ms) based solutions


  • 9
    W

    I have written both the Quad tree based solution and the indexed tree based solution for c++.

    Both are very straight-forward. I have made some mistake for my previous analysis of the quad-tree solution. The indexed tree solution is more efficient in general.

    Method 1: Quad-tree based solution. Essentially, it is a divide and conquer algorithm that divide the whole matrix into 4 sub-matrices recursively. It can be shown that the algorithm is O(max(m, n)) per update/query.

    class NumMatrix {
        struct TreeNode {
            int val = 0;
            TreeNode* neighbor[4] = {NULL, NULL, NULL, NULL};
            pair<int, int> leftTop = make_pair(0,0);
            pair<int, int> rightBottom = make_pair(0,0);
            TreeNode(int v):val(v){}
        };
    public:
        NumMatrix(vector<vector<int>> &matrix) {
            nums = matrix;
            if (matrix.empty()) return;
            int row = matrix.size();
            if (row == 0) return;
            int col= matrix[0].size();
            root = createTree(matrix, make_pair(0,0), make_pair(row-1, col-1));
        }
        
        void update(int row, int col, int val) {
            int diff = val - nums[row][col];
            if (diff == 0) return;
            nums[row][col] = val;
            updateTree(row, col, diff, root);
        }
        
        int sumRegion(int row1, int col1, int row2, int col2) {
            int res = 0;
            if (root != NULL)
                sumRegion(row1, col1, row2, col2, root, res);
            return res;
        }
        
    private:
        TreeNode* root = NULL;
        vector<vector<int>> nums;
        TreeNode* createTree(vector<vector<int>> &matrix, pair<int, int> start, pair<int, int> end) {
            if (start.first > end.first || start.second > end.second)
                return NULL;
            TreeNode* cur = new TreeNode(0);
            cur->leftTop = start;
            cur->rightBottom = end;
            if (start == end) {
                cur->val = matrix[start.first][start.second];
                return cur;
            }
        
            int midx = ( start.first + end.first ) / 2;
            int midy = (start.second + end.second) / 2;
            cur->neighbor[0] = createTree(matrix, start, make_pair(midx, midy));
            cur->neighbor[1] = createTree(matrix, make_pair(start.first, midy+1), make_pair(midx, end.second));
            cur->neighbor[2] = createTree(matrix, make_pair(midx+1, start.second), make_pair(end.first, midy));
            cur->neighbor[3] = createTree(matrix, make_pair(midx+1, midy+1), end);
            for (int i = 0; i < 4; i++) {
                if (cur->neighbor[i])
                    cur->val += cur->neighbor[i]->val;
            }
            return cur;
        }
        
        void sumRegion(int row1, int col1, int row2, int col2, TreeNode* ptr, int &res) {
            pair<int, int> start = ptr->leftTop;
            pair<int, int> end = ptr->rightBottom;
            // determine whether there is overlapping
            int top = max(start.first, row1);
            int bottom = min(end.first, row2);
            if (bottom < top) return;
            int left = max(start.second, col1);
            int right = min(end.second, col2);
            if (left > right) return;
            
            
            if (row1 <= start.first && col1 <= start.second && row2 >= end.first && col2 >= end.second) {
                res += ptr->val;
                return;
            }
            
            for (int i = 0; i < 4; i ++) 
                if (ptr->neighbor[i]) 
                    sumRegion(row1, col1, row2, col2, ptr->neighbor[i], res);
                
        }
        
        
        void updateTree(int row, int col, int diff, TreeNode* ptr){
            if (row >= (ptr->leftTop).first && row <= (ptr->rightBottom).first &&
                col >= (ptr->leftTop).second && col <= (ptr->rightBottom).second)
            {
                ptr->val += diff;
                for (int i = 0; i < 4; i++)
                    if (ptr->neighbor[i])
                        updateTree(row, col, diff, ptr->neighbor[i]);
                
            }
        }
    };
    

    Method 2: the 2D indexed-tree solution. It is a simple generalization of the 1D indexed tree solution. The complexity should be O(log(m)log(n)).

    class NumMatrix {
    public:
        NumMatrix(vector<vector<int>> &matrix) {
            if (matrix.size() == 0 || matrix[0].size() == 0) return;
            nrow = matrix.size();
            ncol = matrix[0].size();
            nums = matrix;
            BIT = vector<vector<int>> (nrow+1, vector<int>(ncol+1, 0));
            for (int i = 0; i < nrow; i++)
                for (int j = 0; j < ncol; j++)
                    add(i, j, matrix[i][j]);
                
        }
    
        void update(int row, int col, int val) {
            int diff = val - nums[row][col];
            add(row, col,diff);
            nums[row][col] = val;
        }
    
        int sumRegion(int row1, int col1, int row2, int col2) {
            int regionL = 0, regionS = 0;
            int regionLeft = 0, regionTop = 0;
    
            regionL = region(row2, col2);
            
            if (row1 > 0 && col1 > 0) regionS = region(row1-1, col1-1);
            
            if (row1 > 0) regionTop  = region(row1-1, col2);
                
            if (col1 > 0) regionLeft = region(row2, col1-1);       
     
            return regionL - regionTop - regionLeft + regionS;
        }
    private:
        vector<vector<int>> nums;
        vector<vector<int>> BIT;
        int nrow = 0;
        int ncol = 0;
        void add(int row, int col, int val) {
            row++;
            col++;
            while(row <= nrow) {
                int colIdx = col;
                while(colIdx <= ncol) {
                    BIT[row][colIdx] += val;
                    colIdx += (colIdx & (-colIdx));
                }
                row +=  (row & (-row));
            }
        }
        
        int region(int row, int col) {
            row++;
            col++;
            int res = 0;
            while(row > 0) {
                int colIdx = col;
                while(colIdx > 0) {
                    res += BIT[row][colIdx];
                    colIdx -= (colIdx & (-colIdx));
                }
                row -= (row & (-row));
            }
            return res;
        }
    };

  • 2
    M

    Thanks for the solution sharing. It really helps for understanding 2D segment tree and binary index tree.

    Here is the python implementation of Binary index tree.

    class NumMatrix(object):
      def __init__(self, matrix):
            """
            initialize your data structure here.
            :type matrix: List[List[int]]
            """
            self.matrix = matrix
            self.m = len(matrix)
            if not self.m:
                return
            self.n = len(matrix[0])
            self.bit = [[0] * (self.n + 1) for _ in range(self.m + 1)]
            for i in range(self.m):
                for j in range(self.n):
                    self.add(i, j, matrix[i][j])
        
        
        def add(self, row, col, val):
            row += 1
            col += 1
            while row <= self.m:
                col_t = col
                while col_t <= self.n:
                    self.bit[row][col_t] += val
                    col_t += (col_t & -col_t)
                row += (row & -row)
        
        
        def update(self, row, col, val):
            """
            update the element at matrix[row,col] to val.
            :type row: int
            :type col: int
            :type val: int
            :rtype: void
            """
            diff = val - self.matrix[row][col]
            self.matrix[row][col] = val
            self.add(row, col, diff)
        
    
    def sumRegion(self, row1, col1, row2, col2):
        """
        sum of elements matrix[(row1,col1)..(row2,col2)], inclusive.
        :type row1: int
        :type col1: int
        :type row2: int
        :type col2: int
        :rtype: int
        """
        def sumRegion_bit(row, col):
            ret = 0
            row += 1
            col += 1
            while row > 0:
                col_t = col
                while col_t > 0:
                    ret += self.bit[row][col_t]
                    col_t -= (col_t & -col_t)
                row -= (row & -row)
            return ret
            
        ret = sumRegion_bit(row2, col2)
        if row1 > 0 and col1 > 0:
            ret += sumRegion_bit(row1 - 1, col1 - 1)
        if col1 > 0:
            ret -= sumRegion_bit(row2, col1 - 1)
        if row1 > 0:
            ret -= sumRegion_bit(row1 - 1, col2)
        return ret
    

  • 0
    Y

    no need to store the original data matrix. we only need the BINARY INDEX TREE array.


  • 0
    W

    Do you mean that we can get the original data by an extra query? So I guess this is a trade-off for space and time complexity :) Let me know if you think we can do it in a single query for update without an extra query. Thanks.


  • 4

    Your "segment tree" implementation is Quad-Tree. Quad-Tree != 2D segment tree.
    http://codeforces.com/blog/entry/16363

    For Quad-Tree, the worst case time complexity is O(max(n, m)), not O(log(mn)).
    http://apps.topcoder.com/forums/?module=Thread&threadID=633075

    A real 2D segment tree implementation will have time complexity O(logm * logn), which is the same as 2D binary indexed tree.
    http://e-maxx.ru/algo/segment_tree


  • 0
    W

    Thanks for your answer. For the first question, yes, maybe I should call it a Quad-Tree, but the worst case analysis is incorrect. Although we need O(mn) to construct the tree, as in the proof for the 1D segment tree.
    http://cs.stackexchange.com/questions/37669/time-complexity-proof-for-segment-tree-implementation-of-the-ranged-sum-problem
    you can see that there will be only at most constant number of nodes at each level (I believe the number of 7 for Quad-Tree) that would need to be split during the query even in the worst case. The analysis is tricky because you need to consider the fact that all queries contains a rectangle that is compact in the 2D dimension.

    I am not sure whether 2D segment tree is O(log(n) log(m)), but if so, then the quad-tree may be more efficient. This I am not sure, and I need to look into it.


  • 0

    The time complexity for 1D segment tree does not extend to QuadTree. For instance, consider a 2^n x 2^n grid. If you query the rectangle (0, 0) - (2^n - 1, 0), or any single-row slice, you end up having to look at 2^n different 1x1 squares. Thus, the time complexity for Quad-Tree is linear, O(max(n, m)).

    The time complexity of O(max(n, m)) is greater than O(logn * logm). Thus, 2d segment tree is more efficient.


  • 0
    W

    I think you are right! Thanks for the comment, I will correct my post and give you a upvote!


  • 0

    Thanks! I also give you an upvote. You've done a great job on implementing them. I've been searching for 2D segment tree for days. The only detailed source that seems to be correct is http://e-maxx.ru/algo/segment_tree.
    However, the website is in Russian. And I haven't fully understand all the ideas and implementations. Do you have any shareable information on 2D Segment tree?


  • 1
    W

    To address douglasleer's question, I am adding an extra method using 2D-segment tree.
    Method 3: 2D-segment tree (each node is another segment tree), although conceptually simple, it is a little tricky to implement. It should have the same complexity with 2D indexed tree. I am sharing one possible implementation using dfs for the construction and layered recursion for query/update; The runtime is around 596ms. Comments to speed up the implementation are welcome.

    class NumMatrix {
    public:
        NumMatrix(vector<vector<int>> &matrix) {
            if (matrix.empty()) return;
            data = matrix;
            m = matrix.size();
            n = matrix[0].size();
            int rowSize = getTreeSize(m);
            int colSize = getTreeSize(n);
            segTree2D = vector<vector<int>> (rowSize, vector<int>(colSize, 0));
            vector<vector<bool>> visited(rowSize, vector<bool>(colSize, false));
            constructTree(0, 0, m-1, n-1,  0, 0, visited);
        }
        
        int getTreeSize(int num) {
            int depth = (int)(ceil(log2(num)));
            return 2 * (int) pow(2, depth)-1;
        }
        
        void update(int row, int col, int val) {
            int diff = val - data[row][col];
            data[row][col] = val;
            updateTree(row, col, 0, m-1, 0, 0, diff);
        }
        
        void updateTree(int row, int col, int x1, int x2, int x, int y, int diff) {
            if (x1 <= row && row <= x2) {
                updateCol(col, 0, n-1, x, y, diff);
                if (x1 < x2) {
                    int midX = x1 + ((x2-x1) >> 1);
                    updateTree(row, col, x1, midX, x *2 + 1, y, diff);
                    updateTree(row, col, midX+1, x2, x*2 + 2, y, diff);
                }
            } else if (row < x1 || row  > x2) {
                return;
            }
        }
        
        void updateCol(int col, int y1, int y2, int x, int y, int diff) {
            if (col >= y1 && col <= y2) {
                segTree2D[x][y] += diff;
                if (y1 < y2) {
                    int midY = y1 + ((y2-y1) >> 1);
                    updateCol(col, y1, midY, x, y*2+1, diff);
                    updateCol(col, midY+1, y2, x, y*2+2, diff);
                }
            }
        }
        
        int sumRegion(int row1, int col1, int row2, int col2) {
            return queryRegion(row1, col1, row2, col2, 0, m-1, 0, 0);
        }
        
        int queryRegion(int row1, int col1, int row2, int col2, int x1, int x2, int x, int y) {
            if (x1 >= row1 && x2 <= row2) {
                return queryCol(col1, col2,  0, n-1, x, y);
            } else if (row2  < x1 || x2 < row1){
                return 0;
            } else {
                int midX = x1 + ((x2-x1)>>1);
                return queryRegion(row1, col1, row2, col2, x1, midX, x*2 +1, y) + queryRegion(row1, col1, row2, col2, midX+1, x2, x*2 +2, y);
            }
        }
        
        int queryCol(int col1, int col2, int y1, int y2, int x, int y){
            if (col1 <= y1 && col2 >= y2) {
                return segTree2D[x][y];
            } else if (col2 < y1 || col1 > y2) {
                return 0;
            } else {
                int midY = y1 + ((y2-y1) >> 1);
                return queryCol(col1, col2, y1, midY, x, y*2+1) + queryCol(col1, col2, midY+1, y2, x, y*2+2);
            }
        }
        
        int constructTree(int x1, int y1, int x2, int y2, int row, int col, vector<vector<bool>> &visited) {
            if (visited[row][col])
                return segTree2D[row][col];
            if (x1 == x2 && y1 == y2) {
                segTree2D[row][col] = data[x1][y1];
                visited[row][col] = true;
                return data[x1][y1];
            }
            
            if (x1 < x2) {
                int midX = x1 + ((x2-x1) >> 1);
                segTree2D[row][col] = constructTree(x1, y1, midX, y2, row*2 +1, col, visited) + constructTree(midX +1, y1, x2, y2, row *2 + 2, col, visited);
            }
            if (y1 < y2) {
                int midY = y1 + ((y2-y1) >> 1);
                segTree2D[row][col] = constructTree(x1, y1, x2, midY, row, col*2+1, visited) + constructTree(x1,midY+1, x2, y2, row, col*2 + 2, visited);
            }
            visited[row][col] = true;
            return segTree2D[row][col];
        }
    private:
        vector<vector<int>> data;
        vector<vector<int>> segTree2D;
        int m = 0, n= 0;
    };

  • 0
    W

    Hi, I just added a possible implementation for the 2D-segment tree. It works, but a little verbose to code.


  • 0
    L

    Hi, I think the update should be O(logn * logm) and the query should be O(max(m, n)), is that correct?


  • 0
    L

    @douglasleer A fourth point to mention: 2D segment tree has O(nlog(n)) complexity for range update, because the lazy propagation technique for segment tree can only be applied to one of the dimensions. Point update still has O(log(m) * log(n)) complexity, though.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.