Java Segment Tree Method


  • 0

    Hope you guys give me some suggestion:
    I divide matrix into 4 small pieces using segment tree and record its sum
    '''
    class TreeNode {
    int row1;
    int col1;
    int row2;
    int col2;
    int sum;
    TreeNode topLeft, topRight, bottomLeft, bottomRight;
    TreeNode(int row1, int col1, int row2, int col2, int sum) {
    this.row1 = row1;
    this.col1 = col1;
    this.row2 = row2;
    this.col2 = col2;
    this.sum = sum;
    topLeft = topRight = bottomLeft = bottomRight = null;
    }
    }

    public class NumMatrix {
    private TreeNode root;
    private int row;
    private int col;
    public NumMatrix(int[][] matrix) {
    if (matrix.length == 0 || matrix[0].length == 0) {
    return;
    }
    row = matrix.length;
    col = matrix[0].length;
    root = build(matrix, 0, 0, row - 1, col - 1);
    }

    public void update(int row, int col, int val) {
        modify(root, row, col, val);
    }
    
    public int sumRegion(int row1, int col1, int row2, int col2) {
        return query(root, row1, col1, row2, col2);
    }
    
    private TreeNode build(int[][] matrix, int row1, int col1, int row2, int col2) {
        if (row1 > row2 || col1 > col2) {
            return null;
        }
        if (row1 == row2 && col1 == col2) {
            return new TreeNode(row1, col1, row2, col2, matrix[row1][col1]);
        }
        int midRow = (row1 + row2) / 2;
        int midCol = (col1 + col2) / 2;
        TreeNode topLeft = build(matrix, row1, col1, midRow, midCol);
        TreeNode topRight = build(matrix, row1, midCol + 1, midRow, col2);
        TreeNode bottomLeft = build(matrix, midRow + 1, col1, row2, midCol);
        TreeNode bottomRight = build(matrix, midRow + 1, midCol + 1, row2, col2);
        int sum = 0;
        sum += topLeft != null ? topLeft.sum : 0;
        sum += topRight != null ? topRight.sum : 0;
        sum += bottomLeft != null ? bottomLeft.sum : 0;
        sum += bottomRight != null ? bottomRight.sum : 0;
        TreeNode node = new TreeNode(row1, col1, row2, col2, sum);
        node.topLeft = topLeft;
        node.topRight = topRight;
        node.bottomLeft = bottomLeft;
        node.bottomRight = bottomRight;
        return node;
    }
    
    private void modify(TreeNode root, int row, int col, int value) {
        if (root == null) {
            return;
        }
        if (root.row1 == root.row2 && root.col1 == root.col2 && root.row1 == row && root.col1 == col) {
            root.sum = value;
            return;
        }
        int midRow = (root.row1 + root.row2) / 2;
        int midCol = (root.col1 + root.col2) / 2;
        if (root.row1 <= row && row <= midRow) {
            if (root.col1 <= col && col <= midCol) {
                modify(root.topLeft, row, col, value);
            } else if (midCol + 1 <= col && col <= root.col2) {
                modify(root.topRight, row, col, value);
            }
        } else if (midRow + 1 <= row && row <= root.row2) {
            if (root.col1 <= col && col <= midCol) {
                modify(root.bottomLeft, row, col, value);
            } else if (midCol + 1 <= col && col <= root.col2) {
                modify(root.bottomRight, row, col, value);
            }
        }
        int sum = 0;
        sum += root.topLeft != null ? root.topLeft.sum : 0;
        sum += root.topRight != null ? root.topRight.sum : 0;
        sum += root.bottomLeft != null ? root.bottomLeft.sum : 0;
        sum += root.bottomRight != null ? root.bottomRight.sum : 0;
        root.sum = sum;
    }
    
    private int query(TreeNode root, int row1, int col1, int row2, int col2) {
        if (root == null || row1 > row2 || col1 > col2) {
            return 0;
        }
        if (root.row1 == row1 && root.col1 == col1 && root.row2 == row2 && root.col2 == col2) {
            return root.sum;
        }
        int midRow = (root.row1 + root.row2) / 2;
        int midCol = (root.col1 + root.col2) / 2;
        if (root.row1 <= row2 && row2 <= midRow) {
            // top
            if (root.col1 <= col2 && col2 <= midCol) {
               // top left 
               return query(root.topLeft, row1, col1, row2, col2);
            } else if (midCol + 1 <= col1 && col1 <= root.col2) {
                // top right
                return query(root.topRight, row1, col1, row2, col2);
            } else {
                // top left + top right
                return query(root.topLeft, row1, col1, row2, midCol) + query(root.topRight, row1, midCol + 1, row2, col2);
            }
        } else if (midRow + 1 <= row1 && row1 <= root.row2) {
            // bottom
            if (root.col1 <= col2 && col2 <= midCol) {
                // bottom left 
                return query(root.bottomLeft, row1, col1, row2, col2);
            } else if (midCol + 1 <= col1 && col1 <= root.col2) {
                // bottom right
                return query(root.bottomRight, row1, col1, row2, col2);
            } else {
                // bottom left + bottom right
                return query(root.bottomLeft, row1, col1, row2, midCol) + query(root.bottomRight, row1, midCol + 1, row2, col2);
            }
        } else {
            // top + bottom
            if (root.col1 <= col2 && col2 <= midCol) {
                // top left + bottom left
                return query(root.topLeft, row1, col1, midRow, col2) + query(root.bottomLeft, midRow + 1, col1, row2, col2);
            } else if (midCol + 1 <= col1 && col1 <= root.col2) {
                // top right + bottom right
                return query(root.topRight, row1, col1, midRow, col2) + query(root.bottomRight, midRow + 1, col1, row2, col2);
            } else {
                // top left + top right + bottom left + bottom right
                return query(root.topLeft, row1, col1, midRow, midCol) + query(root.topRight, row1, midCol + 1, midRow, col2) +
                query(root.bottomLeft, midRow + 1, col1, row2, midCol) + query(root.bottomRight, midRow + 1, midCol + 1, row2, col2);
            }
        }
    }
    

    }

    // Your NumMatrix object will be instantiated and called as such:
    // NumMatrix numMatrix = new NumMatrix(matrix);
    // numMatrix.sumRegion(0, 1, 2, 3);
    // numMatrix.update(1, 1, 10);
    // numMatrix.sumRegion(1, 2, 3, 4);
    '''


  • 0

    Right format:

    class TreeNode {
        int row1;
        int col1;
        int row2;
        int col2;
        int sum;
        TreeNode topLeft, topRight, bottomLeft, bottomRight;
        TreeNode(int row1, int col1, int row2, int col2, int sum) {
            this.row1 = row1;
            this.col1 = col1;
            this.row2 = row2;
            this.col2 = col2;
            this.sum = sum;
            topLeft = topRight = bottomLeft = bottomRight = null;
        }
    }
    
    public class NumMatrix {
        private TreeNode root;
        private int row;
        private int col;
        public NumMatrix(int[][] matrix) {
            if (matrix.length == 0 || matrix[0].length == 0) {
                return;
            }
            row = matrix.length;
            col = matrix[0].length;
            root = build(matrix, 0, 0, row - 1, col - 1);
        }
    
        public void update(int row, int col, int val) {
            modify(root, row, col, val);
        }
    
        public int sumRegion(int row1, int col1, int row2, int col2) {
            return query(root, row1, col1, row2, col2);
        }
        
        private TreeNode build(int[][] matrix, int row1, int col1, int row2, int col2) {
            if (row1 > row2 || col1 > col2) {
                return null;
            }
            if (row1 == row2 && col1 == col2) {
                return new TreeNode(row1, col1, row2, col2, matrix[row1][col1]);
            }
            int midRow = (row1 + row2) / 2;
            int midCol = (col1 + col2) / 2;
            TreeNode topLeft = build(matrix, row1, col1, midRow, midCol);
            TreeNode topRight = build(matrix, row1, midCol + 1, midRow, col2);
            TreeNode bottomLeft = build(matrix, midRow + 1, col1, row2, midCol);
            TreeNode bottomRight = build(matrix, midRow + 1, midCol + 1, row2, col2);
            int sum = 0;
            sum += topLeft != null ? topLeft.sum : 0;
            sum += topRight != null ? topRight.sum : 0;
            sum += bottomLeft != null ? bottomLeft.sum : 0;
            sum += bottomRight != null ? bottomRight.sum : 0;
            TreeNode node = new TreeNode(row1, col1, row2, col2, sum);
            node.topLeft = topLeft;
            node.topRight = topRight;
            node.bottomLeft = bottomLeft;
            node.bottomRight = bottomRight;
            return node;
        }
        
        private void modify(TreeNode root, int row, int col, int value) {
            if (root == null) {
                return;
            }
            if (root.row1 == root.row2 && root.col1 == root.col2 && root.row1 == row && root.col1 == col) {
                root.sum = value;
                return;
            }
            int midRow = (root.row1 + root.row2) / 2;
            int midCol = (root.col1 + root.col2) / 2;
            if (root.row1 <= row && row <= midRow) {
                if (root.col1 <= col && col <= midCol) {
                    modify(root.topLeft, row, col, value);
                } else if (midCol + 1 <= col && col <= root.col2) {
                    modify(root.topRight, row, col, value);
                }
            } else if (midRow + 1 <= row && row <= root.row2) {
                if (root.col1 <= col && col <= midCol) {
                    modify(root.bottomLeft, row, col, value);
                } else if (midCol + 1 <= col && col <= root.col2) {
                    modify(root.bottomRight, row, col, value);
                }
            }
            int sum = 0;
            sum += root.topLeft != null ? root.topLeft.sum : 0;
            sum += root.topRight != null ? root.topRight.sum : 0;
            sum += root.bottomLeft != null ? root.bottomLeft.sum : 0;
            sum += root.bottomRight != null ? root.bottomRight.sum : 0;
            root.sum = sum;
        }
        
        private int query(TreeNode root, int row1, int col1, int row2, int col2) {
            if (root == null || row1 > row2 || col1 > col2) {
                return 0;
            }
            if (root.row1 == row1 && root.col1 == col1 && root.row2 == row2 && root.col2 == col2) {
                return root.sum;
            }
            int midRow = (root.row1 + root.row2) / 2;
            int midCol = (root.col1 + root.col2) / 2;
            if (root.row1 <= row2 && row2 <= midRow) {
                // top
                if (root.col1 <= col2 && col2 <= midCol) {
                   // top left 
                   return query(root.topLeft, row1, col1, row2, col2);
                } else if (midCol + 1 <= col1 && col1 <= root.col2) {
                    // top right
                    return query(root.topRight, row1, col1, row2, col2);
                } else {
                    // top left + top right
                    return query(root.topLeft, row1, col1, row2, midCol) + query(root.topRight, row1, midCol + 1, row2, col2);
                }
            } else if (midRow + 1 <= row1 && row1 <= root.row2) {
                // bottom
                if (root.col1 <= col2 && col2 <= midCol) {
                    // bottom left 
                    return query(root.bottomLeft, row1, col1, row2, col2);
                } else if (midCol + 1 <= col1 && col1 <= root.col2) {
                    // bottom right
                    return query(root.bottomRight, row1, col1, row2, col2);
                } else {
                    // bottom left + bottom right
                    return query(root.bottomLeft, row1, col1, row2, midCol) + query(root.bottomRight, row1, midCol + 1, row2, col2);
                }
            } else {
                // top + bottom
                if (root.col1 <= col2 && col2 <= midCol) {
                    // top left + bottom left
                    return query(root.topLeft, row1, col1, midRow, col2) + query(root.bottomLeft, midRow + 1, col1, row2, col2);
                } else if (midCol + 1 <= col1 && col1 <= root.col2) {
                    // top right + bottom right
                    return query(root.topRight, row1, col1, midRow, col2) + query(root.bottomRight, midRow + 1, col1, row2, col2);
                } else {
                    // top left + top right + bottom left + bottom right
                    return query(root.topLeft, row1, col1, midRow, midCol) + query(root.topRight, row1, midCol + 1, midRow, col2) +
                    query(root.bottomLeft, midRow + 1, col1, row2, midCol) + query(root.bottomRight, midRow + 1, midCol + 1, row2, col2);
                }
            }
        }
    }
    
    
    // Your NumMatrix object will be instantiated and called as such:
    // NumMatrix numMatrix = new NumMatrix(matrix);
    // numMatrix.sumRegion(0, 1, 2, 3);
    // numMatrix.update(1, 1, 10);
    // numMatrix.sumRegion(1, 2, 3, 4);
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.