O(n log max(nums)) solution using Wavelet Matrix


  • 0
    H

    Do you know "Wavelet Matrix" to compute k-th number in an array?

    I'll introduce my code using the wavelet matrix, because very few of you mention about this nice data structure. Wavelet matrix can compute "k-th number in arbitrary interval [l, r)" with O(log max(nums)). Obviously k-th number and median go together, you know.

    The solution beats only 20% of C++ submissions. However, I believe this code have flexibility to deal with other problems, because wavelet matrix can support wide variety of other queries. I hope my post will help your learning.

    Note that the solution is NOT dependent to k.

    Thank you.

    #include <bits/stdc++.h>
    #include <sys/time.h>
    using namespace std;
    
    using ll = long long; using vll = vector<ll>; 
    
    /*****************/
    // Dictionary
    /*****************/
    template<int N> class FID {
        static const int bucket = 512, block = 16;
        static char popcount[];
        int n, B[N/bucket+10];
        unsigned short bs[N/block+10] = {};
        unsigned short b[N/block+10] = {};
    
    public:
        FID(){}
        FID(int n, bool s[]) : n(n) {
            if(!popcount[1]) for (int i = 0; i < (1<<block); i++) popcount[i] = __builtin_popcount(i);
    
            bs[0] = B[0] = b[0] = 0;
            for (int i = 0; i < n; i++) {
                if(i%block == 0) {
                    bs[i/block+1] = 0;
                    if(i%bucket == 0) {
                        B[i/bucket+1] = B[i/bucket];
                        b[i/block+1] = b[i/block] = 0;
                    }
                    else b[i/block+1] = b[i/block];
                }
                bs[i/block]   |= short(s[i])<<(i%block);
                b[i/block+1]  += s[i];
                B[i/bucket+1] += s[i];
            }
            if(n%bucket == 0) b[n/block] = 0;
        }
    
        int count(bool val, int r) { return val? B[r/bucket]+b[r/block]+popcount[bs[r/block]&((1<<(r%block))-1)]: r-count(1,r); }
    };
    template<int N> char FID<N>::popcount[1<<FID<N>::block];
    
    /*****************/
    // Wavelet Matrix
    /*****************/
    template<class T, int N, int D> class wavelet {
        int n, zs[D];
        FID<N> dat[D];
    public:
        wavelet(int n, T seq[]) : n(n) {
            T f[N], l[N], r[N];
            bool b[N];
            memcpy(f, seq, sizeof(T)*n);
            for (int d = 0; d < D; d++) {
                int lh = 0, rh = 0;
                for (int i = 0; i < n; i++) {
                    bool k = (f[i]>>(D-d-1))&1;
                    if(k) r[rh++] = f[i];
                    else l[lh++] = f[i];
                    b[i] = k;
                }
                dat[d] = FID<N>(n,b);
                zs[d] = lh;
                swap(l,f);
                memcpy(f+lh, r, rh*sizeof(T));
            }
        }
    
        // O(D)
        int count(T val, int l, int r) {
            for (int d = 0; d < D; d++) {
                bool b = (val>>(D-d-1))&1;
                l = dat[d].count(b,l)+b*zs[d];
                r = dat[d].count(b,r)+b*zs[d];
            }
            return r-l;
        }
        int count(T val, int r) { return count(val,0,r); }
    
        // O(D), k is 0-indexed!!!!
        T kth_number(int l, int r, int k) {
            if(r-l <= k or k < 0) return -1;
            T ret = 0;
            for (int d = 0; d < D; d++) {
                int lc = dat[d].count(1,l), rc = dat[d].count(1,r);
                if(rc-lc > k) { 
                    l = lc+zs[d], r = rc+zs[d]; 
                    ret |= 1ULL<<(D-d-1);
                } else { 
                    k -= rc-lc;
                    l -= lc, r -= rc;
                }
            }
            return ret;
        }
    };
    
    #define MAXN 10000
    #define OFFSET 10000000000ll
    #define MAXK 34
    class Solution {
        public:
            vector<double> medianSlidingWindow(vector<int>& nums_, int k) {
                ll n = nums_.size();
    
                vector<ll> nums(n); 
                for (int i = 0; i < n; i++) nums[i] = nums_[i];
                for (int i = 0; i < n; i++) nums[i] += OFFSET;
    
                vector<double> ret(n-k+1);
                wavelet<ll, MAXN, MAXK> w(n, nums.data());
                for (int i = 0; i < n - k + 1; i++)
                    ret[i] = (k % 2 ?
                            w.kth_number(i, i+k, k/2) :
                            (w.kth_number(i, i+k, k/2) + w.kth_number(i, i+k, k/2-1)) / 2.0) - OFFSET;
    
                return ret;
            }
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.