# Segment Tree Solution

• `````` Long[] arr;
public class Node {
int count;
int start, end;
Node left, right;
Node(int start, int end) {
this.start = start;
this.end = end;
}
}

private Node buildTree(int start, int end) {
Node root = new Node(start, end);
if (start == end) return root;
int mid = (start+end)/2;
root.left = buildTree(start, mid);
root.right = buildTree(mid+1, end);
return root;
}

private void updateTree(Node root, long val) {
if (root.start == root.end) {root.count++; return;}
long midVal = arr[(root.start+root.end)/2];
root.count++;
if (val <= midVal)  updateTree(root.left, val);
else updateTree(root.right, val);
}

private int getCount(Node root, long lower, long upper) {
if (arr[root.start] >= lower && arr[root.end] <= upper) return root.count;
if (arr[root.start] > upper || arr[root.end] < lower)  return 0;
if (root.start == root.end) return 0;
long mid = arr[(root.start+root.end)/2];
if (mid >= upper) return getCount(root.left, lower, upper);
if (mid < lower) return getCount(root.right, lower, upper);
return getCount(root.left, lower, mid) + getCount(root.right, arr[(root.start+root.end)/2+1], upper);
}
public int countRangeSum(int[] nums, int lower, int upper) {
int res = 0;
if (nums.length == 0) return res;

long sum = 0;
Set<Long> set = new HashSet<>();
for (int i = 0; i < nums.length; i++) {
sum = sum + nums[i];
}
arr = set.toArray(new Long[0]);
Arrays.sort(arr);
Node root = buildTree(0, arr.length-1);

sum = 0;
updateTree(root, sum);
for (int i = 0; i < nums.length; i++) {
sum = sum + nums[i];
int tmp = 0;
tmp = getCount(root, sum-upper, sum-lower);
res += tmp;
updateTree(root, sum);
}
return res;
}``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.