The basic concept is that we keep judging if the left child sub-tree of a node is a fully grown BST. If then, we count the num of elements of the last layer and recursively examine its right sub-tree. Or we directly examine its left sub-tree. The time comsumption is O(h^2), that is O(lg^2n)

```
class Solution:
# @param {TreeNode} root
# @return {integer}
def getTreeHeight(self, root):
'''
Return the height of a BST
'''
height = 0
while(root):
height += 1
root = root.left
return height
def countNodes(self, root):
height = self.getTreeHeight(root)
count = 0 # We use count to count the elements in the last layer
for i in xrange(1,height):
# If the left child tree is a fully-grown BST
if self.getTreeHeight(root.right) == height - i:
count += 2 ** (height - 1 - i)
root = root.right
else:
root = root.left
# Plus the upper h-1 layers of nodes
return count + 2 ** (height-1)
```