What would be the answer to the followup question ( "What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?") ?
I know how to modify the underlying BST structure to support order statistics in O(lg n) time, but I'm not sure how to ONLY modify the actual kthSmallest routine to keep up with insertions and deletions.
Answer to the followup question?


For each node, we store the number of nodes in its subtree, including itself. Then we just search for a node with k  1 elements before it and return its value. Insert and delete would update these counts and would remain O(log n).
class Solution(object): def kthSmallest(self, root, k): """ :type root: TreeNode :type k: int :rtype: int """ def process(root): if not root: return 0 left_count = process(root.left) right_count = process(root.right) total = left_count + right_count + 1 root.val = (root.val, total) return total def find_kth(root, k): left_count = root.left.val[1] if root.left else 0 right_count = root.right.val[1] if root.right else 0 if k <= left_count: return find_kth(root.left, k) elif k == left_count + 1: return root.val[0] else: return find_kth(root.right, k  left_count  1) process(root) return find_kth(root, k)