The basic idea is to use a multiset to save sum, where sum at i = nums[0]+...+ nums[i]. At each i, only those sum[j] that satisfies lower=< sum[i]-sum[j]<= upper can generate a valid range[j,i]. so we only need to calculate how many j (0=< j< i) satisfy sum[i]-upper=< sum[j]<=-sum[i]-lower. The STL multiset can take care of sort and find upper_bound, lower_bound j. Since the multiset is usually implemented with Red-black tree, so those operations should have complexity of O(logN). So in total, the complexity is O(NlogN) (except the distance part). At least it looks neat

```
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
multiset<long long> pSum;
int res = 0, i;
long long left, right, sum=0;
for(i=0,pSum.insert(0); i<nums.size(); ++i)
{
sum +=nums[i];
res += std::distance(pSum.lower_bound(sum-upper), pSum.upper_bound(sum-lower));
pSum.insert(sum);
}
return res;
}
};
```

In the comments made by StefanPochmann, there is concern that the STL distance function increases the total complexity to O(N^2), which is true. In the following version, I just show one possible way to fix that (O(1) distance function) if we implement the binary search tree by ourselves. Of course, the below version is not a balanced binary search tree, so the worst case is still O(N^2) even if the input is random, the average complexity is O(NlogN)

```
class Solution {
private:
class BSTNode{ // Binary search tree implementation
public:
long long val;
int cnt; // how many nodes with value of "val'
int lCnt; // how many nodes on its left subtree
BSTNode *left;
BSTNode *right;
BSTNode(long long x)
{
val = x;
cnt = 1;
lCnt = 0;
left = right = nullptr;
}
};
int getBound(BSTNode *root, long long x, bool includeSelf)
{ // get the index of the last node that satisfy val<x (includeSelf=false) or val<=x (includeSelf = true)
if(!root) return 0;
if(root->val == x) return root->lCnt + (includeSelf?root->cnt:0);
else if(root->val > x) return getBound(root->left, x, includeSelf);
else return root->cnt + root->lCnt + getBound(root->right, x, includeSelf);
}
void insert(BSTNode*& root, long long x)
{ // insert a node to the tree
if(!root) root = new BSTNode(x);
else if(root->val == x) (root->cnt)++;
else if(root->val < x)
insert(root->right,x);
else{
++(root->lCnt);
insert(root->left,x);
}
}
void deleteTree(BSTNode*root)
{ //destroy the tree
if(!root) return;
deleteTree(root->left);
deleteTree(root->right);
delete root;
}
public:
int countRangeSum(vector<int>& nums, int lower, int upper) { // same idea as the multiset version
BSTNode *root= new BSTNode(0);
int res = 0, i;
long long left, right, sum=0;
for(i=0; i<nums.size(); ++i)
{
sum +=nums[i];
res += getBound(root, sum-lower, true) - getBound(root, sum-upper, false);
insert(root, sum);
}
deleteTree(root);
return res;
}
};
```

Another option is to multify mergesort to do counting. The code is as below and the complexity is O(NlogN) (52ms)

```
class Solution {
private:
int mergeSort(vector<long long>&sum, int left, int right, int lower, int upper)
{
int mid, i, res, j, k;
if(left>right) return 0;
if(left==right) return ( (sum[left]>=lower) && (sum[left]<=upper) )?1:0;
else
{
vector<long long> temp(right-left+1,0);
mid = (left+right)/2;
res = mergeSort(sum, left,mid, lower, upper) + mergeSort(sum, mid+1,right, lower, upper); // merge sort two halfs first, be careful about how to divide [left, mid] and [mid+1, right]
for(i=left, j=k=mid+1; i<=mid; ++i)
{ // count the valid ranges [i,j], where i is in the first half and j is in the second half
while(j<=right && sum[j]-sum[i]<lower) ++j;
while(k<=right && sum[k]-sum[i]<=upper) ++k;
res +=k-j;
}
for(i=k=left, j=mid+1; k<=right; ++k) //merge the sorted two halfs
temp[k-left] = (i<=mid) && (j>right || sum[i]<sum[j])?sum[i++]:sum[j++];
for(k=left; k<=right; ++k) // copy the sorted results back to sum
sum[k] = temp[k-left];
return res;
}
}
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
int len = nums.size(), i;
vector<long long> sum(len+1, 0);
for(i=1; i<=len; ++i) sum[i] = sum[i-1]+nums[i-1];
return mergeSort(sum, 1, len, lower, upper);
}
};
```