# Complicated segmentree solution, hope to find a better one

• ``````public class Solution {
static class segmentTreeNode {
int start, end, count;
segmentTreeNode left, right;
segmentTreeNode(int start, int end, int count) {
this.start = start;
this.end = end;
this.count = count;
left = null;
right = null;
}
}
public static List<Integer> countSmaller(int[] nums) {
List<Integer> result = new ArrayList<Integer>();

int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE;
for (int i : nums) {
min = Math.min(min, i);

}
if (min < 0) {
for (int i = 0; i < nums.length; i++) {
nums[i] -= min;//deal with negative numbers, seems a dummy way
}
}
for (int i : nums) {
max = Math.max(max, i);
}
segmentTreeNode root = build(0, max);
for (int i = 0; i < nums.length; i++) {
}
for (int i = 0; i < nums.length; i++) {
updateDel(root, nums[i]);
}
return result;
}
public static segmentTreeNode build(int start, int end) {
if (start > end) return null;
if (start == end) return new segmentTreeNode(start, end, 0);
int mid = (start + end) / 2;
segmentTreeNode root = new segmentTreeNode(start, end, 0);
root.left = build(start, mid);
root.right = build(mid + 1, end);
root.count = root.left.count + root.right.count;
return root;
}

public static int query(segmentTreeNode root, int start, int end) {
if (root == null) return 0;
if (root.start == start && root.end == end) return root.count;
int mid = (root.start + root.end) / 2;
if (end < mid) {
return query(root.left, start, end);
} else if (start > end) {
return query(root.right, start, end);
} else {
return query(root.left, start, mid) + query(root.right, mid + 1, end);
}
}

public static void updateAdd(segmentTreeNode root, int val) {
if (root == null || root.start > val || root.end < val) return;
if (root.start == val && root.end == val) {
root.count ++;
return;
}
int mid = (root.start + root.end) / 2;
if (val <= mid) {
} else {
}
root.count = root.left.count + root.right.count;
}

public static void updateDel(segmentTreeNode root, int val) {
if (root == null || root.start > val || root.end < val) return;
if (root.start == val && root.end == val) {
root.count --;
return;
}
int mid = (root.start + root.end) / 2;
if (val <= mid) {
updateDel(root.left, val);
} else {
updateDel(root.right, val);
}
root.count = root.left.count + root.right.count;
}
}``````

• Using a Binary Indexed Tree (Fenwick tree) can shorten the code a lot. :P

``````public class Solution {

private void add(int[] bit, int i, int val) {
for (; i < bit.length; i += i & -i) bit[i] += val;
}

private int query(int[] bit, int i) {
int ans = 0;
for (; i > 0; i -= i & -i) ans += bit[i];
return ans;
}

public List<Integer> countSmaller(int[] nums) {
int[] tmp = nums.clone();
Arrays.sort(tmp);
for (int i = 0; i < nums.length; i++) nums[i] = Arrays.binarySearch(tmp, nums[i]) + 1;
int[] bit = new int[nums.length + 1];
Integer[] ans = new Integer[nums.length];
for (int i = nums.length - 1; i >= 0; i--) {
ans[i] = query(bit, nums[i] - 1);
}
return Arrays.asList(ans);
}
}``````

• It is better to do discretization first to make sure each number is mapped to [0, n). Otherwise, the size of the segment tree can be as large as the max key in the array, say 2147483647.

• This is just amazing! Thanks for sharing this solution!

• Could you explain the idea? Thanks!

• @zjuzhanxf, you can see here for a detailed tutorial on BIT (Fenwick Tree).

• Impressed, BIT is awesome!

• If the input is very spread, your solution would take up a lot of memory. You should build remove duplicates and build tree from the sorted array.

``````class Solution(object):

def countSmaller(self, nums):
"""
:type nums: List[int]
:rtype: List[int]
"""
array = sorted(list(set(nums)))
indexes = dict([(value,index) for index,value in enumerate(array)])
tree = SegmentTree()
root = tree.build(0,len(array)-1)
rst = [0]*len(nums)
for i in range(len(nums)-1,-1,-1):
#find index to modify
index = indexes[nums[i]]
tree.modify(root,index,1)
rst[i] = tree.query(root,0, index-1)
return rst
``````

class SegmentTreeNode:

``````def __init__(self, start, end, sum = 0):
self.start, self.end, self.sum = start, end, sum
self.left, self.right = None, None
``````

class SegmentTree:

``````def build(self, start, end):
if start > end:
return None
if start == end:
return SegmentTreeNode(start,end)
node = SegmentTreeNode(start,end)
node.left = self.build(start, (start+end)/2)
node.right = self.build((start+end)/2+1,end)
return node

def modify(self, root, index, value):
if root is None:
return
if root.start== root.end and root.start == index:
root.sum += value
middle = (root.start + root.end)/2
if index > middle:
self.modify(root.right,index,value)
else:
self.modify(root.left,index,value)
if root.left and root.right:
root.sum = root.left.sum + root.right.sum
return

def query(self, root, start, end):
if start > end or root is None:
return 0
if start == root.start and end == root.end:
return root.sum
middle = (root.start + root.end)/2
# interval in left
if end <= middle:
return self.query(root.left,start,end)
# interval in right
if start > middle:
return self.query(root.right,start,end)
# interval breaks into two section
else:
return self.query(root.left,start,middle) + self.query(root.right,middle+1,end)
``````

`

• why did you build the segment tree with value rather than the index of the array?

segmentTreeNode root = build(0, max);

change to :
segmentTreeNode root = build(0, nums.length - 1);

If you use the index of the array, then you won't worry the negative number

• @mickeyliu6 Recently, I learned a more elegant way of writing a bit without discretization.

``````void add(Map<Long, Integer> bit, long x) {
for (x += 1L << 31; x < 1L << 32; x += x & -x) bit.put(x, bit.getOrDefault(x, 0) + 1);
}

int sum(Map<Long, Integer> bit, long x) {
int ans = 0;
for (x += 1L << 31; x > 0; x -= x & -x) ans += bit.getOrDefault(x, 0);
return ans;
}

public List<Integer> countSmaller(int[] nums) {
Map<Long, Integer> bit = new HashMap<>();
Integer[] res = new Integer[nums.length];
for (int i = nums.length - 1; i >= 0; i--) {
res[i] = sum(bit, (long) nums[i] - 1);
}
return Arrays.asList(res);
}
``````

``````class Solution {
public:
vector<int> countSmaller(vector<int>& nums) {
vector<int> res (nums.size());
auto pos_num = sorted(enumerate(nums), [](std::pair<size_t, int> lhs, std::pair<size_t, int> rhs) { return lhs.second < rhs.second || (lhs.second == rhs.second && lhs.first < rhs.first); });
const int N = nums.size();
std::vector<int> seg(N*2);
for (auto pair: pos_num) {
int t = N + pair.first;
for (++seg[t]; t > 1; t >>= 1) seg[t>>1] = seg[t] + seg[t^1];
int sum = 0;
for (int l = N + pair.first + 1, r = 2*N; l < r; l >>= 1, r >>= 1) {
if (l&1) sum += seg[l++];
if (r&1) sum += seg[--r];
}
res[pair.first] = sum;
}
return res;
}

// enumerate(), return a vector of pair<index, value>, index starting from zero.
template <typename Iterator>
std::vector<std::pair<size_t, typename Iterator::value_type>> enumerate(Iterator first, Iterator last) {
std::vector<std::pair<size_t, typename Iterator::value_type>> res;
for (size_t i = 0; first != last; ++i, ++first) res.push_back({i, *first});
return res;
}
template <typename Container>
std::vector<std::pair<size_t, typename Container::value_type>> enumerate(const Container &c) { return enumerate(std::begin(c), std::end(c)); }

// sorted(), return a vector of sorted sequence.
template <typename Iterator, typename Compare = std::less<typename Iterator::value_type>>
std::vector<typename Iterator::value_type> sorted(Iterator first, Iterator last, Compare comp = Compare()) {
auto res = std::vector<typename Iterator::value_type> (first, last);
std::sort(res.begin(), res.end(), comp);
return res;
}
template <typename Container, typename Compare = std::less<typename Container::value_type>>
std::vector<typename Container::value_type> sorted(const Container &c, Compare comp = Compare()) { return sorted(std::begin(c), std::end(c), comp); }
};``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.