8-line multiset C++ solution (100ms), also binary search tree (180ms) + mergesort(52ms)


  • 15
    D

    The basic idea is to use a multiset to save sum, where sum at i = nums[0]+...+ nums[i]. At each i, only those sum[j] that satisfies lower=< sum[i]-sum[j]<= upper can generate a valid range[j,i]. so we only need to calculate how many j (0=< j< i) satisfy sum[i]-upper=< sum[j]<=-sum[i]-lower. The STL multiset can take care of sort and find upper_bound, lower_bound j. Since the multiset is usually implemented with Red-black tree, so those operations should have complexity of O(logN). So in total, the complexity is O(NlogN) (except the distance part). At least it looks neat

    class Solution {
    public:
        int countRangeSum(vector<int>& nums, int lower, int upper) {
            multiset<long long> pSum;
            int res = 0,  i; 
            long long left, right, sum=0;
            for(i=0,pSum.insert(0); i<nums.size(); ++i)
            {
                sum +=nums[i];
                res += std::distance(pSum.lower_bound(sum-upper), pSum.upper_bound(sum-lower));
                pSum.insert(sum);
            }
            return res;
        }
    };
    

    In the comments made by StefanPochmann, there is concern that the STL distance function increases the total complexity to O(N^2), which is true. In the following version, I just show one possible way to fix that (O(1) distance function) if we implement the binary search tree by ourselves. Of course, the below version is not a balanced binary search tree, so the worst case is still O(N^2) even if the input is random, the average complexity is O(NlogN)

    class Solution {
    private:
        class BSTNode{ // Binary search tree implementation
        public:    
            long long val;
            int cnt; // how many nodes with value of "val'
            int lCnt; // how many nodes on its left subtree
            BSTNode *left;
            BSTNode *right;
            
            BSTNode(long long x)
            {
                val = x;
                cnt = 1;
                lCnt = 0;
                left = right = nullptr;
            }
        };
        
        int getBound(BSTNode *root, long long x, bool includeSelf)
        { // get the index of the last node that satisfy val<x (includeSelf=false) or val<=x (includeSelf = true)
            if(!root) return 0;
            if(root->val == x) return  root->lCnt + (includeSelf?root->cnt:0);
            else if(root->val > x) return getBound(root->left, x, includeSelf);
            else return root->cnt + root->lCnt + getBound(root->right, x, includeSelf);
        }
        void insert(BSTNode*& root, long long x)
        { // insert a node to the tree
            if(!root) root = new BSTNode(x);
            else if(root->val == x) (root->cnt)++;
            else if(root->val < x) 
                insert(root->right,x);
            else{
                ++(root->lCnt);
                insert(root->left,x);
            }
        }
        void deleteTree(BSTNode*root)
        { //destroy the tree
            if(!root) return;
            deleteTree(root->left);
            deleteTree(root->right);
            delete root;
        }
        
        
    public:
        int countRangeSum(vector<int>& nums, int lower, int upper) { // same idea as the multiset  version
            BSTNode *root= new BSTNode(0);
            int res = 0,  i; 
            long long left, right, sum=0;
            for(i=0; i<nums.size(); ++i)
            {
                sum +=nums[i];
                res += getBound(root, sum-lower, true) - getBound(root, sum-upper, false);
                insert(root, sum);
            }
            deleteTree(root);
            return res;
         }
    };
    

    Another option is to multify mergesort to do counting. The code is as below and the complexity is O(NlogN) (52ms)

    class Solution {
    private:    
        int mergeSort(vector<long long>&sum, int left, int right, int lower, int upper)
        {
            int mid, i, res, j, k;
            if(left>right) return 0;
            if(left==right) return ( (sum[left]>=lower) && (sum[left]<=upper) )?1:0;
            else
            {
                vector<long long> temp(right-left+1,0);
                mid = (left+right)/2;
                res = mergeSort(sum, left,mid, lower, upper) + mergeSort(sum, mid+1,right, lower, upper); // merge sort two halfs first, be careful about how to divide [left, mid] and [mid+1, right]
                for(i=left, j=k=mid+1; i<=mid; ++i)
                { // count the valid ranges [i,j], where i is in the first half and j is in the second half
                    while(j<=right && sum[j]-sum[i]<lower)  ++j;
                    while(k<=right && sum[k]-sum[i]<=upper) ++k;
                    res +=k-j;
                }
                for(i=k=left, j=mid+1; k<=right; ++k) //merge the sorted two halfs
                    temp[k-left] = (i<=mid) && (j>right || sum[i]<sum[j])?sum[i++]:sum[j++]; 
                for(k=left; k<=right; ++k) // copy the sorted results back to sum
                    sum[k] = temp[k-left]; 
                return res;
            }
        }
    public:
        int countRangeSum(vector<int>& nums, int lower, int upper) {
             int len = nums.size(), i;
             vector<long long> sum(len+1, 0);
             for(i=1; i<=len; ++i) sum[i] = sum[i-1]+nums[i-1];
             return mergeSort(sum, 1, len, lower, upper);
        }
    };

  • 3

    I believe it's only O(N^2), not O(NlogN), and thus not acceptable. If I'm not mistaken, multiset iterators are bidirectional iterators and distance takes linear time for them.

    Also, if you insert

        nums = vector<int>(5000, 1);
        lower = 0;
        upper = 99999;
    

    at the top of your function and click "Run Code", it takes about 156 ms. Doubling the 5000 to 10000 more than quadruples that to about 650 ms, as expected from quadratic run time.

    Edit: Since it fairly consistently more than quadrupled, I'm wondering whether it actually is O(N^2) or might be even worse. I thought of iterator steps as O(1), at least amortized, but maybe that's not true?

    Average of five attempts with 5000: (152 + 156 + 164 + 156 + 152) / 5 = 156 ms
    Average of five attempts with 10000: (644 + 636 + 648 + 644 + 680) / 5 = 650 ms

    Edit 2: Apparently it's amortized O(1) after all.


  • 0
    D

    Thanks for your comments. I agree the STL distance may be an issue and that is why I mentioned in my post "except the distance part". But if we can modify the multiset implementation (or build our own binary search tree), O(1) distance is not difficult to implement, which will change the total complexity to O(NlogN). As I said, I post it here just because it is neat.


  • 0
    W

    @dong.wang.1694 said in 8-line multiset C++ solution (100ms), also binary search tree (180ms) + mergesort(52ms):

    res += getBound(root, sum-lower, true) - getBound(root, sum-upper, false);

    We can calculate the number of valid sum[j] in a simpler way:

    res += cntBelowBound(bstRoot, sum-lower+1)-cntBelowBound(bstRoot, sum-upper);
    

    If cntBelowBound(root, val) find the lower bound of val, and return the number of sums less than the lower bound. The code is pretty much similar:

    int cntBelowBound(node* root, long val) {
            if(root == NULL) return 0;
            if(root->val == val) return root->leftcnt;
            else if(root->val > val) {
                return lowerbound(root->left, val);
            }
            else {
                return root->leftcnt + root->copy + lowerbound(root->right, val);
            }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.