Java code based on @yfcheng 's solution


  • 0
    M
    public class Solution {
    public class AugTreeNode{
        int val;
        AugTreeNode left;
        AugTreeNode right;
        int subTreeSize;
        AugTreeNode(int x) { 
            val = x;
            subTreeSize = 1;
        }
        
        public void calcSubTreeSize(){
            if(left == null && right == null){
                subTreeSize = 1;
                return;
            }
            if(left == null){
                subTreeSize = right.subTreeSize + 1;
                return;
            }
            if(right == null){
                subTreeSize = left.subTreeSize + 1;
                return;
            }
            subTreeSize = left.subTreeSize + right.subTreeSize + 1;
        }
    }
    
    public AugTreeNode copy(TreeNode node){
        if(node == null) return null;
        AugTreeNode augNode = new AugTreeNode(node.val);
        augNode.left = copy(node.left);
        augNode.right = copy(node.right);
        augNode.calcSubTreeSize();
        return augNode;
    }
    
    public int kthSmallest(TreeNode root, int k) {
        AugTreeNode augRoot = copy(root);
        return getKth(augRoot,k);
    }
    
    public int getKth(AugTreeNode root, int k ){
        int currentRank = (root.left == null)? 1:root.left.subTreeSize +1;
        if(k == currentRank) return root.val;
        if(currentRank > k){
            return getKth(root.left,k);
        }
        return getKth(root.right,k - currentRank);
    }
    

    }


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.