public class Solution {
public int countNodes(TreeNode root) {
int k = 0;
if(root == null){
return 0;
}
int h1 = leftH(root.left);
while(root != null){
int h2 = leftH(root.right);
if(h1 == h2){
root = root.right;
}else{
root = root.left;
}
k += 1 << h2;
h1;
}
return k;
}
private int leftH(TreeNode curr){
int h = 0;
while(curr != null){
curr = curr.left;
h++;
}
return h;
}
}
Java Iterative Solution


I found my solution similar to yours, but might be a little more concise.
public class Solution { public int countNodes(TreeNode root) { int cnt=0; while(root!=null){ int l = height(root.left); int r = height(root.right); cnt += 1<<r; root = l>r? root.left: root.right; } return cnt; } private int height(TreeNode root){ int n=0; while(root!=null){ n++; root = root.left; } return n; } }