```
int leftHeight(TreeNode* root) {
int n = 0;
while (root) {
root = root->left;
++n;
}
return n;
}
int countNodes(TreeNode* root) {
if (!root) return 0;
int lH = leftHeight(root->left), rH = leftHeight(root->right);
if (lH == rH)
return (1<<lH) + countNodes(root->right); // (1<<lH) - 1 + countNodes(root->right) + 1;
else
return (1<<rH) + countNodes(root->left); // (1<<rH) -1 + countNodes(root->left) + 1;
}
```