Segment Tree, Binary Indexed Tree and the simple way using buffer to accelerate in C++, all quite efficient


  • 8

    Segment Tree

    struct SegmentTreeNode 
    {
        int start, end, sum;
        SegmentTreeNode* left;
        SegmentTreeNode* right;
        SegmentTreeNode(int a, int b):start(a),end(b),sum(0),left(nullptr),right(nullptr){}
    };
    
    class NumArray 
    {
    private:
        SegmentTreeNode* root;
        SegmentTreeNode* buildTree(vector<int> &nums, int start, int end) 
        {
            if(start > end) return nullptr;
            SegmentTreeNode* root = new SegmentTreeNode(start,end);
            if(start == end) 
            {
                root->sum = nums[start];
                return root;
            }
            int mid = start + (end - start) / 2;
            root->left = buildTree(nums,start,mid);
            root->right = buildTree(nums,mid+1,end);
            root->sum = root->left->sum + root->right->sum;
            return root;
        }
    
        int modifyTree(int i, int val, SegmentTreeNode* root) 
        {
            if(root == nullptr) return 0;
            int diff;
            if(root->start == i && root->end == i) 
            {
                diff = val - root->sum;
                root->sum = val;
                return diff;
            }
            int mid = (root->start + root->end) / 2;
            if(i > mid) diff = modifyTree(i,val,root->right);
            else diff = modifyTree(i,val,root->left);
            root->sum += diff;
            return diff;
        }
    
        int queryTree(int i, int j, SegmentTreeNode* root) 
        {
            if(root == nullptr) return 0;
            if(root->start == i && root->end == j) return root->sum;
            int mid = (root->start + root->end) / 2;
            if(i > mid) return queryTree(i,j,root->right);
            if(j <= mid) return queryTree(i,j,root->left);
            return queryTree(i,mid,root->left) + queryTree(mid+1,j,root->right);
        }
    
    public:
        NumArray(vector<int> &nums) 
        {
            root = buildTree(nums,0,nums.size()-1);
        }
    
        void update(int i, int val) 
        {
            modifyTree(i,val,root);
        }
    
        //AC - 56ms - Segment Tree;
        int sumRange(int i, int j) 
        {
            return queryTree(i, j, root);
        }
    };
    

    Fenwick Tree or Binary Indexed Tree

    class NumArray {
    private:
        int* BIT;
        int size;
        int sum(int x)
        {
            int ret = 0;
            while(x)
            {
                ret += BIT[x];
                x-= (x&-x);
            }
            return ret;
        }
    public:
        NumArray(vector<int> &nums) 
        {
            size = nums.size();
            BIT = (int*)malloc(sizeof(int)*(size+1));
            memset(BIT,0,sizeof(int)*(size+1));
            for(int i = 0; i < size; i++)
            update(i,nums[i]);
        }
    
        void update(int i, int val) 
        {
            i++;
            val -= sum(i) - sum(i-1);
            while(i <= size)
            {
                BIT[i] += val;
                i+= (i&-i);
            }
        }
        
        //AC - 52ms - Fenwick Tree or Binary Indexed Tree;
        int sumRange(int i, int j) 
        {
            return sum(j+1) - sum(i);
        }
    };
    

    Réguler method using buffer to accelerate

    class NumArray {
    private:    
        vector<long> sums;
        vector<int> nums;
        vector<pair<int,int>> buffer;
    public:
        NumArray(vector<int> &nums) : sums(nums.size()+1, 0), nums(nums)
        {
            partial_sum(nums.begin(), nums.end(), sums.begin() + 1);
        }
        void update(int i, int val) 
        {
            buffer.emplace_back(i,  val - nums[i]); //buffer.push_back(make_pair(i, val-nums[i]));
            nums[i] = val;        
            if(buffer.size() > 300) 
            {
                partial_sum(nums.begin(), nums.end(), sums.begin() + 1);
                buffer.clear();
            }
        }
    
        //AC - 80ms - just using a buffer, dramatically reduce the time cost;
        int sumRange(int i, int j)
        {
            long result =  sums[j+1] - sums[i];
            for(const auto& p : buffer)
                if(p.first <=j && p.first >= i) result += p.second;
            return result;
        }
    };

  • 0

    As someone new to Fenwick Tree, I have a question regarding the update function. Is what it does just add "val" to the number at index i, or does it replace the original value at index i with val? What is the meaning of "val -= sum(i) - sum(i-1);" in your update function? Thanks


  • 0

    @fittaoee update here means change a value from a to b then there is a difference between them val is that difference we need consider here. Think it twice then you must can get it.


  • 0
    S

    Thanks @LHearen for your terrific answer.
    I updated the Segment Tree method a bit based on this post on geeksforgeeks.
    Basically, I used array to store the segment tree (refered to Chap6 Heapsort of CLRS), cause the tree is nearly completely full except the lowest level. So we can save some hassle from the pointers.

    class NumArray {
      public:
        NumArray(vector<int> &nums):N(nums.size()) {
          if (N) {
            int capacity = 1;
            while (capacity <  N) {
              capacity <<= 1;
            }
            segment_tree = vector<int>(capacity * 2 - 1, 0);
            construct_segment_util(nums, 0, N-1, 0);
          }
        }
    
        void update(int i, int val) {
          update_util(0, 0, N-1, i, val);
        }
    
        int sumRange(int i, int j) {
          return sumRange_util(0, 0, N-1, i, j);
        }
    
      private:
        int N;
        vector<int> segment_tree;
        int construct_segment_util(vector<int>& nums, int istart, int iend, int iseg);
        int sumRange_util(int iseg, int istart, int iend, int i, int j);
        int update_util(int iseg, int istart, int iend, int i, int val);
        inline int get_mid (int i, int j)  {return (i + j)>>1;}
        inline int left_child (int i)  {return (i<<1) + 1;}
        inline int right_child(int i)  {return (i<<1) + 2;}
        inline bool intersect(int i1_left, int i1_right, int i2_left, int i2_right) {
          return (i1_left > i2_right || i1_right < i2_left)  ? false: true;}
        inline bool inclusive(int i1_left, int i1_right, int i2_left, int i2_right) {
          return (i1_left >= i2_left && i1_right <= i2_right)? true: false;}
    };
    
     int NumArray::construct_segment_util(vector<int>& nums, int istart, int iend, int iseg) {
       if (istart == iend) {
         segment_tree[iseg] = nums[istart];
       } else {
         int imid = get_mid(istart, iend);
         segment_tree[iseg] = construct_segment_util(nums, istart, imid, left_child(iseg)) +
                              construct_segment_util(nums, imid+1, iend, right_child(iseg));
       }
       return segment_tree[iseg];
     }
    
     int NumArray::sumRange_util(int iseg, int istart, int iend, int i, int j) {
       if (intersect(istart, iend, i, j)) {
         if (inclusive(istart, iend, i, j)) {
           return segment_tree[iseg];
         } else {
           int imid = get_mid(istart, iend);
           return sumRange_util(left_child(iseg), istart, imid, i, j) + sumRange_util(right_child(iseg), imid+1, iend, i, j);
         }
       } else {
         return 0;
       }
     }
    
     int NumArray::update_util(int iseg, int istart, int iend, int i, int val){
       int diff;
       if (istart != iend) {
         int imid = get_mid(istart, iend);
         if (i <= imid)
           diff = update_util(left_child(iseg), istart, imid, i, val);
         else
           diff = update_util(right_child(iseg), imid+1, iend, i, val);
       } else {
           diff = val - segment_tree[iseg];
       }
       segment_tree[iseg] += diff;
       return diff;
     }
    

  • 0
    B

    Good buffer solution!


  • 0
    R

    The Segment tree solutions is not getting accepted.
    it gives error: Line 122: invalid initialization of non-const reference of type 'std::__debug::vector<int>&' from an rvalue of type 'std::__debug::vector<int>'


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.