# 8ms Java solution with comments

• Surprisingly deleting node from BST is not a trivial case, there're edge cases that could screw up deleting logic.

The delete itself is iterative, finding node is recursive

PS: the solution below doesn't delete by swapping values, it performs pointers changing. So it should handle duplicates (depending on how it should be handled in the findNode method)

``````TreeNode node = null;
TreeNode parent = null;

public TreeNode deleteNode(TreeNode root, int key) {
// after calling findNode, node and parent should be populated
findNode(root, key, null);
// if can't find node in tree, return root
if(node == null)
return root;
// if node is a leaf
if(node.left == null && node.right == null) {
// only single node in the tree
if(parent == null)
return null;
else if(parent.left == node)
parent.left = null;
else
parent.right = null;
return root;
}
// node has only one child
else if(node.left == null || node.right == null) {
// node is root
if(parent == null)
root = node.right == null ? node.left : node.right;
else {
if(parent.right == node)
parent.right = node.right == null ? node.left : node.right;
else
parent.left = node.right == null ? node.left : node.right;
}
return root;
}
// node has 2 subtrees
else
return remove2ChildNode(root);
}

TreeNode remove2ChildNode(TreeNode root) {
TreeNode tmpN = getLeftMostOfRight();
// tmpN takes over right and left children of node to be deleted
tmpN.right = node.right == tmpN ? tmpN.right : node.right;
tmpN.left = node.left == tmpN ? tmpN.left : node.left;
if(parent != null) {
if(parent.left == node)
parent.left = tmpN;
else
parent.right = tmpN;
} else
root = tmpN;
node.left = node.right = null;
return root;
}

TreeNode getLeftMostOfRight() {
// get left most node of right subtree
TreeNode tmpP = node.right;
TreeNode tmpN = tmpP.left;
// go left as far as we can
while(tmpN != null && tmpN.left != null) {
tmpP = tmpN;
tmpN = tmpN.left;
}
// if no left subtree was there, then return right child of node
if(tmpN == null)
tmpN = tmpP;
else
tmpP.left = tmpN.right;
return tmpN;
}

// find node we're looking for and its parent
void findNode(TreeNode node, int key, TreeNode parent) {
if(node == null)
return;
else if(node.val == key) {
this.node = node;
this.parent = parent;
} else if(node.val < key)
findNode(node.right, key, node);
else
findNode(node.left, key, node);
}
``````

