44ms Balanced BST C++ solution


  • 5
    struct MyTreeNode {
        MyTreeNode(long long value) {
            val = value;
            count = less = 0;
            left = right = nullptr;
        }
        long long val;
        int count;
        int less;
        MyTreeNode* left;
        MyTreeNode* right;
    };
    
    class Solution {
    public:
        int countRangeSum(vector<int>& nums, int lower, int upper) {
            int result = 0;
            long long sum = 0;
    
            // remove duplicated
            unordered_set<long long> hash = {0};
            for (int n : nums) {
                sum += n;
                hash.insert(sum);
            }
            // sort
            vector<long long> orderedNums(hash.begin(), hash.end());
            sort(orderedNums.begin(), orderedNums.end());
    
            auto* tree = buildBalancedTree(orderedNums.begin(), orderedNums.end());
            
            // lower <= sum[i] - sum[x] <= upper      (i > x)
            // sum[i] - upper <= sum[x] <= sum[i] - lower;
            sum = 0;
            insert(tree, 0);
            for (int n : nums) {
                sum += n;
                int loCount = countLessThanValue(tree, sum - upper);
                int hiCount = countLessThanValue(tree, sum - lower + 1);
                result += hiCount - loCount;
                insert(tree, sum);
            }
            
            return result;
        }
    private:
        MyTreeNode* buildBalancedTree(vector<long long>::iterator begin, vector<long long>::iterator end) {
            if (begin == end) return nullptr;
            auto mid = begin + (end - begin) / 2;
            auto* node = new MyTreeNode(*mid);
            node->left = buildBalancedTree(begin, mid);
            node->right = buildBalancedTree(mid + 1, end);
            return node;
        }
        int countLessThanValue(MyTreeNode* pNode, long long value) {
            int count = 0;
            while (pNode != nullptr) {
                if (value < pNode->val) {
                    pNode = pNode->left;
                } else if (value > pNode->val) {
                    count += pNode->count + pNode->less;
                    pNode = pNode->right;
                } else {
                    count += pNode->less;
                    break;
                }
            }
            
            return count;
        }
        void insert(MyTreeNode* pNode, long long value) {
            while (value != pNode->val) {
                if (value < pNode->val) {
                    ++(pNode->less);
                    pNode = pNode->left;
                } else {
                    pNode = pNode->right;
                }
            }
            ++(pNode->count);
        }
    };
    

  • 0
    D

    Why did you insert "0" before building the tree ?


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.