```
int depth(TreeNode root){
if(root.left != null)
return depth(root.left) + 1;
else
return 0;
}
boolean f(int low, int high, int mid, TreeNode root){
if(low + 1 == high){
if(mid == low && root.left != null){
return true;
}else if(mid == high && root.right != null){
return true;
}
return false;
}
int diff = (high-low)/2;
int leftLow = low, leftHigh=leftLow + diff;
int rightLow = leftHigh + 1, rightHigh = rightLow + diff;
if(leftLow <= mid && leftHigh >= mid){
return f(leftLow, leftHigh, mid, root.left);
}else{
return f(rightLow, rightHigh, mid, root.right);
}
}
public int countNodes(TreeNode root) {
if(root == null) return 0;
int d = depth(root);
if(d == 0) return 1;
if(d == 1){
if(root.right == null) return 2;
else return 3;
}
int internalNodes = 1;
for(int i=1; i<d; i++) internalNodes += (1<<i);
int leaves = 1<<d;
int left=1, right=leaves;
int leftRange=1, rightRange=leaves;
int mid=1, lastSuccess=1;
while(leftRange <= rightRange){
mid = leftRange + (rightRange-leftRange)/2;
boolean ret = f(left, right, mid, root);
if(ret) lastSuccess = mid;
if(ret){
leftRange=mid+1;
}else{
rightRange=mid-1;
}
}
return internalNodes + lastSuccess;
}
```

My approach is to binary search from 1,2^h, where 1,2^h correspond to the leaves from left to right. We can descend the tree to find the position where leaf i, 1<=i<=2^h should be in log(n) time. Am I making a mistake in my analysis?