My BST solution with detailed explanation


  • 2

    Basically, the idea is very easy. We traverse the array backwards meanwhile build a BST (AVL, RB-tree... whatever). The key is how to get how many nodes in tree smaller than current node. The code as below says everything.

        public List<Integer> countSmaller(int[] nums) {
            Integer[] counter = new Integer[nums.length];
            TreeSet<Integer> tree = new TreeSet<>(
                (a, b) -> (Integer.compare(a, b) == 0) ? -1 : Integer.compare(a, b));
            for (int i = nums.length - 1; i >= 0; i--) {
                tree.add(nums[i]);
                counter[i] = tree.headSet(nums[i]).size(); // Here is the key!!!
            }
            return Arrays.asList(counter);
        }
    

    Of course, this solution will TLE. Though we cheat it to store duplicates, since TreeSet has no method to get headSet size, it will timeout. Now we need to make a simple BST with extra field to store how many left nodes. I prefer using dup field to make it clear.

        public List<Integer> countSmaller(int[] nums) {
            if (nums.length == 0) return new ArrayList<>();
            
            BSTNode root = new BSTNode(nums[nums.length - 1]);
            Integer[] counter = new Integer[nums.length];
            counter[nums.length - 1] = 0;
            for (int i = nums.length - 2; i >= 0; i--)
                counter[i] = insert(root, nums[i]);
            return Arrays.asList(counter);  
        }
        
        private int insert(BSTNode root, int newval) {
            if (newval < root.val) {
                root.leftsum++;
                if (root.left != null) 
                    return insert(root.left, newval);
                root.left = new BSTNode(newval);
                return 0;
            } else if (root.val < newval) {
                int smaller = root.leftsum + root.dup;
                if (root.right != null)
                    return insert(root.right, newval) + smaller;
                root.right = new BSTNode(newval);
                return smaller;
            } else {
                root.dup++;
                return root.leftsum;
            }
        }
        
        private static class BSTNode {
            BSTNode left, right;
            int leftsum, dup = 1, val;
            BSTNode(int val) { this.val = val; }
        }
    

    And here is an alternative solution with reference to "Rank BST" in CLRS. (Order-statistic tree with size field to calculate the rank of value.)

        public List<Integer> countSmaller(int[] nums) {
            if (nums.length == 0) return new ArrayList<>();
            int n = nums.length;
            Integer[] ret = new Integer[n];
            ret[n - 1] = 0;
            
            Node root = new Node(nums[n - 1]);
            for (int i = n - 2; i >= 0; i--) {
                int rank = 0;
                Node par = root;
                for (Node cur = root; cur != null; ) {
                    par = cur;
                    par.size++;
                    if (nums[i] < cur.val) { // left subtree contains only smaller num, it's safe to add to rank!
                        cur = cur.left;
                    } else {
                        if (cur.left != null) rank += cur.left.size; // size of left subtree
                        if (cur.val < nums[i]) rank++; // check duplicate to decide if count parent as smaller
                        cur = cur.right;
                    }
                }
                if (nums[i] < par.val) par.left = new Node(nums[i]);
                else par.right = new Node(nums[i]);
                ret[i] = rank;
            }
            return Arrays.asList(ret);
        }
        
        class Node {
            int val, size = 1; // itself
            Node left, right;
            Node(int val) { this.val = val; }
        }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.