8ms Simple Memoized Java Solution


  • 2
    public int rob(TreeNode root) {
        
        HashMap<TreeNode, Integer> map = new HashMap<>();
        
        return robBoth(root, map);
    }
    

    // For Each node find maximum of both - one including the node and one excluding the node

    public int robBoth(TreeNode root, HashMap<TreeNode, Integer> map) {
        if (root == null)
            return 0;
        
        if (map.containsKey(root))
            return map.get(root);
            
        int max = Math.max(robInclude(root, map), robExclude(root, map));
        
        map.put(root, max);
        
        return max;
    }
    

    // Include the given node to find max sum and exclude the child nodes

    public int robInclude(TreeNode root, HashMap<TreeNode, Integer> map) {
        if (root == null)
            return 0;
            
        return robExclude(root.left, map) + robExclude(root.right, map) + root.val;
    }
    

    //Exclude the given node and can include/exclude the child nodes

    public int robExclude(TreeNode root, HashMap<TreeNode, Integer> map) {
        if (root == null)
            return 0;
            
        return robBoth(root.left, map) + robBoth(root.right, map);
    }

  • 0
    L

    public class Solution {
    public int rob(TreeNode root) {
    if(root == null) return 0;

    	Map<TreeNode,Integer> map = new HashMap<TreeNode,Integer>();
        
    	int val1 = func(root,map);
    	int val2 = func(root.left,map)+func(root.right,map);
    	
    	return val1>val2?val1:val2;
    }
    
    public int func(TreeNode node,Map<TreeNode,Integer> map){
    	
    	if(node == null)return 0;
    	
    	if(map.containsKey(node))return map.get(node);
    	
    	if(node.left==null && node.right==null){
    		map.put(node, node.val);
    		return node.val;
    	}
    	
    	int val1 = 0, val2 = 0,max = 0;
    	if(node.left!=null&&node.right!=null){
    		val1 = node.val+func(node.left.left,map)+func(node.left.right,map)+func(node.right.left,map)+func(node.right.right,map);
    		val2 = func(node.left,map)+func(node.right,map);
    		
    	}
    	else if(node.left!=null){
    		val1 = node.val+func(node.left.left,map)+func(node.left.right,map);
    		val2 = func(node.left,map);
    	}
    	else{
    		val1 = node.val+func(node.right.left,map)+func(node.right.right,map);
    		val2 = func(node.right,map);
    	}
    	max = val1>val2?val1:val2;
    	map.put(node, max);
    	return max;
    }
    

    }


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.