Here's my implementation.

```
public int kthSmallest(TreeNode root, int k) {
TreeNode result = new TreeNode(k); // using tree.val to store k, and tree.left to point to kth element
recurse(root, result);
return result.left.val;
}
public void recurse(TreeNode root, TreeNode result) {
if(root == null || result.left != null) return;
recurse(root.left, result);
result.val = result.val - 1;
if(result.val == 0) {
result.left = root; // kth element
}
recurse(root.right, result);
}
```