17 ms O(n) java Prefix sum method


  • 109

    So the idea is similar as Two sum, using HashMap to store ( key : the prefix sum, value : how many ways get to this prefix sum) , and whenever reach a node, we check if prefix sum - target exists in hashmap or not, if it does, we added up the ways of prefix sum - target into res.
    For instance : in one path we have 1,2,-1,-1,2, then the prefix sum will be: 1, 3, 2, 1, 3, let's say we want to find target sum is 2, then we will have{2}, {1,2,-1}, {2,-1,-1,2} and {2}ways.

    I used global variable count, but obviously we can avoid global variable by passing the count from bottom up. The time complexity is O(n). This is my first post in discuss, open to any improvement or criticism. :)

        public int pathSum(TreeNode root, int sum) {
            HashMap<Integer, Integer> preSum = new HashMap();
            preSum.put(0,1);
            helper(root, 0, sum, preSum);
            return count;
        }
        int count = 0;
        public void helper(TreeNode root, int currSum, int target, HashMap<Integer, Integer> preSum) {
            if (root == null) {
                return;
            }
            
            currSum += root.val;
    
            if (preSum.containsKey(currSum - target)) {
                count += preSum.get(currSum - target);
            }
            
            if (!preSum.containsKey(currSum)) {
                preSum.put(currSum, 1);
            } else {
                preSum.put(currSum, preSum.get(currSum)+1);
            }
            
            helper(root.left, currSum, target, preSum);
            helper(root.right, currSum, target, preSum);
            preSum.put(currSum, preSum.get(sum) - 1);
        }
    

    Thanks for your advice, @StefanPochmann . Here is the modified version, concise and shorter:

        public int pathSum(TreeNode root, int sum) {
            HashMap<Integer, Integer> preSum = new HashMap();
            preSum.put(0,1);
            return helper(root, 0, sum, preSum);
        }
        
        public int helper(TreeNode root, int currSum, int target, HashMap<Integer, Integer> preSum) {
            if (root == null) {
                return 0;
            }
            
            currSum += root.val;
            int res = preSum.getOrDefault(currSum - target, 0);
            preSum.put(currSum, preSum.getOrDefault(currSum, 0) + 1);
            
            res += helper(root.left, currSum, target, preSum) + helper(root.right, currSum, target, preSum);
            preSum.put(currSum, preSum.get(currSum) - 1);
            return res;
        }
    

  • 3
    D

    why do you minus one in the last line of code?
    preSum.put(sum, preSum.get(sum) - 1);


  • 8

    The helper could be shorter by using getOrDefault:

        count += preSum.getOrDefault(sum - target, 0);
        preSum.put(sum, preSum.getOrDefault(sum, 0) + 1);

  • 1
    N

    @DreamSeason

    I think for a node, once you processed all of its child nodes, the number of ways to get to current prefix sum should not effect the rest of nodes, in case they have the same prefix sum.


  • 0
    C

    Clever way to avoid coping the hash map


  • 1
    U

    Wow, one of the most genius use of existing solution. Cheers!!


  • 3
    V

    I feel like the idea behind the code is:

    When I come to one node, I want to find all paths ended with current node, whose sum equal to target.

    Is this your idea?


  • 1

    @van92 Yes, you are right!


  • 5
    B

    "we will have {2}, {1, 2, -1, -1, 2}, { 2, -1, -1, 2} ways"

    Can you elaborate on this? Wouldn't it be {2}, {1,2,-1}, and {2,-1,-1,2}?


  • 3
    M

    Could you explain what exactly your hashmap saves? what does the key mean in the hashmap? Why can you save the number of paths to a Integer? I think you must associate the number of paths to a certain node, instead of the prefix sum value. Because if two nodes on two different branches have same prefix value, the number of paths to them may not be the same, and if you calls preSum.get(prefix) then the res will be wrong.

    And why do you need to put preSum(sum, 1) first and at last let preSum(sum) - 1?


  • 0
    J

    Could we know what's the two sum problem and its solution we are talking here?
    I a curious how to solve it using prefix sum.

    Thanks


  • 37
    K

    This solution makes me confused at first and seems others are having the same problem.

    The idea is based on path.
    Suppose now the hash table preSum stores the prefix sum of the whole path. Then after adding current node's val to the pathsum, if (pathsum-target) is in the preSum, then we know that at some node of path we have a (pathsum-target) preSum, hence we have a path of target. Actually, it is the path starting from that node.

    Now the problem is how to maintain this preSum table? Since one path's preSum is different from others, we have to change it. However, we should notice that we can reuse the most part of the preSum table. If we are done with current node, we just need to delete the current pathsum in the preSum, and leave all other prefix sum in it. Then, in higher layers, we can forget everything about this node (and its descendants).
    That's why we have

    preSum.put(sum, preSum.get(sum) - 1);
    // this deletes current pathsum and leave all previous sums
    

    After running the algorithm, the preSum table should contain keys of all possible path sum starting from root, but all values of them are 0, except key 0. For instance in the example we should have:

    {0: 1, 7: 0, 10: 0, 15: 0, 16: 0, 17: 0, 18: 0, 21: 0}
    

    Hope it helps.

    @jeffery said in 17 ms O(n) java Prefix sum method:

    Could we know what's the two sum problem and its solution we are talking here?
    I a curious how to solve it using prefix sum.

    Thanks

    @marcusgao94 said in 17 ms O(n) java Prefix sum method:

    Could you explain what exactly your hashmap saves? what does the key mean in the hashmap? Why can you save the number of paths to a Integer? I think you must associate the number of paths to a certain node, instead of the prefix sum value. Because if two nodes on two different branches have same prefix value, the number of paths to them may not be the same, and if you calls preSum.get(prefix) then the res will be wrong.

    And why do you need to put preSum(sum, 1) first and at last let preSum(sum) - 1?


  • 0
    H

    @kvwang Thank you so much for your explanation! I read this solution many times and there had still been some spots that I couldn't figure out how they worked. Now it's clearified.


  • 2
    J

    I had the same idea implemented in C++

    class Solution {
    public:
        int pathSum(TreeNode* root, int sum) {
            int res = 0;
            unordered_map<int, int> prev_sums;
            prev_sums[0] = 1;
            dfs(root, prev_sums, 0, res, sum);
            return res;
        }
        
        void dfs(TreeNode* root, unordered_map<int, int> &prev_sums, int cur_sum, int &res, int sum){
            if(root == NULL) return;
            cur_sum = cur_sum + root->val;
            if(prev_sums.find(cur_sum - sum) != prev_sums.end()){
                res += prev_sums[cur_sum - sum];
            }
            prev_sums[cur_sum]++;
            dfs(root->left, prev_sums, cur_sum, res, sum);
            dfs(root->right, prev_sums, cur_sum, res, sum);
            prev_sums[cur_sum]--;
        }
    };
    

  • 79

    This is an excellent idea and took me some time to figure out the logic behind.
    Hope my comment here could help understanding this solution.

    1. The prefix stores the sum from the root to the current node in the recursion
    2. The map stores <prefix sum, frequency> pairs before getting to the current node. We can imagine a path from the root to the current node. The sum from any node in the middle of the path to the current node = the difference between the sum from the root to the current node and the prefix sum of the node in the middle.
    3. We are looking for some consecutive nodes that sum up to the given target value, which means the difference discussed in 2. should equal to the target value. In addition, we need to know how many differences are equal to the target value.
    4. Here comes the map. The map stores the frequency of all possible sum in the path to the current node. If the difference between the current sum and the target value exists in the map, there must exist a node in the middle of the path, such that from this node to the current node, the sum is equal to the target value.
    5. Note that there might be multiple nodes in the middle that satisfy what is discussed in 4. The frequency in the map is used to help with this.
    6. Therefore, in each recursion, the map stores all information we need to calculate the number of ranges that sum up to target. Note that each range starts from a middle node, ended by the current node.
    7. To get the total number of path count, we add up the number of valid paths ended by EACH node in the tree.
    8. Each recursion returns the total count of valid paths in the subtree rooted at the current node. And this sum can be divided into three parts:
      - the total number of valid paths in the subtree rooted at the current node's left child
      - the total number of valid paths in the subtree rooted at the current node's right child
      - the number of valid paths ended by the current node

    The interesting part of this solution is that the prefix is counted from the top(root) to the bottom(leaves), and the result of total count is calculated from the bottom to the top :D

    The code below takes 16 ms which is super fast.

    public int pathSum(TreeNode root, int sum) {
            if (root == null) {
                return 0;
            }
            Map<Integer, Integer> map = new HashMap<>();
            map.put(0, 1);
            return findPathSum(root, 0, sum, map);  
        }
        private int findPathSum(TreeNode curr, int sum, int target, Map<Integer, Integer> map) {
            if (curr == null) {
                return 0;
            }
            // update the prefix sum by adding the current val
            sum += curr.val;
            // get the number of valid path, ended by the current node
            int numPathToCurr = map.getOrDefault(sum-target, 0); 
            // update the map with the current sum, so the map is good to be passed to the next recursion
            map.put(sum, map.getOrDefault(sum, 0) + 1);
            // add the 3 parts discussed in 8. together
            int res = numPathToCurr + findPathSum(curr.left, sum, target, map)
                                                   + findPathSum(curr.right, sum, target, map);
           // restore the map, as the recursion goes from the bottom to the top
            map.put(sum, map.get(sum) - 1);
            return res;
        }
    

  • 0
    A

    @DreamSeason As far as I understand, the key in the HashMap is the total value in a specific path, and the value of this key is how many times this value appears. Taking this for an example, "For instance : in one path we have 1,2,-1,-1,2, then the prefix sum will be: 1, 3, 2, 1, 3", for this path, there should be like map(1,2), map(2,1), map(3,2). When you finish searching the last node in this path, i.e. 2 (both left node and right node are done), you should kind of release current path BUT NOT change other(i.e. previous) paths. So the change is from map(3,2) to map(3,1) which represents the second path 1->2.

    Correct me if I get misunderstanding. Thanks.


  • 0
    F

    Nice solution!


  • 0
    S

    Same idea for c++ version

    class Solution {
        unordered_map<int,int> m;
    public:
        int pathSum(TreeNode* root, int sum) {
             m.emplace(0,1);
             return backtrack(root, 0, sum);
        }
        int backtrack(TreeNode* root, int sum, int target){
            if(!root) return 0;
            sum += root->val;
            int res = (m.find(sum-target) != m.end())? m.at(sum-target) : 0;
            if(m.find(sum) != m.end()){
                m.at(sum)++;
            }else{
                m.emplace(sum,1);
            }
            res += backtrack(root->left, sum, target) + backtrack(root->right, sum, target);
            m.at(sum)--;
            return res;
        }
    

    };


  • 1
    A

    @marcusgao94 For anybody who's having trouble understanding this problem, look at this post in stackoverflow- http://stackoverflow.com/questions/14948258/given-an-input-array-find-all-subarrays-with-given-sum-k

    The stackoverflow link is about finding all subarrays of an array that add up to a target value.

    Now, think of every distinct path in the tree as an array, then it'll make more sense!


  • 0
    W
    This post is deleted!

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.