57 ms O(h^2) simple Java solution


  • 2
    H

    Since the tree is complete tree, there’s some properties we can make use of. For example, we can get the height of a complete tree by going straight down to the leftmost node and count the height. Also, the left and right subtrees of a complete tree are also complete trees.

    Assuming we now know the heights of a node’s left and right subtrees, we can know that the left subtree’s height is either the same as or one more than the right one’s height. If they have same height, then it means the left subtree is full tree, and the last node of the last level is in the right subtree.

    Inspired by how we store binary trees in an array, we know for a node stored in arr[i], its left child is in arr[2 * i] and right child is in arr[2 * i + 1]. So, if we can find the last element in an array that represents this complete binary tree, we know its count. Now back to the heights of left and right subtrees, if they are the same, it means the last node is on the right subtree; otherwise, last node is in the left subtree. We use a pointer to track down the current node, the pointer goes to right child if the heights are the same, to left child otherwise. At the same time, we maintain a variable count, initialized as 1, which serves as the index of the current node in an array representation. Each time it moves left, count *= 2; it moves right, count = 2 * count + 1. If the current node has no left child, it means it has reached the last level, and it also means it is the last node. And the count now is the final answer.

    public class Solution {
        public int countNodes(TreeNode root) {
            if(root == null) return 0;
            TreeNode cur = root;
            int count = 1;
            while(cur.left != null) {
                if(getHeight(cur.left) > getHeight(cur.right)) {
                    cur = cur.left;
                    count *= 2;
                } else {
                    cur = cur.right;
                    count = count * 2 + 1;
                }
            }
            return count;   
        }
        
        private int getHeight(TreeNode root) {
            int height = -1;
            while(root != null) {
                root = root.left;
                height++;
            }
            return height;
        }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.