Java Easy to understand solution


  • 0
    Z

    If the current node value is in correct place, then all values on the left branch must be smaller, and all values on the right branch must be greater than it. Based on that concept, we can search for a wrong value on the right branch and left branch separately.

    • If found wrong nodes in both sides, then swap them with each other.
    • If wrong node found in only one of them, then the root node is itself at incorrect place. Therefore swap the wrong node with root.
    • If no wrong node found in both branches. Then current node is at proper place. Therefore search for wrong node starting from left and right childs.

    That's it! That algorithm works for O(n^2) time, therefore it is slower than most of other solutions, but it is easy to understand.

    public class Solution {
        public void recoverTree(TreeNode root) {
            if(root == null) return;
            if(root.left == null && root.right == null) return;
            TreeNode rightWrong = getRightWrongTreeNode(root.right, root.val);
            TreeNode leftWrong = getLeftWrongTreeNode(root.left, root.val);
            
            if(rightWrong != null && leftWrong != null){
                System.out.println("ok");
                swap(rightWrong, leftWrong);
                return;
            }
            
            if(leftWrong != null){
                System.out.println("swapping "+root.val+" with "+leftWrong.val);
                swap(root, leftWrong);
                return;
            }
            
            if(rightWrong != null){
                swap(root, rightWrong);
                return;
            }
            
            recoverTree(root.left);
            recoverTree(root.right);
        }
        
        public TreeNode getRightWrongTreeNode(TreeNode node, int val){
            if(node == null) return null;
            if(node.val < val){
                TreeNode x = getRightWrongTreeNode(node.right, node.val);
                if(x != null) return x;
                return node;
            }
            TreeNode right = getRightWrongTreeNode(node.right, val);
            if(right != null) return right;
            TreeNode left = getRightWrongTreeNode(node.left, val);
            return left;
        }
        
        public TreeNode getLeftWrongTreeNode(TreeNode node, int val){
            if(node == null) return null;
            if(node.val > val){
                TreeNode x = getLeftWrongTreeNode(node.left, node.val);
                if(x != null) return x;
                return node;
            }
            TreeNode right = getLeftWrongTreeNode(node.right, val);
            if(right != null) return right;
            TreeNode left = getLeftWrongTreeNode(node.left, val);
            return left;
        }
        
        public void swap(TreeNode node1, TreeNode node2){
            int tmp = node1.val;
            node1.val = node2.val;
            node2.val = tmp;
        }
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.