@angelvivienne

Thanks for sharing three solutions.

As for the time complexity, the worse case may be not O(n log n) as you explained.

Think about a skewed binary search tree, like this.

For every node, the time complexity of countNodes() is O(m), where m is # of node in the left tree. Suppose k = 1, then we need to call kthSmallest() on every node, from 9 to 1. With each call, countNodes() is called. Thus, the overall time complexity is 9 + 8 + ... + 1 = 9 * (9 + 1) / 2 = O(n ^ 2).

The solution is to memory the # of nodes of the tree. But this will also increase the space complexity. Hence, I think inorder traversal is better than your 1st solution.