Java SegmentTree Solution, 36ms


  • 19
    F

    Understand my segmentTree implementation is not optimized.
    Please feel free to give me suggestions.

    public class Solution {
        class SegmentTreeNode {
            SegmentTreeNode left;
            SegmentTreeNode right;
            int count;
            long min;
            long max;
            public SegmentTreeNode(long min, long max) {
                this.min = min;
                this.max = max;
            }
        }
        private SegmentTreeNode buildSegmentTree(Long[] valArr, int low, int high) {
            if(low > high) return null;
            SegmentTreeNode stn = new SegmentTreeNode(valArr[low], valArr[high]);
            if(low == high) return stn;
            int mid = (low + high)/2;
            stn.left = buildSegmentTree(valArr, low, mid);
            stn.right = buildSegmentTree(valArr, mid+1, high);
            return stn;
        }
        private void updateSegmentTree(SegmentTreeNode stn, Long val) {
            if(stn == null) return;
            if(val >= stn.min && val <= stn.max) {
                stn.count++;
                updateSegmentTree(stn.left, val);
                updateSegmentTree(stn.right, val);
            }
        }
        private int getCount(SegmentTreeNode stn, long min, long max) {
            if(stn == null) return 0;
            if(min > stn.max || max < stn.min) return 0;
            if(min <= stn.min && max >= stn.max) return stn.count;
            return getCount(stn.left, min, max) + getCount(stn.right, min, max);
        }
    
        public int countRangeSum(int[] nums, int lower, int upper) {
    
            if(nums == null || nums.length == 0) return 0;
            int ans = 0;
            Set<Long> valSet = new HashSet<Long>();
            long sum = 0;
            for(int i = 0; i < nums.length; i++) {
                sum += (long) nums[i];
                valSet.add(sum);
            }
    
            Long[] valArr = valSet.toArray(new Long[0]);
    
            Arrays.sort(valArr);
            SegmentTreeNode root = buildSegmentTree(valArr, 0, valArr.length-1);
    
            for(int i = nums.length-1; i >=0; i--) {
                updateSegmentTree(root, sum);
                sum -= (long) nums[i];
                ans += getCount(root, (long)lower+sum, (long)upper+sum);
            }
            return ans;
        }
        
    }

  • 0
    E

    Can you make more detail explanation?
    Thanks in advance!


  • 0
    M

    Would you give some explanation pls?


  • 18

    This is really a great method of solving this problem. However, without any comment, this code really makes make me spend a decade to figure all this out.
    Let's start from the beginning of solving this problem to understand this code step by step.
    This problem requires us to return the number of sum(i, j) which result is between [lower, upper]

    for (int i = 0; i < length; i++) {
          for (int j = i; j < length; j++) {
                .....
          }
    }
    

    this above method can solve this problem easily with a TLE because of its O(n2) time complexity;
    To avoid this situation, the method uses Segment Tree data structure to prevent unnecessary checking which make it into O(n * log n). How n2 ==> n*log n? Because the data structure looks like a binary tree, which make the checking from n to log n.

    Before going any further, you must know what is segment tree. This is a brilliant data structure which is been used for searching the minimum value within a certain range with O(log n) time.
    It is hard to explain using several lines. So check this link out, believe me, it is really fun to watch and the instructor in this video is excellent in explaining things you do not know.

    Ok, let start from the main function

    public int countRangeSum(int[] nums, int lower, int upper) {
    
            if(nums == null || nums.length == 0) return 0;
            int ans = 0;
    // Questions 1: why are we using set here?
            Set<Long> valSet = new HashSet<Long>();
            /**
             * Because in this method, what really matters is the range of sum. So duplicates has no use at all.
             * You will know it after goint through the whole process.
             */
            long sum = 0; 
    // Use long to prevent overflow. 
            for(int i = 0; i < nums.length; i++) {
                sum += (long) nums[i];
    // (long) is a must, you can delete it and you will get a wrong answer
                valSet.add(sum);
            }
    //valSet now contains all sum of range(i, j) where i = 0 and j from 0 to nums.length - 1 
    
            Long[] valArr = valSet.toArray(new Long[0]);
    // Do not use primitive here, "long" does not work;
    
            Arrays.sort(valArr);
    // You must sort here. Because we are going to extract the range of sum. Or, you will Orz
    
            SegmentTreeNode root = buildSegmentTree(valArr, 0, valArr.length-1);
            /**
             * Before diving into "buildSegmentTree" function, you can imagine the tree looks like this:
             * This is a binary tree, each node contains a range formed by "min" and "max".
             * the "min" of a parent node is determined by the minimum lower boundary of all its children
             * the "max" is determined by the maximum upper boundary of all its children.
             * And remember, the boundary value must be a sum of a certain range(i, j). And values between 
             * min and max may not corresponding to a valid sum;
             * This node also contains a "Count" property which marks how many sub ranges under this node.
             */
    
            for(int i = nums.length-1; i >=0; i--) {
                updateSegmentTree(root, sum);
    /**
                 * Core part 1 : "updateTree" function will update nodes cnt value by plusing 1 if this node cotains range [sum(0, i)].
                 * How? 
                 * Each leafe of the segment tree contains range [sum[0, i], sum[0,i]] where i starts from 1 to nums.length
                 * so, we will definitely find the leafe if we search from the root of the tree;
                 * And during the process of finding this leafe, update every node's count value by 1
                 * because it must contains the leafe's range by definition.
                 */
                sum -= (long) nums[i];
    /**
                 * Core part 2 : why subtract nums[i] here ?
                 * because of its usage in the next part;
                 */
                ans += getCount(root, (long)lower+sum, (long)upper+sum);
    /**
                 * Core part 3 :
                 * why sum + lower and sum + upper
                 * In core part 2, sum is now the sum of range (0, i - 1), and it serves as a base now.
                 * What base?
                 * getCount method is trying to return how many valid subranges under [sum + lower, sum + upper]
                 * we plus "sum" to range[lower, upper] is because we want it to search the ranges formed by all
                 * ranges which starts from i - 1;
                 * why ? 
                 * To understand this, let's imagine sum is 0, and it will be getCount(root, 0 + lower, 0 + upper) 
                 * this will return number of valid ranges formed by sum(0, j)  
                 * Oh yeah. Hope you accept this. 
                 * but we still need the number of valid of ranges formed by sum(i, j) where i is not 0
                 * that is what "base" is doing now
                 * sum serves as a base here which makes ranges must start from sum(0, i - 1)
                 * really hard to explain...... Sorry
                 */
            }
            return ans;
        }
    

    You will understand the following two functions if you understand the implementation of Segment tree.

    private SegmentTreeNode buildSegmentTree(Long[] valArr, int low, int high) {
        }
    private int getCount(SegmentTreeNode stn, long min, long max) {
        }
    

    The following function is used to update "Count" variable
    No tricky part

        private void updateSegmentTree(SegmentTreeNode stn, Long val) {
            }
        }

  • 2
    Y

    Please correct me if I was wrong, but I feel like this segment tree didn't improve the timing complexity.
    In proc updateSegmentTree and getCount, we still traverse all n nodes. It's O(n) instead of O(logn).


  • 1
    F

    @yu.xiaoxixi The update function will not traverse all n nodes. It will only visite lg(n) nodes, which is the height of the tree. This is because that the interval of left child and the interval of right child does not overlap. So that the val only propagates to left child or right child for each node.


  • 1
    Y

    @fangxuke19 Yes, segment tree normally works as you said, it goes left child OR right child for each node.
    But int his segment tree update fuction, he goes BOTH left and right sides
    private void updateSegmentTree(SegmentTreeNode stn, Long val) {
    if(stn == null) return;
    if(val >= stn.min && val <= stn.max) {
    stn.count++;
    updateSegmentTree(stn.left, val);
    updateSegmentTree(stn.right, val);
    }
    }


  • 1
    F

    @yu.xiaoxixi But in the next recursion, either right half or left half will terminate. Because the interval for the left child and the interval for the right child does not overlap, hence the value either fall into the right child or left child. Correct me if I am wrong


  • 0
    Y

    @fangxuke19 Yes, I think you are right. I was wrong...


  • 3

    I find your last loop backward is a little bit hard to understand.
    Here I modified it to loop forward, which is more straight-forward:

    Cause we need to find previous prefixSum that is

    lower <= currentPrefixSum - previousPrefixSum <= upper
    ========>
    currentPrefixSum-upper <= previousPrefixSum <= currentPrefixSum-lower

    updateSegmentTree(root, sum); will add previousPrefixSum into counts.
    Then query the segment tree to get how many previousPrefixSums that are within the required range.

    ans += nums[0] >= lower && nums[0] <= upper ? 1 : 0;
    sum = nums[0];
    for(int i = 1; i<nums.length; i++) {
        updateSegmentTree(root, sum);
        sum += nums[i];
        ans += sum >= lower && sum <= upper ? 1 : 0;
        ans += getCount(root, (long)sum-upper, (long)sum-lower);
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.