# My JAVA solution with explanation which beats 99%

• Basically my solution contains 2 steps.
(1) Firstly, we need to find the height of the binary tree and count the nodes above the last level.
(2) Then we should find a way to count the nodes on the last level.

Here I used a kind of binary search. We define the "midNode" of the last level as a node following the path "root->left->right->right->...->last level".

If midNode is null, then it means we should count the nodes on the last level in the left subtree.

If midNode is not null, then we add half of the last level nodes to our result and then count the nodes on the last level in the right subtree.

Of course I used some stop condition to make the code more efficient, e.g. when a tree has height 1, it means it only has 3 cases: 1. has right son; 2. only has left son; 3. has no son.

``````public int countNodes(TreeNode root) {
if (root==null) return 0;
if (root.left==null) return 1;
int height = 0;
int nodesSum = 0;
TreeNode curr = root;
while(curr.left!=null) {
nodesSum += (1<<height);
height++;
curr = curr.left;
}
return nodesSum + countLastLevel(root, height);
}

private int countLastLevel(TreeNode root, int height) {
if(height==1)
if (root.right!=null) return 2;
else if (root.left!=null) return 1;
else return 0;
TreeNode midNode = root.left;
int currHeight = 1;
while(currHeight<height) {
currHeight++;
midNode = midNode.right;
}
if (midNode==null) return countLastLevel(root.left, height-1);
else return (1<<(height-1)) + countLastLevel(root.right, height-1);
}``````

• Brilliant Idea! Thanks for sharing.

• Thx for sharing. This gives me knowledge of how the tag "binary search" works.

• oh!!this is much easier for me to understand,thanks for sharing!

• This idea is spectacular!

• This idea is spectacular!

• So brilliant!!!

• what‘s the runtime?

• what is the runtime?

• `````` public int countNodes(TreeNode root) {
if (root==null) return 0;
if (root.left==null) return 1;
int height = 0;
int nodesSum = 0;
TreeNode curr = root;
while(curr.left!=null) {
nodesSum += (1<<height);
height++;
curr = curr.left;
System.out.println(nodesSum+" "+height);
}
return nodesSum + countLastLevel(root, height);
}
``````

Why is it giving a TLE for System.out.println(nodesSum+" "+height);

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.