O(n*log(k)) 79 ms Indexed priority queue C++ solution with explanation (accepted).


  • 1
    V

    This solution also uses 2 heaps, a Bottom Max heap and a top Min heap wherein the median is always kept as the largest element in the max heap or average of the top elements in min and max heaps when k is even.
    This solution makes use of the heap structure defined in algorithms by Sedgwick:
    http://algs4.cs.princeton.edu/home/

    Sedgwick also describes an Index Priority queue wherein a client is allowed to modify indices of keys. By doing so, we can always easily remove elements from the heap based on their index in O(log(k)), since there are atmost k/2 elements in either of the two heaps.

    The indexPQ structure can be viewed separately for solving similar problems.

    class Solution {
        class IndexPQ {
            int N{0};
            int maxN;
            std::vector<int> pq;
            std::vector<int> qp;
            std::vector<int> keys;
            bool isMin{false};
            public:
            IndexPQ(int maxN, bool isMin): maxN(maxN), isMin(isMin) {
                qp = std::vector<int>(maxN + 1, -1);
                pq = std::vector<int>(maxN + 1, -1);
                keys = std::vector<int>(maxN + 1, -1);
            }
            
            bool contains(int i) {
                return (qp[i] != -1);
            }
            
            void insert(int i, int key) {
                qp[i] = N;
                pq[N] = i;
                keys[i] = key;
                swim(N);
                N++;
            }
            
            int topKey() {
                return keys[pq[0]];
            }
            
            int deleteTop() {
                int indexOfTop = pq[0];
                exch(0, --N);
                sink(0);
                qp[pq[N]] = -1;
                pq[N] = -1;
                return indexOfTop;
            }
            
            int count() {
                return N;
            }
            
            void deleteIndex(int i) {
                int index = qp[i];
                exch(index, --N);
                swim(index);
                sink(index);
                pq[N] = -1;
                qp[i] = -1;
            }
            
            void sink(int k) {
                while(2*k + 1 < N) {
                    int j = 2*k + 1;
                    if ((j + 1) < N && less(j, j+1)) j++;
                    if (!less(k, j)) break;
                    exch(k, j);
                    k = j;
                }
            }
            
            void swim(int k) {
                while(k > 0 && less((k-1)/2, k)) {
                    exch((k-1)/2, k);
                    k = (k-1)/2;
                }
            }
            
            bool less(int i, int j) {
                if (isMin) {
                    return keys[pq[i]] > keys[pq[j]];
                } else {
                    return keys[pq[i]] < keys[pq[j]];
                }
            }
            
            void exch(int i, int j) {
                int t = qp[pq[i]];
                qp[pq[i]] = qp[pq[j]];
                qp[pq[j]] = t;
                t = pq[i];
                pq[i] = pq[j];
                pq[j] = t;
            }
        };
        
    public:
        vector<double> medianSlidingWindow(vector<int>& nums, int k) {
            vector<double> solution;
            // Trivial cases
            if (nums.size() == 0) {
                return solution;
            }
            if (k == 1) {
                for(auto v: nums) {
                    solution.push_back(double(v));
                }
                return solution;
            }
            
            int mid = (k - 1)/2;
            if (k >= nums.size()) {
                k = nums.size();   
            }
            
            std::vector<int> subvector(nums.begin(), nums.begin() + k);
            // Shuffle -- O(n) Partition -- O(n)
            std::random_shuffle(subvector.begin(), subvector.end());
            std::nth_element(subvector.begin(), subvector.begin() + mid, subvector.end());
            int median = subvector[mid];
            
            IndexPQ minPQ(nums.size(), true);
            IndexPQ maxPQ(nums.size(), false);
            
            // Insertion based on median -- O(n)
            for (int i = 0; i < k; i++) {
                if (nums[i] > median) {
                    minPQ.insert(i, nums[i]);
                } else {
                    maxPQ.insert(i, nums[i]);
                }
            }
            
            // If there are lot of elements equal to the median, we need to rebalance the 2 PQs
            // O(n)
            while(1) {
                auto minc = minPQ.count();
                auto maxc = maxPQ.count();
                if (maxc == 0 || (minc > maxc && (minc - maxc) > 1)) {
                    int index = minPQ.deleteTop();
                    maxPQ.insert(index, nums[index]);
                } else if (minc == 0 || (maxc > minc && (maxc - minc) > 1)) {
                    int index = maxPQ.deleteTop();
                    minPQ.insert(index, nums[index]);
                } else {
                    break;
                }
            }
            
            
            // At most there are k elements in the heap
            for (int j = 0,i = k; i < nums.size(); i++,j++) {
                // O(lgK)
                int topMin = minPQ.topKey();
                int topMax = maxPQ.topKey();
                solution.push_back((k%2)? topMax:((double)(topMin) + (double)topMax)/2);
                
                // We need two cases here. If there were lot of similar elements distributed across both minPQ and maxPQ
                // We also have to verify if the element being removed was actually in minPQ.
                if (nums[j] > topMax || (nums[j] == topMin && !maxPQ.contains(j))) {
                    minPQ.deleteIndex(j);
                    if (nums[i] > topMax) {
                        minPQ.insert(i, nums[i]);
                    } else {
                        maxPQ.insert(i, nums[i]);
                        int index = maxPQ.deleteTop();
                        minPQ.insert(index, nums[index]);
                    }
                } else {
                    maxPQ.deleteIndex(j);
                    if (nums[i] <= topMax) {
                        maxPQ.insert(i, nums[i]);
                    } else {
                        minPQ.insert(i, nums[i]);
                        int index = minPQ.deleteTop();
                        maxPQ.insert(index, nums[index]);
                    }
                }
            }
            int topMin = minPQ.topKey();
            int topMax = maxPQ.topKey();
            solution.push_back((k%2)? topMax:((double)(topMin) + (double)topMax)/2);
            
            return solution;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.