Since for a complete binary tree, every level is completely filled except for the last level, and nodes are filled as far left as possible, my approach to this problem is:

- See how far I can go if I keep going left.
- See how far I can go if I keep going right.
- If both depths are equal, it means the tree is completely filled. Use math formula to calculate the total number of nodes by tree's height.
- if they are different, it means the last level is partially filled. Use binary search idea to find the right-most node in the last level to know how many nodes are there in the last level.
- Add complete part and last level to return.

Elaborate on the 4th step:

Binary search idea is regular, but the way to check the existence of a specific node is a little bit tricky.

Left rotate the tree for 90 degrees. Imagine every level of branches as a digit. Either going left or going right makes it a binary representation of a number. Here, going left means it is `0`

on that digit, and going right for `1`

. And the number, say `k`

, means you are the `k-th`

node on that level, `0 <= k < 2^h`

. So finding the right-most `k-th`

number can let us know there are `k+1`

nodes in the last level.

Java implementation:

```
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
public class Solution {
public int countNodes(TreeNode root) {
if(root == null) return 0;
int lH = 0, rH = 0;
TreeNode p = root;
while(p.left != null) {
lH++;
p = p.left;
}
p = root;
while(p.right != null){
rH++;
p = p.right;
}
// 2^0 + 2^1 + ... + 2^n = 2^(n+1) - 1
int complete = (1 << (Math.min(lH, rH) + 1)) - 1;
int lastLevel = 0;
if(lH != rH){
lastLevel = helper(root, Math.max(lH, rH));
}
return lastLevel + complete;
}
// binary search
private int helper(TreeNode node, int h){
int start = 0;
int end = (1 << h) - 1;
while(start <= end){
int mid = start + (end - start) / 2;
int mask = (1 << (h-1));
TreeNode p = node;
while(mask > 0){
if((mask & mid) == 0) p = p.left;
else p = p.right;
mask >>= 1;
}
if(p == null) end = mid - 1;
else start = mid + 1;
}
return end + 1;
}
}
// 82 ms
```