22ms Java solution using binary tree, beats 99.82% of submissions


  • 7
    D

    Some notes: This solution uses an ordinary binary tree for simplicity's sake, which means it is likely to be unbalanced. Given enough time one may well use a balanced binary tree implementation to guarantee O(logn) runtime for addNum(). It is easy to see that findMedian() runs in O(1).

    By using a binary tree, we can easily keep the input numbers in nondecreasing order. Observe that whenever a number is added, the numbers used to calculate the median never shift by more than 1 position (in an imagined array representation) to the left or to the right. Let's see an example:
    [2], number used to calculate median is 2.
    [2,3], numbers used are 2,3 (expanding 1 to right)
    [0,2,3], use 2 (shrinking 1 to left)
    [0,1,2,3], use 1,2 (expanding 1 to left)
    [0,1,2,3,4], use 2 (shrinking 1 to right)
    ....and so on.

    With this observation, in MedianFinder I employ 2 variables medianLeft and medianRight to keep track of numbers we need to calculate the median. When size is odd, they point to the same node, otherwise they always point to 2 nodes which have predecessor/successor relationship. When adding a node, we just need to check the size of our MedianFinder tree, then depending on whether the new number is inserted to the left, inbetween, or to the right of our 2 median trackers, we will change medianLeft and medianRight to point to the correct nodes. Because the position never shifts more than 1, we can simply use predecessor() or successor() on the desired node to update it. Those 2 methods run in O(logn) when the tree is balanced, hence the O(logn) runtime of addNum().

    Hope this helps!

    public class MedianFinder {
        private Node root;
        private Node medianLeft;
        private Node medianRight;
        private int size;
        
        public MedianFinder() {
        }
    
        // Adds a number into the data structure.
        public void addNum(int num) {
            if (root == null) {
                root = new Node(num);
                medianLeft = root;
                medianRight = root;
            }
            else {
                root.addNode(num);
                if (size % 2 == 0) {
                    if (num < medianLeft.data) {
                        medianRight = medianLeft;
                    }
                    else if (medianLeft.data <= num && num < medianRight.data) {
                        medianLeft = medianLeft.successor();
                        medianRight = medianRight.predecessor();
                    }
                    else if (medianRight.data <= num) {
                        medianLeft = medianRight;
                    }
                }
                else {
                    if (num < medianLeft.data) {
                        medianLeft = medianLeft.predecessor();
                    }
                    else {
                        medianRight = medianRight.successor();
                    }
                }
            }
            size++;
        }
    
        // Returns the median of current data stream
        public double findMedian() {
            return (medianLeft.data + medianRight.data) / 2.0;
        }
        
        class Node {
            private Node parent;
            private Node left;
            private Node right;
            private int data;
            
            public Node(int data) {
                this.data = data;
            }
            
            public void addNode(int data) {
                if (data >= this.data) {
                  if (right == null) {
                    right = new Node(data);
                    right.parent = this;
                  }
                  else
                    right.addNode(data);
                }
                else {
                  if (left == null) {
                    left = new Node(data);
                    left.parent = this;
                  }
                  else
                    left.addNode(data);
                }
            }
            
            public Node predecessor() {
                if (left != null)
                    return left.rightMost();
                
                Node predecessor = parent;
                Node child = this;
                
                while (predecessor != null && child != predecessor.right) {
                    child = predecessor;
                    predecessor = predecessor.parent;
                }
                
                return predecessor;
            }
            
            public Node successor() {
                if (right != null)
                    return right.leftMost();
                
                Node successor = parent;
                Node child = this;
                
                while (successor != null && child != successor.left) {
                    child = successor;
                    successor = successor.parent;
                }
                
                return successor;
            }
            
            public Node leftMost(){
                if (left == null)
                    return this;
                return left.leftMost();
            }
            
            private Node rightMost() {
                if (right == null)
                    return this;
                return right.rightMost();
            }
            
        }
    };
    

  • 0
    D

    0_1475688639139_median from stream.png
    Actually got it to 19ms and beats 99.91% now.


  • 0
    S

    Excellent solution!


  • 0
    H

    I just ran your code and found that the runtime is 220ms, were the test cases changed?


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.