The solution is to visualize that you need to keep on adding the sums of all nodes as you traverse, and also maintain the tilt value. One option may be to use a global variable for doing so. I chose to abstract everything into a wrapper class for ease of return. Here is the code:

```
public int findTilt(TreeNode root) {
if(root==null) {
return 0;
}
return helper(root).tilt;
}
private Wrapper helper(TreeNode root) {
if(root.left==null && root.right==null) {
return new Wrapper(root.val,0); //leaf node.
}
int leftTiltSum=0;
int leftTilt=0;
if(root.left!=null) {
Wrapper left = helper(root.left);
leftTilt = left.tilt;
leftTiltSum = left.sum;
}
int rightTiltSum=0;
int rightTilt=0;
if(root.right!=null) {
Wrapper right = helper(root.right);
rightTilt = right.tilt;
rightTiltSum = right.sum;
}
return new Wrapper(leftTiltSum+rightTiltSum+root.val,leftTilt+rightTilt+Math.abs(leftTiltSum-rightTiltSum));
}
private class Wrapper {
int sum;
int tilt;
Wrapper(int sum, int tilt) {
this.sum = sum;
this.tilt = tilt;
}
}
```