int getLeftHeight(TreeNode* root) {
int height = 0;
while(root) {
root = root>left;
height++;
}
return height;
}
int countNodes(TreeNode* root) {
if(!root) return 0;
int left_height = getLeftHeight(root>left);
int right_height = getLeftHeight(root>right);
if(left_height == right_height)
return pow(2, left_height) + countNodes(root>right);
return pow(2, right_height) + countNodes(root>left);
}
Simple C++ recursive solution


I cant pass in Java,life is hard with Java
public class Solution { public int countNodes(TreeNode root) { if(root==null) return 0; int l = getLeftCount(root.left); int r = getLeftCount(root.right); if(l==r){ return (int)Math.pow(2,l)+countNodes(root.right); }else{ return (int)Math.pow(2,r)+countNodes(root.left); } } public int getLeftCount(TreeNode n){ int count = 0; while(n!=null){ n = n.left; count++; } return count; }
}

@luchy0120 Yeah, life is hard with Java :D Looks like the Math.pow is not fast enough in Java. After change it to bit manipulation, it will pass. Check this.
public class Solution { int getLeftHeight(TreeNode root) { int height = 0; while (root != null) { root = root.left; height++; } return height; } public int countNodes(TreeNode root) { if (root == null) return 0; int left_height = getLeftHeight(root.left); int right_height = getLeftHeight(root.right); if(left_height == right_height) return (1 << left_height) + countNodes(root.right); return (1 << right_height) + countNodes(root.left); } }