# My BST solution with detailed explanation

• Basically, the idea is very easy. We traverse the array backwards meanwhile build a BST (AVL, RB-tree... whatever). The key is how to get how many nodes in tree smaller than current node. The code as below says everything.

``````    public List<Integer> countSmaller(int[] nums) {
Integer[] counter = new Integer[nums.length];
TreeSet<Integer> tree = new TreeSet<>(
(a, b) -> (Integer.compare(a, b) == 0) ? -1 : Integer.compare(a, b));
for (int i = nums.length - 1; i >= 0; i--) {
counter[i] = tree.headSet(nums[i]).size(); // Here is the key!!!
}
return Arrays.asList(counter);
}
``````

Of course, this solution will TLE. Though we cheat it to store duplicates, since TreeSet has no method to get headSet size, it will timeout. Now we need to make a simple BST with extra field to store how many left nodes. I prefer using dup field to make it clear.

``````    public List<Integer> countSmaller(int[] nums) {
if (nums.length == 0) return new ArrayList<>();

BSTNode root = new BSTNode(nums[nums.length - 1]);
Integer[] counter = new Integer[nums.length];
counter[nums.length - 1] = 0;
for (int i = nums.length - 2; i >= 0; i--)
counter[i] = insert(root, nums[i]);
return Arrays.asList(counter);
}

private int insert(BSTNode root, int newval) {
if (newval < root.val) {
root.leftsum++;
if (root.left != null)
return insert(root.left, newval);
root.left = new BSTNode(newval);
return 0;
} else if (root.val < newval) {
int smaller = root.leftsum + root.dup;
if (root.right != null)
return insert(root.right, newval) + smaller;
root.right = new BSTNode(newval);
return smaller;
} else {
root.dup++;
return root.leftsum;
}
}

private static class BSTNode {
BSTNode left, right;
int leftsum, dup = 1, val;
BSTNode(int val) { this.val = val; }
}
``````

And here is an alternative solution with reference to "Rank BST" in CLRS. (Order-statistic tree with size field to calculate the rank of value.)

``````    public List<Integer> countSmaller(int[] nums) {
if (nums.length == 0) return new ArrayList<>();
int n = nums.length;
Integer[] ret = new Integer[n];
ret[n - 1] = 0;

Node root = new Node(nums[n - 1]);
for (int i = n - 2; i >= 0; i--) {
int rank = 0;
Node par = root;
for (Node cur = root; cur != null; ) {
par = cur;
par.size++;
if (nums[i] < cur.val) { // left subtree contains only smaller num, it's safe to add to rank!
cur = cur.left;
} else {
if (cur.left != null) rank += cur.left.size; // size of left subtree
if (cur.val < nums[i]) rank++; // check duplicate to decide if count parent as smaller
cur = cur.right;
}
}
if (nums[i] < par.val) par.left = new Node(nums[i]);
else par.right = new Node(nums[i]);
ret[i] = rank;
}
return Arrays.asList(ret);
}

class Node {
int val, size = 1; // itself
Node left, right;
Node(int val) { this.val = val; }
}
``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.