C++ double Binary Index Tree greedy solution, O(nlog^3n), just to prove BIT is as general as Segment tree


  • 0
    M
    class Solution {
        #define lowbit(x) (x&(-x))
    public:
        int minarr[10001];
        int delta[10001];
        int N = 10000;
    
        void init(int n) {
            N = n;
            for (int i = 1; i <= N; i++) {
                minarr[i] = N + 1 - i;
            }
        }
    
        void update_sum(int l, int x) {
            while(l <= N) {
                delta[l] += x;
                l+=lowbit(l);
                //printf("l=%d\n", l);
            }
        }
    
        int get_sum(int l) {
            int ret = 0, ll = l;
            while (l) {
                ret += delta[l];
                l -= lowbit(l);
                //printf("get_sum(%d)\n", l);
            }
            //printf("get_sum(%d) return %d\n", ll, ret);
            return ret;
        }
    
        int get_min(int dl) {
            int ret = INT_MAX, ml = N + 1 - dl;
            while (ml) {
                dl = N + 1 - ml;
                ret = min(ret, minarr[ml] - get_sum(dl));
                ml -= lowbit(ml);
            }
            return ret;
        }
    
        
        void update_min(int dl) {
            int ml = N + 1 - dl;
            int curmin = minarr[ml] - get_sum(dl);
            for (int i = ml + lowbit(ml); i <= N; i += lowbit(i)) {
                dl = N + 1 - i;
                int accum = get_sum(dl);
                if (minarr[i] - accum > curmin) {
                    minarr[i] = accum + curmin;
                }
                //printf("i = %d, N = %d\n", i, N);
                for (int j = 1; j < lowbit(i); j <<= 1) {
                    dl = N + 1 - i + j;
                    int tmp = get_sum(dl);
                    if (minarr[i] - accum > minarr[i-j] - tmp) {
                        minarr[i] = minarr[i-j] - tmp + accum; 
                    }
                }
            }
        }
      
        int scheduleCourse(vector<vector<int>>& course) {
            sort(course.begin(), course.end(), [](vector<int> const& a, vector<int> const& b){
                if (a[0] == b[0]) return a[1] > b[1];
                return a[0] < b[0];
            });
            int ret = 0, n = 0;
            for (int i=0; i <course.size(); i++) {
               if (course[i][1] > n ) n = course[i][1];
            }
            init(n);
            for (auto &v: course) {
                int t = v[0];
                int e = v[1];
                if (get_min(e) >= t) {
                    update_sum(e, t);
                    update_min(e);
                    ret++;
                }
            }
            return ret;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.