C++ segment tree implementation based on other's solution


  • 2
    X

    Implemented a C++ version based on this solution:

    https://leetcode.com/discuss/79073/java-segmenttree-solution-36ms

    // count of range sum, segment tree solution
    // based on this equation:
    // low <= sum[i] - sum[j] <= high, i > j, which could be changed into:
    // low + sum[j] <= sum[i] <= high + sum[j], i > j
    // i ranges from 0 to n-1, j ranges from -1 to n-2
    // sum[k] = sum of {num[0.....k]}, k ranges from 0 to n-1
    // in order to control i > j during each segment tree query, add sum[i] one by one into the segment tree, from i = n-1 to 0
    class SegmentTreeNode {
    public:
        long min;
        long max;
        int count; //number of values in [min, max]
        SegmentTreeNode *left;
        SegmentTreeNode *right;
        SegmentTreeNode (long min, long max, int count) {
            this->min = min;
            this->max = max;
            this->count = count;
            left = right = NULL;
        }
    };
    
    class Solution {
        SegmentTreeNode *root;
    public:
        int countRangeSum(vector<int>& nums, int lower, int upper) {
            if (nums.size() == 0) return 0;
            unordered_set<long> ssum;
            long tmpSum = 0;
            for (auto n: nums) {
                tmpSum += n;
                ssum.insert(tmpSum);
            }
            // copy the set to a vector
            vector<long> sum;
            for (auto s: ssum) {
                sum.push_back(s);
            }
            sort(sum.begin(), sum.end());
            // build a segment tree using the sorted sums
            root = build(sum, 0, sum.size() - 1);
            // next count the number of sums based on low + sum[j] <= sum[i] <= high + sum[j], i > j
            int count = 0;
            for (int i = nums.size() - 1; i >= 0; i--) {
                modify(root, tmpSum);
                tmpSum -= nums[i];
                count += query(root, lower + tmpSum, upper + tmpSum);
            }
            return count;
        }
        
        SegmentTreeNode *build(vector<long> &sum, int start, int end) {
            SegmentTreeNode *root = new SegmentTreeNode(sum[start], sum[end], 0);
            if (start == end) {
                return root;
            } else {
                int middle = (start + end) / 2;
                root->left = build(sum, start, middle);
                root->right = build(sum, middle + 1, end);
                return root;
            }
        }
        
        void modify(SegmentTreeNode *root, long value) {
            if (root == NULL || value < root->min || value > root->max) return;
            // now value is between root->min and root->max
            root->count += 1;
            if (root->min == root->max) {
                return;
            } else {
                modify(root->left, value);
                modify(root->right, value);
            }
        }
        
        int query(SegmentTreeNode *root, long min, long max) {
            if (root == NULL || max < root->min || min > root->max) return 0;
            if (min <= root->min && max >= root->max) return root->count;
            return query(root->left, min, max) + query(root->right, min, max);
        }
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.