Java Red-Black Tree 72 ms solution


  • 4

    There are BST solutions, but they suffer from unbalance in the worst-case, degrading to O(n^2). What's worse, the worst case, no pun intended, is a very regular case when all numbers are positive or negative. So we need to keep our tree balanced, and that immediately rings a bell: Red-Black Trees.

    The only trick is that we need to keep track of node counts in subtrees so that we can quickly count the number of elements less than or equal to something without traversing around (like TreeMap.subMap().size() does). It is also important to update those node counts when performing rotations.

    public int countRangeSum(int[] nums, int lower, int upper) {
        long sum = 0;
        RedBlackTree sumTree = new RedBlackTree();
        sumTree.add(sum); // zero-length prefix
        int count = 0;
        for (int i = 0; i < nums.length; ++i) {
            sum += nums[i];
            // we need to count lower <= sums[i] - sums[j] <= upper, j < i, or
            // -lower >= sums[j] - sums[i] >= -upper, or sums[i] - lower >= sums[j] >= sums[i] - upper, or
            // sums[i] - lower >= sums[j] > sums[i] - upper - 1
            count += countLE(sumTree.root, sum - lower) - countLE(sumTree.root, sum - upper - 1);
            sumTree.add(sum);
        }
        return count;
    }
    
    private static int countLE(RedBlackTree.Node root, long sum) {
        RedBlackTree.Node current = root;
        int count = current.totalCount;
        while (current != RedBlackTree.Node.NIL) {
            if (current.value == sum) {
                count -= current.right.totalCount;
                break;
            } else if (sum < current.value) {
                count -= current.valueCount + current.right.totalCount;
                current = current.left;
            } else { // we haven't seen anything greater than sum yet
                current = current.right;
            }
        }
        return count;
    }
    
    static class RedBlackTree {
        
        Node root = Node.NIL;
        
        void add(long value) {
            Node current = root, prev = Node.NIL;
            while (current != Node.NIL && current.value != value) {
                ++current.totalCount;
                prev = current;
                if (value < current.value) {
                    current = current.left;
                } else {
                    current = current.right;
                }
            }
            if (current != Node.NIL) { // Note: can't test for current.value == value here because value can be 0.
                // exact match
                ++current.totalCount;
                ++current.valueCount;
                return;
            }
            Node node = new Node(value);
            if (prev == Node.NIL) {
                root = node;
            } else {
                if (value < prev.value) {
                    assert prev.left == Node.NIL;
                    prev.left = node;
                } else {
                    assert prev.right == Node.NIL && value > prev.value;
                    prev.right = node;
                }
                node.parent = prev;
            }
            // fix up the Red-Blackness (CLR, Introduction to Algorithms)
            while (node.parent.color == Node.Color.RED) {
                Node parent = node.parent;
                Node granddad = parent.parent;
                assert granddad.color == Node.Color.BLACK;
                boolean left = granddad.left == parent;
                Node uncle = left ? granddad.right : granddad.left;
                if (uncle.color == Node.Color.RED) { // case 1
                    granddad.color = Node.Color.RED;
                    parent.color = uncle.color = Node.Color.BLACK;
                    node = granddad;
                } else {
                    if ((left ? parent.right : parent.left) == node) { // case 2
                        node = parent;
                        rotate(node, left);
                    }
                    // case 3
                    parent.color = Node.Color.BLACK;
                    granddad.color = Node.Color.RED;
                    rotate(granddad, !left);
                }
            }
            root.color = Node.Color.BLACK;
        }
        
        void rotate(Node node, boolean left) {
            Node parent = node.parent;
            Node child = left ? node.right : node.left;
            if (left) { node.right = child.left; } else { node.left = child.right; }
            node.totalCount = node.left.totalCount + node.valueCount + node.right.totalCount;
            (left ? child.left : child.right).parent = node;
            child.parent = parent;
            if (parent == Node.NIL) {
                root = child;
            } else {
                if (parent.left == node) {
                    parent.left = child;
                } else {
                    assert parent.right == node;
                    parent.right = child;
                }
            }
            if (left) { child.left = node; } else { child.right = node; }
            child.totalCount = child.left.totalCount + child.valueCount + child.right.totalCount;
            node.parent = child;
            Node.NIL.left = Node.NIL.right = Node.NIL.parent = Node.NIL; // fix it up in case we've messed it up
        }
        
        static class Node {
            static final Node NIL = new Node();
            
            static { // need this because we can't initialize fields to NIL until it is created
                NIL.left = NIL.right = NIL.parent = NIL;
            }
            
            long value;
            int valueCount, totalCount;
            Node parent = NIL, left = NIL, right = NIL;
            Color color;
            
            private Node() { // NIL constructor
                this.color = Color.BLACK;
            }
            
            Node(long value) {
                this.value = value;
                this.valueCount = this.totalCount = 1;
                this.color = Color.BLACK;
            }
            
            enum Color {
                RED, BLACK
            }
        }
    }
    

    This solution is not terribly efficient: only 72 ms, and I can't think of any ways to optimize it. In fact, Java TreeMap sources look very similar except that they use more branches and don't use explicit NIL nodes. Maybe the fact that RBT is not perfectly balanced also plays its role: some paths can be twice as long as others. And we also need to perform the search twice on each loop iteration, which also doesn't make it any faster.

    The conclusion is that this solution is probably only of theoretical interest.


  • 0

    @SergeyTachenov
    Give this warrior who implemented RBT himself a thumb up!


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.