What would be the answer to the follow-up question ( "What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?") ?
I know how to modify the underlying BST structure to support order statistics in O(lg n) time, but I'm not sure how to ONLY modify the actual kthSmallest routine to keep up with insertions and deletions.
For each node, we store the number of nodes in its subtree, including itself. Then we just search for a node with k - 1 elements before it and return its value. Insert and delete would update these counts and would remain O(log n).
class Solution(object): def kthSmallest(self, root, k): """ :type root: TreeNode :type k: int :rtype: int """ def process(root): if not root: return 0 left_count = process(root.left) right_count = process(root.right) total = left_count + right_count + 1 root.val = (root.val, total) return total def find_kth(root, k): left_count = root.left.val if root.left else 0 right_count = root.right.val if root.right else 0 if k <= left_count: return find_kth(root.left, k) elif k == left_count + 1: return root.val else: return find_kth(root.right, k - left_count - 1) process(root) return find_kth(root, k)