```
# @param {TreeNode} root
# @return {Integer}
def count_nodes(root)
return 0 if root.nil?
curr = root
hl = 1
hl += 1 while curr = curr.left
curr = root
hr = 1
hr += 1 while curr = curr.right
return 2**hl - 1 if hl == hr
1 + count_nodes(root.left) + count_nodes(root.right)
end
```