```
#define Count(n,p,lc) while (n) { ++lc; n = n->p; }
class Solution {
public:
int countNodes(TreeNode* root) {
int leftCount = 0, rightCount = 0;
TreeNode *n = root;
Count(n, left, leftCount);
n = root;
Count(n, right, rightCount);
if (leftCount == rightCount)
return pow(2, leftCount) - 1;
n = root;
return 1 + countNodes(n->left) + countNodes(n->right);
}
};
```