C++ Quad-Tree (2D Segment Tree) solution. Easy to understand


  • 1
    T

    We can simply extend 1D segment tree to 2D: we divide a matrix into 4 parts (cut in half on each dimension): up-left\up-right\down-left\down-right.

    class Node // 2D Segment Tree
    {
    public:
    	Node(vector<vector<int>> &m, int ix0, int iy0, int ix1, int iy1)
    		: sum(0), x0(ix0), x1(ix1), y0(iy0), y1(iy1),
    		  ul(nullptr), ur(nullptr), dl(nullptr), dr(nullptr)
    	{
    		if(ix0 > ix1 || iy0 > iy1) return;
    		
    		if(ix0 == ix1 && iy0 == iy1)
    		{
    			sum = m[iy0][ix0];
    			return;
    		}
    		
    		int xmid = getMidX();
    		int ymid = getMidY();
    		
    		ul = new Node(m, ix0, iy0, xmid, ymid);
    		sum += ul->sum;
    		if(ix1 > xmid)
    		{
    			ur = new Node(m, xmid + 1, iy0, ix1, ymid);
    			sum += ur->sum;
    		}
    		if(iy1 > ymid)
    		{
    			dl = new Node(m, ix0, ymid + 1, xmid, iy1);
    			sum += dl->sum;
    		}
    		if(iy1 > ymid && ix1 > xmid)
    		{
    			dr = new Node(m, xmid + 1, ymid + 1, ix1, iy1);
    			sum += dr->sum;
    		}
    	}
    	
    	long long update(int rx, int ry, long long val)
    	{
    		if(rx == x0 && ry == y0 && x0 == x1 && y0 == y1)
    		{
    			long long d = val - sum;
    			sum = val;
    			return d;
    		}
    		
    		int xmid = getMidX();
    		int ymid = getMidY();
    		
    		long long d = 0;
    		if(rx <= xmid && ry <= ymid)
    		{
    			d = ul->update(rx, ry, val);
    		}
    		else if(rx > xmid && ry <= ymid)
    		{
    			d = ur->update(rx, ry, val);
    		}
    		else if(rx <= xmid && ry > ymid)
    		{
    			d = dl->update(rx, ry, val);
    		}
    		else if(rx > xmid && ry > ymid)
    		{
    			d = dr->update(rx, ry, val);
    		}
    		sum += d;
    		return d;
    	}
    	
    	long long get(int rx0, int ry0, int rx1, int ry1)
    	{
    		if(rx0 == x0 && rx1 == x1 && ry0 == y0 && ry1 == y1)
    		{
    			return sum;
    		}
    		//
    		int xmid = getMidX();
    		int ymid = getMidY();
    		
    		long long d = 0;		
    		if(rx0 <= xmid && ry0 <= ymid)
    		{
    			d += ul->get(rx0, ry0, min(xmid, rx1), min(ymid, ry1));
    		}		
    		if(rx1 > xmid && ry0 <= ymid)
    		{
    			d += ur->get(max(rx0, xmid + 1), ry0, rx1, min(ymid, ry1));
    		}
    		if(rx0 <= xmid && ry1 > ymid)
    		{
    			d += dl->get(rx0, max(ymid + 1, ry0), min(rx1, xmid), ry1);
    		}
    		if(rx1 > xmid && ry1 > ymid)
    		{
    			d += dr->get(max(rx0, xmid + 1), max(ry0, ymid + 1), rx1, ry1);
    		}
    		
    		return d;
    	}
    private:	
    	int getMidX(){ return x0 + (x1 - x0) / 2; }
    	int getMidY(){ return y0 + (y1 - y0) / 2; }
    private:	
    	//	mem vars
    	long long sum;
    	int x0, x1;
    	int y0, y1;
    	
    	Node *ul;
    	Node *ur;
    	Node *dl;
    	Node *dr;
    };
    
    class NumMatrix 
    {
    	Node *pSeg;
    public:
        NumMatrix(vector<vector<int>> &matrix) 
    	{
    		int h = matrix.size();
    		if(!h) return;
    		int w = matrix[0].size();
    		
            pSeg = new Node(matrix, 0, 0, w - 1, h - 1);
        }
    
        void update(int row, int col, int val) 
    	{
    		if(pSeg)
    			pSeg->update(col, row, val);
        }
    
        int sumRegion(int row1, int col1, int row2, int col2) 
    	{
    		if(pSeg)
    			return pSeg->get(col1, row1, col2, row2);
    		return 0;
        }
    };

  • 0

    I don't think sumRegion is O(log4mn). Imagine m=n=2e and sumRegion(1, 1, m-2, n-2). Aren't you going all the way down to the roughly 2m+2n single-cell regions on the border of that region?


  • -1
    T

    Thanks. Yes you are right the worst case should be O(m+n). Updated my post.


  • 0

    Quad-Tree != 2D segment tree.
    http://codeforces.com/blog/entry/16363


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.