Java Segmeng Tree Solution with Explaination


  • 4
    A

    First build a Segment Tree, implements build, update and range operations.
    The tree holds the indices of the data array, tree node hold the indices and sum of this range.
    For build operation, divide and conqure, split range into start~mid and mid+1~end, calculate sum recursively. this operates N/2 + N/4 + N/8 + ... + N/N = N-1 times to build a binary segement tree. leaf node has same start and end index.
    For update operation, search node as search in binary search tree, update sum values in the path, time complexity is O(lgN).
    For range operation, divide and conqure with 4 conditions:
    1, range meet the node, return sum
    2, range meet the node.left(end <= node.mid), operate on node.left recursively
    3, range meet the node.right( start > node.mid), operate on node.right recursively
    4, range span node.left and node.right, operate sum(node.left, start, mid) + sum(node.right, mid+1, right) recursively

    public class NumArray {
        public class TreeNode {
            int start = 0;
            int end = 0;
            int sum = 0;
    
            TreeNode left = null;
            TreeNode right = null;
        }
    
        private TreeNode root = null;
        private int[] data = null;
    
        public void init(int[] data) {
            if (data == null || data.length == 0) return;
            this.data = data;
            this.root = build(0, data.length - 1);
        }
    
        private TreeNode build(int start, int end) {
            TreeNode node = new TreeNode();
            node.start = start;
            node.end = end;
    
            if (start == end) {
                node.sum = data[start];
                return node;
            }
    
            int mid = start + (end - start) / 2;
            node.left = build(start, mid);
            node.right = build(mid + 1, end);
    
            node.sum = node.left.sum + node.right.sum;
            return node;
        }
    
        private void update(TreeNode node, int index, int num) {
            if (node == null) return;
            if (node.start == node.end) {
                node.sum = num;
                return;
            }
            int mid = node.start + (node.end - node.start) / 2;
            if (index <= mid) {
                update(node.left, index, num);
            } else {
                update(node.right, index, num);
            }
            node.sum = node.left.sum + node.right.sum;
        }
    
        private int getSum(TreeNode node, int start, int end) {
             if (start < node.start || end > node.end) {
                return -1;
            }
    
            if (start == node.start && end == node.end) {
                return node.sum;
            }
    
            int mid = node.start + (node.end - node.start) / 2;
    
            if (start > mid) {
                return getSum(node.right, start, end);
            }
            if (end <= mid) {
                return getSum(node.left, start, end);
            }
            
            return getSum(node.left, start, mid) + getSum(node.right, mid+1, end);
        }
    
        public int getSum(int start, int end) {
            return getSum(root, start, end);
        }
        
        public NumArray(int[] nums) {
           init(nums);
        }
    
        void update(int i, int val) {
            update(root, i, val);
        }
    
        public int sumRange(int i, int j) {
            return getSum(i, j);
        }
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.