Share my solution


  • 211

    First of all, let's look at the naive solution. Preprocess to calculate the prefix sums S[i] = S(0, i), then S(i, j) = S[j] - S[i]. Note that here we define S(i, j) as the sum of range [i, j) where j exclusive and j > i. With these prefix sums, it is trivial to see that with O(n^2) time we can find all S(i, j) in the range [lower, upper]

    Java - Naive Solution

    public int countRangeSum(int[] nums, int lower, int upper) {
        int n = nums.length;
        long[] sums = new long[n + 1];
        for (int i = 0; i < n; ++i)
            sums[i + 1] = sums[i] + nums[i];
        int ans = 0;
        for (int i = 0; i < n; ++i)
            for (int j = i + 1; j <= n; ++j)
                if (sums[j] - sums[i] >= lower && sums[j] - sums[i] <= upper)
                    ans++;
        return ans;
    }
    

    However the naive solution is set to TLE intentionally

    Now let's do better than this.

    Recall count smaller number after self where we encountered the problem

    • count[i] = count of nums[j] - nums[i] < 0 with j > i

    Here, after we did the preprocess, we need to solve the problem

    • count[i] = count of a <= S[j] - S[i] <= b with j > i
    • ans = sum(count[:])

    Therefore the two problems are almost the same. We can use the same technique used in that problem to solve this problem. One solution is merge sort based; another one is Balanced BST based. The time complexity are both O(n log n).

    The merge sort based solution counts the answer while doing the merge. During the merge stage, we have already sorted the left half [start, mid) and right half [mid, end). We then iterate through the left half with index i. For each i, we need to find two indices k and j in the right half where

    • j is the first index satisfy sums[j] - sums[i] > upper and
    • k is the first index satisfy sums[k] - sums[i] >= lower.

    Then the number of sums in [lower, upper] is j-k. We also use another index t to copy the elements satisfy sums[t] < sums[i] to a cache in order to complete the merge sort.

    Despite the nested loops, the time complexity of the "merge & count" stage is still linear. Because the indices k, j, t will only increase but not decrease, each of them will only traversal the right half once at most. The total time complexity of this divide and conquer solution is then O(n log n).

    One other concern is that the sums may overflow integer. So we use long instead.

    Java - Merge Sort Solution

    public int countRangeSum(int[] nums, int lower, int upper) {
        int n = nums.length;
        long[] sums = new long[n + 1];
        for (int i = 0; i < n; ++i)
            sums[i + 1] = sums[i] + nums[i];
        return countWhileMergeSort(sums, 0, n + 1, lower, upper);
    }
    
    private int countWhileMergeSort(long[] sums, int start, int end, int lower, int upper) {
        if (end - start <= 1) return 0;
        int mid = (start + end) / 2;
        int count = countWhileMergeSort(sums, start, mid, lower, upper) 
                  + countWhileMergeSort(sums, mid, end, lower, upper);
        int j = mid, k = mid, t = mid;
        long[] cache = new long[end - start];
        for (int i = start, r = 0; i < mid; ++i, ++r) {
            while (k < end && sums[k] - sums[i] < lower) k++;
            while (j < end && sums[j] - sums[i] <= upper) j++;
            while (t < end && sums[t] < sums[i]) cache[r++] = sums[t++];
            cache[r] = sums[i];
            count += j - k;
        }
        System.arraycopy(cache, 0, sums, start, t - start);
        return count;
    }

  • 0
    T
    This post is deleted!

  • 0
    D

    Excellent solution! The key point is that we do not need to count from mid again and again. That is to say, k and j continues to increase is the key. Some simple mathematical property stands behind it. Your solution also represents a general method. I am curious how could the "merge-sort" solution show up in your brain when you analyze the question?


  • 14

    Yes. You are right. Actually the general method is related to the so called "two pointers" technique, although we do have a lot more than two here :P. The "property" behind two pointers is monotonicity. To utilize monotonicity you have to sort. That is why "merge sort" show up in mind.

    The difference between this kind of problems and easy two pointer problems are that they also have relative position constraints which will be destroyed after sorting. Thus, we have to do this during sorting when partially sorted and we still have some relative positions.


  • 0
    W
    This post is deleted!

  • 0

    As I mentioned, although there are two nested loops, the time is still linear.
    Because the index of inner loop will only increase, but not reset to mid. Thus the inner loop at most increase (end-mid) times.


  • 0
    W

    Got it, thanks.


  • 13

    C++ implementations of your ideas with rich comments
    I think the key idea is at that you do not need to loop all the sum pair, but by merge sorting, you just need to find the 2 bound and set the count=upper_bound-lower_bound is OK

     class Solution {
    public:
        int countRangeSum(vector<int>& nums, int lower, int upper) {
            int size=nums.size();
            if(size==0)  return 0;
            vector<long> sums(size+1, 0);
            for(int i=0; i<size; i++)  sums[i+1]=sums[i]+nums[i];
            return help(sums, 0, size+1, lower, upper);
        }
        
        /*** [start, end)  ***/
        int help(vector<long>& sums, int start, int end, int lower, int upper){
            /*** only-one-element, so the count-pair=0 ***/
            if(end-start<=1)  return 0;
            int mid=(start+end)/2;
            int count=help(sums, start, mid, lower, upper)
                    + help(sums, mid, end, lower, upper);
            
            int m=mid, n=mid, t=mid, len=0;
            /*** cache stores the sorted-merged-2-list ***/
            /*** so we use the "len" to record the merged length ***/
            vector<long> cache(end-start, 0);
            for(int i=start, s=0; i<mid; i++, s++){
                /*** wrong code: while(m<end && sums[m++]-sums[i]<lower);  ***/
                while(m<end && sums[m]-sums[i]<lower) m++;
                while(n<end && sums[n]-sums[i]<=upper) n++;
                count+=n-m;
                /*** cache will merge-in-the-smaller-part-of-list2 ***/
                while(t<end && sums[t]<sums[i]) cache[s++]=sums[t++];
                cache[s]=sums[i];
                len=s;
            }
            
            for(int i=0; i<=len; i++)  sums[start+i]=cache[i];
            return count;
        }
    };

  • 0
    N

    really nice and clean code. Thank you for sharing!


  • 0
    W

    I'm wondering whether "count += j - k;" should be conditional with if() such as:
    if(k>=mid && k<end && sum[k]-sum[i]>=lower && j-1>=mid && j-1<end && sum[j-1]-sum[i]<=upper && j-1>=k). or did I miss something here? Appreciate feedback and this great solution!


  • 0
    W

    I'm wondering whether "count += j - k;" should be conditional with if() such as:

    if(k>=mid && k<end && sum[k]-sum[i]>=lower && j-1>=mid && j-1<end && sum[j-1]-sum[i]<=upper && j-1>=k)     
    {
          count += j - k;
    }
    

    It seems there would be corner cases not covered with direct addition or did I miss something here? Appreciate feedback and this great solution!


  • 1

    yes. you miss the fact that j is always >= k
    it is a sorted array.


  • 0
    W

    thanks pepsi, great solution


  • 4
    Y

    why cache is needed? can anyone help me?


  • 0

    because we need actually merge the sorted two parts in order to merge sort......


  • 0
    N

    Hi, can someone intuitively tell what does the 2 while loops do?

    while (k < end && sums[k] – sums[i] < lower) k++;
    while (j < end && sums[j] – sums[i] <= upper) j++;

    what is their purpose. Why do we even check if sum[k] – sum[i] < lower? Im having tough time understanding it.


  • 0

    Then you might want to read the problem statement again.


  • 0
    J

    Brilliant solution.


  • 0
    M

    one small question, after the for loop
    for (int i = start, r = 0; i < mid; ++i, ++r) {
    while (k < end && sums[k] - sums[i] < lower) k++;
    while (j < end && sums[j] - sums[i] <= upper) j++;
    while (t < end && sums[t] < sums[i]) cache[r++] = sums[t++];
    cache[r] = sums[i];
    count += j - k;
    }
    why you do not copy the rest of array into the cache?
    I use C++, so I have to copy it.


  • 0

    because after you copy it to cache, you then copy it back to sums. The reason that you don't need to copy the rest of array is because it doesn't change.

    For example, you are merging sums = [1 3 5 6 2 4 7 8]
    The result should be [1 2 3 4 5 6 7 8] and [7 8] doesn't change. Thus after you copied 123456 to cache you don't need to copy [7 8].


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.