Easy to understand O(nlogk) Java solution using TreeMap


  • 20
    B

    TreeMap is used to implement an ordered MultiSet.

    In this problem, I use two Ordered MultiSets as Heaps. One heap maintains the lowest 1/2 of the elements, and the other heap maintains the higher 1/2 of elements.

    This implementation is faster than the usual implementation that uses 2 PriorityQueues, because unlike PriorityQueue, TreeMap can remove arbitrary element in logarithmic time.

    public class Solution {
        public double[] medianSlidingWindow(int[] nums, int k) {
            double[] res = new double[nums.length-k+1];
            TreeMap<Integer, Integer> minHeap = new TreeMap<Integer, Integer>();
            TreeMap<Integer, Integer> maxHeap = new TreeMap<Integer, Integer>(Collections.reverseOrder());
            
            int minHeapCap = k/2; //smaller heap when k is odd.
            int maxHeapCap = k - minHeapCap; 
            
            for(int i=0; i< k; i++){
                maxHeap.put(nums[i], maxHeap.getOrDefault(nums[i], 0) + 1);
            }
            int[] minHeapSize = new int[]{0};
            int[] maxHeapSize = new int[]{k};
            for(int i=0; i< minHeapCap; i++){
                move1Over(maxHeap, minHeap, maxHeapSize, minHeapSize);
            }
            
            res[0] = getMedian(maxHeap, minHeap, maxHeapSize, minHeapSize);
            int resIdx = 1;
            
            for(int i=0; i< nums.length-k; i++){
                int addee = nums[i+k];
                if(addee <= maxHeap.keySet().iterator().next()){
                    add(addee, maxHeap, maxHeapSize);
                } else {
                    add(addee, minHeap, minHeapSize);
                }
                
                int removee = nums[i];
                if(removee <= maxHeap.keySet().iterator().next()){
                    remove(removee, maxHeap, maxHeapSize);
                } else {
                    remove(removee, minHeap, minHeapSize);
                }
    
                //rebalance
                if(minHeapSize[0] > minHeapCap){
                    move1Over(minHeap, maxHeap, minHeapSize, maxHeapSize);
                } else if(minHeapSize[0] < minHeapCap){
                    move1Over(maxHeap, minHeap, maxHeapSize, minHeapSize);
                }
                
                res[resIdx] = getMedian(maxHeap, minHeap, maxHeapSize, minHeapSize);
                resIdx++;
            }
            return res;
        }
    
        public double getMedian(TreeMap<Integer, Integer> bigHeap, TreeMap<Integer, Integer> smallHeap, int[] bigHeapSize, int[] smallHeapSize){
            return bigHeapSize[0] > smallHeapSize[0] ? (double) bigHeap.keySet().iterator().next() : ((double) bigHeap.keySet().iterator().next() + (double) smallHeap.keySet().iterator().next()) / 2.0;
        }
        
        //move the top element of heap1 to heap2
        public void move1Over(TreeMap<Integer, Integer> heap1, TreeMap<Integer, Integer> heap2, int[] heap1Size, int[] heap2Size){
            int peek = heap1.keySet().iterator().next();
            add(peek, heap2, heap2Size);
            remove(peek, heap1, heap1Size);
        }
        
        public void add(int val, TreeMap<Integer, Integer> heap, int[] heapSize){
            heap.put(val, heap.getOrDefault(val,0) + 1);
            heapSize[0]++;
        }
        
        public void remove(int val, TreeMap<Integer, Integer> heap, int[] heapSize){
            if(heap.put(val, heap.get(val) - 1) == 1) heap.remove(val);
            heapSize[0]--;
        }
    }
    

  • 0
    B

    to whomever downvoted, if you could explain your concern with the code or suggestions for improvement, that would be greatly appreciated.


  • -8
    W

    PriorityQueue might also have logN removing complexity.


  • 2
    B

    @wsliubw
    In this problem, it is necessary to be able remove elements that are not necessarily at the top of the heap. PriorityQueue has logarithmic time remove top, but a linear time remove arbitrary element.
    I updated comments to clarify this.


  • 0
    M

    Does TreeMap allow duplicates? How did you handle duplicates?


  • 2
    E

    @brendon4565 I don't understand why there could be down votes for this solution. I think this solution is fantastic. The idea of using TreeMap to store number and its count to handle duplicates is very unique. And using balanced BST instead of heap to avoid O(k) remove operation is quite smart. The implementation might be simpler if you remove each element first before adding a new element. After removal, the difference between two BSTs is at most 2, and using the code from "Find Median from Data Stream" can easily make it no more than 1.

    @MitchellHe I believe I have answered your question ^_^.


  • 6

    You could use map.getFirstKey() and map.getFirstEntry().getValue() instead of iterator.

    public double[] medianSlidingWindow(int[] nums, int k) {
    
        TreeMap<Integer, Integer> minHeap = new TreeMap<Integer, Integer>();
        TreeMap<Integer, Integer> maxHeap = new TreeMap<Integer, Integer>(Collections.reverseOrder());
    
    
        double[] result = new double[nums.length-k+1];
        int i = 0, numToRemove = 0;
        int minSize = 0, maxSize = 0;
        for(int num : nums) {
            if(i > k-1) numToRemove = nums[i-k];
            Double minHeapTop = minHeap.firstEntry() != null ? (double)minHeap.firstKey() : Double.MIN_VALUE;
            Double maxHeapTop = maxHeap.firstEntry() != null ? (double)maxHeap.firstKey() : Double.MAX_VALUE;
    
            if(num < maxHeapTop) {
                maxHeap.put(num, maxHeap.getOrDefault(num,0)+1);
                maxSize ++;
            }else {
                minHeap.put(num, minHeap.getOrDefault(num,0)+1);
                minSize++;
            }
    
            // heap clean up
            TreeMap<Integer, Integer> pq = null;
            if(minHeap.firstEntry() == null) pq = maxHeap;
            else pq = numToRemove >= minHeap.firstKey() ?  minHeap : maxHeap;
    
            if(i >= k && pq.containsKey(numToRemove)) {
                if(pq == minHeap) minSize--;
                else maxSize--;
                if(pq.get(numToRemove) == 1) {
                    pq.remove(numToRemove);
                }else {
                    pq.put(numToRemove, pq.get(numToRemove)-1);
                }
            }
    
            // balance
            if(minSize-1 > maxSize) {
                transferFrom(minHeap,maxHeap);
                minSize--;maxSize++;
            }else if(minSize < maxSize-1) {
                transferFrom(maxHeap,minHeap);
                maxSize--;minSize++;
            }
    
            if(i >= k-1 && minSize == maxSize) {
                result[i-k+1] = ((double)minHeap.firstKey() + (double)maxHeap.firstKey())/2.0;
            }else if(i >= k-1 && minSize > maxSize) {
                result[i-k+1] = minHeap.firstKey();
            }else if(i >= k-1) {
                result[i-k+1] = maxHeap.firstKey();
            }
            i++;
            //System.out.println(maxHeap + " " + minHeap);
    
        }
        return result;
    }
    public void transferFrom(TreeMap<Integer,Integer> src, TreeMap<Integer,Integer> dest) {
        dest.put(src.firstKey(), dest.getOrDefault(src.firstKey(),0)+1);
        if(src.firstEntry().getValue() == 1) {
            src.remove(src.firstKey());
        }else {
            src.put(src.firstKey(),src.firstEntry().getValue()-1);
        }
    }

  • 0
    F

    brilliant solution! upvoted!


  • 0
    F

    Actually for PriorityQueue, it takes O(n) to remove an element other than peek element.


  • 3
    M

    Using TreeSet:

    public class Solution {
    public double[] medianSlidingWindow(int[] nums, int k) {
        double[] result = new double[nums.length - k + 1];
        TreeSet<Integer> left = getSet(nums);
        TreeSet<Integer> right = getSet(nums);
        for(int i = 0; i < nums.length; i++) {
            if(left.size() <= right.size()) {
                right.add(i);
                int m = right.first();
                right.remove(m);
                left.add(m);
            } else {
                left.add(i);
                int m = left.last();
                left.remove(m);
                right.add(m);
            }
            
            
            if(left.size() + right.size() == k) {
                double med;
                if(left.size() == right.size())
                    med = ((double)nums[left.last()] + nums[right.first()]) / 2;
                else if(left.size() < right.size())
                    med = nums[right.first()];
                else
                    med = nums[left.last()];
                    
                int start = i - k + 1;    
                result[start] = med;    
                
                if(!left.remove(start))
                    right.remove(start);
            }
        }
        return result;
    }
    
    private static TreeSet<Integer> getSet(int[] nums) {
        return new TreeSet<>(new Comparator<Integer>(){
            public int compare(Integer a, Integer b) {
                return nums[a] == nums[b] ? a - b : nums[a] < nums[b] ? -1 : 1;
            }
        });
    }
    

    }


  • 0
    H

    Genius solution! Never thought of a TreeMap that way - it has so many uses. Thanks a lot.


  • 0

    share my treeset solution, plus I use another hashmap to record the frequency of the elements in each treeset and another variable to trace the # of elements in each treeset(not treeset.size()).

    public class Solution {
        public double[] medianSlidingWindow(int[] nums, int k) {
            //corner case
            if (nums.length == 0) {
                return new double[0];
            }
    
            TreeSet<Integer> low = new TreeSet<>();
            TreeSet<Integer> high = new TreeSet<>();
            int lowCap = 0;
            int highCap = 0;
            Map<Integer, Integer> lowTimes = new HashMap<>();
            Map<Integer, Integer> highTimes = new HashMap<>();
    
            int size = nums.length;
            double[] res = new double[size - k + 1];
    
    
            for (int i = 0; i < size; i++) {
                //remove
                if (i > k - 1) {
                    int removeValue = nums[i - k];
                    int lowCeil = low.last();
    
                    if (high.contains(removeValue)) {
                        //can be optimimzed
                        int tmp = highTimes.get(removeValue);
                        if (tmp == 1) {
                            high.remove(removeValue);
                        }
                        highTimes.put(removeValue, tmp - 1);
                        highCap--;
                    } else {
                        int tmp = lowTimes.get(removeValue);
                        if (tmp == 1) {
                            low.remove(removeValue);
                        }
                        lowTimes.put(removeValue, tmp - 1);
                        lowCap--;
                    }
                }
    
                //add
                high.add(nums[i]);
                highTimes.put(nums[i], highTimes.getOrDefault(nums[i], 0) + 1);
                //highCap++;
    
                int poll = high.first();
                if (highTimes.get(poll) == 1) {
                    high.pollFirst();
                }
                highTimes.put(poll, highTimes.get(poll) - 1);
                //highCap--;
    
                low.add(poll);
                lowTimes.put(poll, lowTimes.getOrDefault(poll, 0) + 1);
                lowCap++;
    
                //keep balance
                if (lowCap > highCap + 1) {
                    int cur = low.last();
                    if (lowTimes.get(cur) == 1) {
                        low.pollLast();
                    }
                    lowTimes.put(cur, lowTimes.get(cur) - 1);
                    lowCap--;
    
                    high.add(cur);
                    highTimes.put(cur, highTimes.getOrDefault(cur, 0) + 1);
                    highCap++;
                }
    
                //get median
                if (i >= k - 1) {
                    if (lowCap == highCap) {
                        res[i - k + 1] = ((double)low.last() + (double)high.first())/2.0;
                    } else {
                        res[i - k + 1] = low.last();
                    }
                }
            }
    
            return res;
        }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.