A c++ binary search tree solution


  • 0
    C

    I guess most of you have solved the problem 303, Range Sum Query - Immutable. The idea of that problem is to construct 'a prefix sum array' (just how i call it, not very sure the exact terminology) first, then we have

    sum[ i ~j ] = prefix_sum[ j ] - prefix_sum[ i ].

    Back to this problem, if we maintain the prefix sums sorted, and then are able to find out how many of the sums are less than 'lower', say num1, and how many of the sums are less than 'upper + 1', say num2, then 'num2 - num1' is the number of sums that lie within the range of [lower, upper].

    That is to say for a current prefix sum, s_i, we need to find how many of the previous prefix sums satisfy the equation

    s_i - x < lower,
    

    and

    s_i - x < upper + 1,
    

    where x is some one of the previous prefix sums. We get

    x > s_i - lower,
    

    and

    x > s_i - upper - 1.
    

    That's where the 'greaterThan' comes from. Please always remember to destruct the tree you build.

    class Tree
    {
    private:
        struct Node 
        {
            long long val;
            int cnt;
            Node* left, *right;
            Node(long long v): val(v), cnt(1), left(nullptr), right(nullptr){}
        };
    
        void insert(Node *&root, long long val)
        {
            if (!root) 
            {
                root = new Node(val);
                return;
            }
            root->cnt++;
            if (root->val == val) return;
            else if (root->val < val) insert(root->right, val);
            else insert(root->left, val);
        }
    
        int greaterThan(Node *root, long long val, int res)
        {
            if (!root) return res;
            if (root->val > val)
            {
                int tmp = root->cnt - (root->left ? root->left->cnt : 0);
                return greaterThan(root->left, val, res + tmp);
            }
            else if (root->val == val) return res + (root->right ? root->right->cnt : 0);
            else return greaterThan(root->right, val, res);
        }
    
        void destoryTree(Node *root)
        {
            if (!root) return;
            destoryTree(root->left);
            destoryTree(root->right);
            delete root;
        }
    public:
        Tree(): root(nullptr) {}
        ~Tree() {destoryTree(root);}
        void insert(long long val) {insert(root, val);}    
        int greaterThan(long long val) {return greaterThan(root, val, 0);}
    private:
        Node* root;
    };
    
    class Solution 
    {
    public:
        int countRangeSum(vector<int>& nums, int lower, int upper) 
        {
            Tree tree;
            tree.insert(0);
            long long sum = 0;
            int res = 0;
            for(int &i : nums)
            {
                sum += i;
                int lo = tree.greaterThan(sum - lower);
                int hi = tree.greaterThan(sum - upper - 1);
                res += hi - lo;
                tree.insert(sum);
            }
            return res;
        }
    };

  • 0
    X

    Thank you for sharing! I just have a question: what does root->count mean exactly? From implementation it represents the # of visited times for a certain node during BST construction. However, I don't know how it relates to the counting in greaterThan()... again, what does it mean?


  • 0
    C

    Sorry for the ambiguity. It simply means the number of nodes rooted in the node, say node 'A', including the node A itself. Yet when duplicates happen, I only increment the count instead of inserting a new node.

    So when a val to be inserted is equal to the val of the node, then we have 'node->right->count' nodes (the number of nodes from the right subtree) that are greater than this val.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.