# Share my Java solution(82 ms): Using binary search to count the last level

• Since for a complete binary tree, every level is completely filled except for the last level, and nodes are filled as far left as possible, my approach to this problem is:

1. See how far I can go if I keep going left.
2. See how far I can go if I keep going right.
3. If both depths are equal, it means the tree is completely filled. Use math formula to calculate the total number of nodes by tree's height.
4. if they are different, it means the last level is partially filled. Use binary search idea to find the right-most node in the last level to know how many nodes are there in the last level.
5. Add complete part and last level to return.

Elaborate on the 4th step:
Binary search idea is regular, but the way to check the existence of a specific node is a little bit tricky.
Left rotate the tree for 90 degrees. Imagine every level of branches as a digit. Either going left or going right makes it a binary representation of a number. Here, going left means it is `0` on that digit, and going right for `1`. And the number, say `k`, means you are the `k-th` node on that level, `0 <= k < 2^h`. So finding the right-most `k-th` number can let us know there are `k+1` nodes in the last level.

Java implementation:

``````/**
* Definition for a binary tree node.
* public class TreeNode {
*     int val;
*     TreeNode left;
*     TreeNode right;
*     TreeNode(int x) { val = x; }
* }
*/
public class Solution {
public int countNodes(TreeNode root) {
if(root == null) return 0;

int lH = 0, rH = 0;
TreeNode p = root;
while(p.left != null) {
lH++;
p = p.left;
}
p = root;
while(p.right != null){
rH++;
p = p.right;
}
// 2^0 + 2^1 + ... + 2^n = 2^(n+1) - 1
int complete = (1 << (Math.min(lH, rH) + 1)) - 1;

int lastLevel = 0;
if(lH != rH){
lastLevel = helper(root, Math.max(lH, rH));
}

return lastLevel + complete;
}
// binary search
private int helper(TreeNode node, int h){
int start = 0;
int end = (1 << h) - 1;
while(start <= end){
int mid = start + (end - start) / 2;
int mask = (1 << (h-1));
TreeNode p = node;
if((mask & mid) == 0) p = p.left;
else p = p.right;