BST in JAVA - using Rank


  • 0
    A
    public List<Integer> countSmaller(int[] nums) {
        BST<Integer> bst = new BST();
        List<Integer> counts = new ArrayList<Integer>(nums.length);
        
        for(int i = nums.length - 1; i >= 0; i--) {
            bst.put(nums[i]);
            counts.add(bst.rank(nums[i]));
        }
        
        Collections.reverse(counts);
        return counts;
    }
    
    
    private class BST<Key extends Comparable<Key>> {
        private class Node {
            private Key key;
            private int n;
            private Node left;
            private Node right;
            private int count;
            
            public Node(Key key, int n, int count) {
                this.n = n;
                this.key = key;
                this.count = count;
            }
        }
        
        private Node root;
        
        public BST() {
            
        }
        
        public int size(Node node) {
            if (node == null) return 0;
            return node.n;
        }
        public void put(Key key) {
            root = put(root, key);
        } 
        private Node put(Node node, Key key) {
            if (node == null) {
                return new Node(key, 1, 1);
            }
            int cmp = key.compareTo(node.key);
            if (cmp < 0) {
                node.left = put(node.left, key);
                
            } else if (cmp > 0) {
                node.right = put(node.right, key);
                node.n = size(node.left) + node.count + size(node.right);
            } else {
                node.count = node.count + 1;     // do not replace key, increase count
            }
            node.n = size(node.left) + node.count + size(node.right);
            return node;
        }
        
        public int rank(Key key) {
            return rank(root, key);
        }
        
        private int rank(Node node, Key key) {
            if (node == null) return 0;
            int cmp = key.compareTo(node.key);
            if (cmp < 0) return rank(node.left, key);
            else if (cmp > 0) return size(node.left) + node.count + rank(node.right, key);
            else return size(node.left);
        }
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.