Given a Binary Search Tree, Find the distance between 2 nodes


  • 0
    O

    Write a function that given a BST, it will return the distance between 2 nodes.

    For example, given this tree

             A
            / \
           B   C
          / \   \
         D   E   F
        /         \
       G           H
    

    The distance between G and F is 3: [G -> D -> B -> E]

    The distance between E and H is 6: [G -> D -> B -> A -> C -> F -> H]


  • 0
    Y

    The examples seem a bit wrong. The path from G to F should be: G->D->B->A->C->F, and the path from E to H should be E->B->A->C->F->H.

    The obvious solution uses bottom-up approach, which takes worst case O(n) time. A similar problem on Leetcode is a question called "maximum path sum" I think.

    Pseudocode (without much verification):

    def findDistance(TreeNode root, TreeNode p, TreeNode q):
    
    if root == null or p == q:
    	return 0
    
    distance_left = findDistance(root.left, p, q)
    distance_right = findDistance(root.right, p, q)
    
    if (distance_left > 0 and distance_right > 0):
    	return distance_left + distance_right
    if (distance_left > 0 and root is p or q):
    	return distance_left
    if (distance_right > 0 and root is p or q):
    	return distance_right
    if (distance_left == 0 and distance_right == 0):
    	return 0 if root is not p and not q
    	return 1 if root is p or q
    else:
    	return max(distance_left, distance_right) + 1
    

  • 0
    A

    A couple of flaws in your algorithm: 1) The last line wrongly assumes the current node i.e. root in method argument is always on the resulting path, hence your algorithm will yield 3 if the input is D and E in the above tree. 2) Your algorithm cannot handle the case where one of the target node is not in the tree. The solution: you need to record the number of matches found as well.


  • 0
    D

    You have to modify the solution for Least common ancestor.


  • 0
    Y

    Last line returns the current tentative distance if only one of the nodes is found. Since I'm using a bottom-up approach, if only one node is found in a subtree and the other is not found, in the end the path must go through the current node (root). I don't see a problem returning 3 if the input is D and E (D->B->E). My solution does not handle the case when a node or both nodes are not in the tree.


  • 0
    J

    C# solution: to first find the LCA, and then calculate the level of root->LCA/root->node1/root->node2, then distance = root->node1 + root->node2 - 2 * root->LCA;

        class MyTreeNode
        {
            public int Data { get; set; }
            public MyTreeNode Left { get; set; }
            public MyTreeNode Right { get; set; }
            public MyTreeNode(int data)
            {
                this.Data = data;
            }
        }
    
        class QTwoNodeDis
        {
            public static int Distance(MyTreeNode root, MyTreeNode node1, MyTreeNode node2)
            {
                var node = FindLCA(root, node1, node2);
                int distLCA = FindLevel(root, node);
                int dist1 = FindLevel(root, node1);
                int dist2 = FindLevel(root, node2);
    
                return dist1 + dist2 - 2 * distLCA;
            }
    
            private static MyTreeNode FindLCA(MyTreeNode root, MyTreeNode node1, MyTreeNode node2)
            {
                if (root == null) return null;
                //找到两个节点中的一个就返回
                if (root.Data == node1.Data || root.Data== node2.Data)
                    return root;
                //分别在左右子树查找两个节点
    
                MyTreeNode left_lca = FindLCA(root.Left, node1, node2);
                MyTreeNode right_lca = FindLCA(root.Right, node1, node2);
    
                if (left_lca != null && right_lca != null)
                    //此时说明,两个节点肯定是分别在左右子树中,当前节点比为LCA
                    return root;
                return left_lca != null ? left_lca : right_lca;
            }
    
            private static int FindLevel(MyTreeNode root, MyTreeNode node)
            {
                if (root == null)
                    return -1;
                if(root.Data == node.Data)
                    return 0;
    
                int level = FindLevel(root.Left, node);
    
                if (level == -1)
                    level = FindLevel(root.Right, node);
    
                if(level != -1)
                    return level + 1;
    
                return -1;
            }
        }
    

  • 0
    M

    I don't think the example is a BST...

    If the problem input is a BST, the solution above is OK, but if it is just a Binary Tree, what should we do?


  • 0
    P

    @mikealive
    I think this will work if its BST. did you think of any case where this code will fail ?

    
    import java.util.*;
    public class DistanceBetweenBSTNode{
    
    
    public int distance(TreeNode root, TreeNode node1,TreeNode node2){
    	if(root == null || node1 == node2)
    		return 0;
    	if(root.val >= node1.val && root.val <= node2.val)
    		return distance(root,node1) + distance(root,node2);
    	else if(root.val < node1.val && root.val < node2.val)
    		return distance(root.right,node1,node2);
    	else{
    		return distance(root.left,node1,node2);
    	}
    }
    
    public int distance(TreeNode root, TreeNode node1){
    	if(root == node1)
    		return 0;
    	if(root.val < node1.val){
    		return 1+ distance(root.right,node1);
    	}
    	else{
    		return 1+ distance(root.left,node1);
    	}
    }
    
    
    public static void main(String[] args){
    	DistanceBetweenBSTNode dbb = new DistanceBetweenBSTNode();
    	TreeNode root = new TreeNode(5);
    	root.left = new TreeNode(3);
    	root.right =  new TreeNode(10);
    	root.right.right = new TreeNode(11);
    	root.right.left = new TreeNode(8);
    	//dbb.inorder(root);
    	System.out.println(dbb.distance(root,root,root.right.right));
    	
    }
    }
     

  • 0
    L

    @prashantkadam88-gmail.com

    I like you solution, but there is an error in this line:

    if(root.val >= node1.val && root.val <= node2.val)

    The fix is:

    if(root.val >= node1.val && root.val <= node2.val) && (root.val <= node1.val && root.val >= node2.val)

    You should include both cases.

    Grettings!


  • 0
    P

    @lupinekaupz
    Thanks.
    Actually we can use Math.min to decide on left or right path.


  • 0
    H
    	
    	static int findDistanceForNodes(Node n1, Node n2, Node root) {
    		if(root==null || n1 ==null || n2==null)
    			return 0;
    		Temp temp = new Temp();
    		int n = findPathForNodes(root,n1,n2,temp)-1;
    		return (temp.foundNode1&&temp.foundNode2)?n:0;
    	}
    
    	private static int findPathForNodes(Node node, Node n1, Node n2,Temp temp) {
    		if(node==null || temp.foundLen)
    			return 0;
    		if(node==n1) 
    			temp.foundNode1 = true;
    		if(node==n2)
    			temp.foundNode2 = true;
    		int left = findPathForNodes(node.left,n1,n2,temp);
    		int right = findPathForNodes(node.right,n1,n2,temp);
    		if((left>0 && right>0) || ((left>0 || right>0) && (node==n1 || node==n2))) {
    			temp.foundLen = true;
    			return left+right+1;
    		}
    		if((left>0 || right>0 || node==n1 || node==n2) && !temp.foundLen)
    			return left+right+1;
    		
    		return Math.max(left, right);
    	}
    }
    
    class Temp {
    	boolean foundNode1;
    	boolean foundNode2;
    	boolean foundLen;
    }```

  • 0
    K

    Got inspired by @yuhao5, fixed couple of flaws.

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    # Find distance between two given keys of a Binary Tree
    # Author: zhang.xiaoye@gmail.com
    
    class TreeNode(object):
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
    
    class BTree(object):
    
        def __init__(self):
            self.root = TreeNode(1)
            node = TreeNode(2)
            self.root.left = node
            node = TreeNode(3)
            self.root.right = node
            node = TreeNode(4)
            self.root.left.left = node
            node = TreeNode(5)
            self.root.left.right = node
            node = TreeNode(6)
            self.root.right.left = node
            node = TreeNode(7)
            self.root.right.right = node
            node = TreeNode(8)
            self.root.right.left.right = node
    
        def getBTHead(self):
            return self.root
    
    class BTreeDistance(object):
    
        dis = 0
    
        def findDistance(self, root, p, q):
            if not root:
                return 0
    
            ldepth = self.findDistance(root.left, p, q)
            rdepth = self.findDistance(root.right, p, q)
    
            if root.val == p or root.val == q:
                if max(ldepth, rdepth) != 0:
                    # 给定节点为上下级关系,如4,2
                    self.dis = max(ldepth, rdepth)
                # level + 1
                return 1
            if ldepth != 0 and rdepth == 0 or ldepth == 0 and rdepth != 0:
                # 如果左右节点有一个level>1,增+1 返回
                return max(ldepth, rdepth) + 1
    
            if ldepth > 0 and rdepth > 0:
                # 如果左右节level大于1,则相加,保存结果
                self.dis = ldepth + rdepth
    
            return 0
        
        def getResult(self, root, p, q):
            self.findDistance(root, p, q)
            return self.dis
    
    tree = BTree()
    root = tree.getBTHead()
    
    bd = BTreeDistance()
    print bd.getResult(root, 4, 5)
    

  • 0

    I thought about the algorithm and it should fit in 60 minutes time range, and also need to make the code work, pass a few test cases.

    First, write recursive function instead of iterative. Code is simple.

    Secondly, recursive is a depth first search algorithm, treat it as a graph search algorithm.

    Third, I like to argue that "find lowest common ancestor" maybe is a more complicated algorithm. Better not relate to the algorithm "find lowest common ancestor".

    To find a path from root node to search node, the function is designed to find one node a time. The time complexity should be the same to find two nodes one time.

    Time complexity is O(n), n is the total nodes of binary tree. Use preorder traversal, visit root first, then visit left and right child.

    Here is my C# practice code.

    using System;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Linq;
    using System.Text;
    using System.Threading.Tasks;
    
    namespace BinarySearchTreeTwoNodesDistance
    {
        class Program
        {
            internal class Node
            {
                public int Value { get; set; }
                public Node Left { get; set; }
                public Node Right { get; set; }
    
                public Node(int number)
                {
                    Value = number;
                }
            }
    
            static void Main(string[] args)
            {
                // calculate two nodes distance
                RunTestcase();
            }
    
            /// <summary>
            /// inorder traversal - 1 2 3 4 5 6 7
            /// </summary>
            public static void RunTestcase()
            {
                var root = new Node(4);
                root.Left = new Node(2);
                root.Left.Left = new Node(1);
                root.Left.Right = new Node(3);
                root.Right = new Node(6);
                root.Right.Left = new Node(5);
                root.Right.Right = new Node(7);
    
                // distance should be 4 
                var distance = FindDistance(root, root.Left.Right, root.Right.Right);
                Debug.Assert(distance == 4);
    
                var distance2 = FindDistance(root, root.Left.Right, root.Left.Left);
                Debug.Assert(distance2 == 2); 
            }
    
            public static int FindDistance(Node root, Node p, Node q)
            {
                IList<Node> possiblePath_1 = new List<Node>();
                IList<Node> possiblePath_2 = new List<Node>();
    
                IList<Node> searchPath_1 = new List<Node>();
                IList<Node> searchPath_2 = new List<Node>();
    
                FindPath(root, p, possiblePath_1,ref searchPath_1);
                FindPath(root, q, possiblePath_2,ref searchPath_2);
    
                if (searchPath_1.Count == 0 || searchPath_2.Count == 0)
                {
                    return 0; 
                }
    
                // find first node not equal 
                int index = 0;
                int length1 = searchPath_1.Count;
                int length2 = searchPath_2.Count;
    
                while (index < Math.Min(length1, length2) &&
                    searchPath_1[index] == searchPath_2[index])
                {
                    index++;
                }
    
                //(length1 - 1) + (length2 - 1) - (2 * (index - 1))
                return length1 + length2 - 2 * index;
            }
    
            /// <summary>
            /// Do a preorder search for the node
            /// </summary>
            /// <param name="root"></param>
            /// <param name="search"></param>
            /// <param name="possiblePath"></param>
            /// <param name="searchPath"></param>
            public static void FindPath(Node root, Node search, IList<Node> possiblePath, ref IList<Node> searchPath)
            {
                if (root == null || searchPath.Count > 0)
                {
                    return;
                }
    
                if (root == search)
                {
                    searchPath = possiblePath;
                    searchPath.Add(search);
                    return;
                }
    
                possiblePath.Add(root);
                IList<Node> leftBranch  = new List<Node>(possiblePath);
                IList<Node> rightBranch = new List<Node>(possiblePath);
    
                FindPath(root.Left, search, leftBranch, ref searchPath);
                FindPath(root.Right, search, rightBranch,ref  searchPath);
            }
        }
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.