Given a Binary Search Tree, Find the distance between 2 nodes


  • 1
    O

    Write a function that given a BST, it will return the distance between 2 nodes.

    For example, given this tree

             A
            / \
           B   C
          / \   \
         D   E   F
        /         \
       G           H
    

    The distance between G and F is 3: [G -> D -> B -> E]

    The distance between E and H is 6: [G -> D -> B -> A -> C -> F -> H]


  • 1
    Y

    The examples seem a bit wrong. The path from G to F should be: G->D->B->A->C->F, and the path from E to H should be E->B->A->C->F->H.

    The obvious solution uses bottom-up approach, which takes worst case O(n) time. A similar problem on Leetcode is a question called "maximum path sum" I think.

    Pseudocode (without much verification):

    def findDistance(TreeNode root, TreeNode p, TreeNode q):
    
    if root == null or p == q:
    	return 0
    
    distance_left = findDistance(root.left, p, q)
    distance_right = findDistance(root.right, p, q)
    
    if (distance_left > 0 and distance_right > 0):
    	return distance_left + distance_right
    if (distance_left > 0 and root is p or q):
    	return distance_left
    if (distance_right > 0 and root is p or q):
    	return distance_right
    if (distance_left == 0 and distance_right == 0):
    	return 0 if root is not p and not q
    	return 1 if root is p or q
    else:
    	return max(distance_left, distance_right) + 1
    

  • 0
    A

    A couple of flaws in your algorithm: 1) The last line wrongly assumes the current node i.e. root in method argument is always on the resulting path, hence your algorithm will yield 3 if the input is D and E in the above tree. 2) Your algorithm cannot handle the case where one of the target node is not in the tree. The solution: you need to record the number of matches found as well.


  • 0
    D

    You have to modify the solution for Least common ancestor.


  • 0
    Y

    Last line returns the current tentative distance if only one of the nodes is found. Since I'm using a bottom-up approach, if only one node is found in a subtree and the other is not found, in the end the path must go through the current node (root). I don't see a problem returning 3 if the input is D and E (D->B->E). My solution does not handle the case when a node or both nodes are not in the tree.


  • 0
    J

    C# solution: to first find the LCA, and then calculate the level of root->LCA/root->node1/root->node2, then distance = root->node1 + root->node2 - 2 * root->LCA;

        class MyTreeNode
        {
            public int Data { get; set; }
            public MyTreeNode Left { get; set; }
            public MyTreeNode Right { get; set; }
            public MyTreeNode(int data)
            {
                this.Data = data;
            }
        }
    
        class QTwoNodeDis
        {
            public static int Distance(MyTreeNode root, MyTreeNode node1, MyTreeNode node2)
            {
                var node = FindLCA(root, node1, node2);
                int distLCA = FindLevel(root, node);
                int dist1 = FindLevel(root, node1);
                int dist2 = FindLevel(root, node2);
    
                return dist1 + dist2 - 2 * distLCA;
            }
    
            private static MyTreeNode FindLCA(MyTreeNode root, MyTreeNode node1, MyTreeNode node2)
            {
                if (root == null) return null;
                //找到两个节点中的一个就返回
                if (root.Data == node1.Data || root.Data== node2.Data)
                    return root;
                //分别在左右子树查找两个节点
    
                MyTreeNode left_lca = FindLCA(root.Left, node1, node2);
                MyTreeNode right_lca = FindLCA(root.Right, node1, node2);
    
                if (left_lca != null && right_lca != null)
                    //此时说明,两个节点肯定是分别在左右子树中,当前节点比为LCA
                    return root;
                return left_lca != null ? left_lca : right_lca;
            }
    
            private static int FindLevel(MyTreeNode root, MyTreeNode node)
            {
                if (root == null)
                    return -1;
                if(root.Data == node.Data)
                    return 0;
    
                int level = FindLevel(root.Left, node);
    
                if (level == -1)
                    level = FindLevel(root.Right, node);
    
                if(level != -1)
                    return level + 1;
    
                return -1;
            }
        }
    

  • 0
    M

    I don't think the example is a BST...

    If the problem input is a BST, the solution above is OK, but if it is just a Binary Tree, what should we do?


  • 0
    P

    @mikealive
    I think this will work if its BST. did you think of any case where this code will fail ?

    
    import java.util.*;
    public class DistanceBetweenBSTNode{
    
    
    public int distance(TreeNode root, TreeNode node1,TreeNode node2){
    	if(root == null || node1 == node2)
    		return 0;
    	if(root.val >= node1.val && root.val <= node2.val)
    		return distance(root,node1) + distance(root,node2);
    	else if(root.val < node1.val && root.val < node2.val)
    		return distance(root.right,node1,node2);
    	else{
    		return distance(root.left,node1,node2);
    	}
    }
    
    public int distance(TreeNode root, TreeNode node1){
    	if(root == node1)
    		return 0;
    	if(root.val < node1.val){
    		return 1+ distance(root.right,node1);
    	}
    	else{
    		return 1+ distance(root.left,node1);
    	}
    }
    
    
    public static void main(String[] args){
    	DistanceBetweenBSTNode dbb = new DistanceBetweenBSTNode();
    	TreeNode root = new TreeNode(5);
    	root.left = new TreeNode(3);
    	root.right =  new TreeNode(10);
    	root.right.right = new TreeNode(11);
    	root.right.left = new TreeNode(8);
    	//dbb.inorder(root);
    	System.out.println(dbb.distance(root,root,root.right.right));
    	
    }
    }
     

  • 0
    L

    @prashantkadam88-gmail.com

    I like you solution, but there is an error in this line:

    if(root.val >= node1.val && root.val <= node2.val)

    The fix is:

    if(root.val >= node1.val && root.val <= node2.val) && (root.val <= node1.val && root.val >= node2.val)

    You should include both cases.

    Grettings!


  • 0
    P

    @lupinekaupz
    Thanks.
    Actually we can use Math.min to decide on left or right path.


  • 0
    H
    	
    	static int findDistanceForNodes(Node n1, Node n2, Node root) {
    		if(root==null || n1 ==null || n2==null)
    			return 0;
    		Temp temp = new Temp();
    		int n = findPathForNodes(root,n1,n2,temp)-1;
    		return (temp.foundNode1&&temp.foundNode2)?n:0;
    	}
    
    	private static int findPathForNodes(Node node, Node n1, Node n2,Temp temp) {
    		if(node==null || temp.foundLen)
    			return 0;
    		if(node==n1) 
    			temp.foundNode1 = true;
    		if(node==n2)
    			temp.foundNode2 = true;
    		int left = findPathForNodes(node.left,n1,n2,temp);
    		int right = findPathForNodes(node.right,n1,n2,temp);
    		if((left>0 && right>0) || ((left>0 || right>0) && (node==n1 || node==n2))) {
    			temp.foundLen = true;
    			return left+right+1;
    		}
    		if((left>0 || right>0 || node==n1 || node==n2) && !temp.foundLen)
    			return left+right+1;
    		
    		return Math.max(left, right);
    	}
    }
    
    class Temp {
    	boolean foundNode1;
    	boolean foundNode2;
    	boolean foundLen;
    }```

  • 0
    K

    Got inspired by @yuhao5, fixed couple of flaws.

    #!/usr/bin/env python
    # -*- coding: utf-8 -*-
    # Find distance between two given keys of a Binary Tree
    # Author: zhang.xiaoye@gmail.com
    
    class TreeNode(object):
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
    
    class BTree(object):
    
        def __init__(self):
            self.root = TreeNode(1)
            node = TreeNode(2)
            self.root.left = node
            node = TreeNode(3)
            self.root.right = node
            node = TreeNode(4)
            self.root.left.left = node
            node = TreeNode(5)
            self.root.left.right = node
            node = TreeNode(6)
            self.root.right.left = node
            node = TreeNode(7)
            self.root.right.right = node
            node = TreeNode(8)
            self.root.right.left.right = node
    
        def getBTHead(self):
            return self.root
    
    class BTreeDistance(object):
    
        dis = 0
    
        def findDistance(self, root, p, q):
            if not root:
                return 0
    
            ldepth = self.findDistance(root.left, p, q)
            rdepth = self.findDistance(root.right, p, q)
    
            if root.val == p or root.val == q:
                if max(ldepth, rdepth) != 0:
                    # 给定节点为上下级关系,如4,2
                    self.dis = max(ldepth, rdepth)
                # level + 1
                return 1
            if ldepth != 0 and rdepth == 0 or ldepth == 0 and rdepth != 0:
                # 如果左右节点有一个level>1,增+1 返回
                return max(ldepth, rdepth) + 1
    
            if ldepth > 0 and rdepth > 0:
                # 如果左右节level大于1,则相加,保存结果
                self.dis = ldepth + rdepth
    
            return 0
        
        def getResult(self, root, p, q):
            self.findDistance(root, p, q)
            return self.dis
    
    tree = BTree()
    root = tree.getBTHead()
    
    bd = BTreeDistance()
    print bd.getResult(root, 4, 5)
    

  • 0

    I thought about the algorithm and it should fit in 60 minutes time range, and also need to make the code work, pass a few test cases.

    First, write recursive function instead of iterative. Code is simple.

    Secondly, recursive is a depth first search algorithm, treat it as a graph search algorithm.

    Third, I like to argue that "find lowest common ancestor" maybe is a more complicated algorithm. Better not relate to the algorithm "find lowest common ancestor".

    To find a path from root node to search node, the function is designed to find one node a time. The time complexity should be the same to find two nodes one time.

    Time complexity is O(n), n is the total nodes of binary tree. Use preorder traversal, visit root first, then visit left and right child.

    Here is my C# practice code.

    using System;
    using System.Collections.Generic;
    using System.Diagnostics;
    using System.Linq;
    using System.Text;
    using System.Threading.Tasks;
    
    namespace BinarySearchTreeTwoNodesDistance
    {
        class Program
        {
            internal class Node
            {
                public int Value { get; set; }
                public Node Left { get; set; }
                public Node Right { get; set; }
    
                public Node(int number)
                {
                    Value = number;
                }
            }
    
            static void Main(string[] args)
            {
                // calculate two nodes distance
                RunTestcase();
            }
    
            /// <summary>
            /// inorder traversal - 1 2 3 4 5 6 7
            /// </summary>
            public static void RunTestcase()
            {
                var root = new Node(4);
                root.Left = new Node(2);
                root.Left.Left = new Node(1);
                root.Left.Right = new Node(3);
                root.Right = new Node(6);
                root.Right.Left = new Node(5);
                root.Right.Right = new Node(7);
    
                // distance should be 4 
                var distance = FindDistance(root, root.Left.Right, root.Right.Right);
                Debug.Assert(distance == 4);
    
                var distance2 = FindDistance(root, root.Left.Right, root.Left.Left);
                Debug.Assert(distance2 == 2); 
            }
    
            public static int FindDistance(Node root, Node p, Node q)
            {
                IList<Node> possiblePath_1 = new List<Node>();
                IList<Node> possiblePath_2 = new List<Node>();
    
                IList<Node> searchPath_1 = new List<Node>();
                IList<Node> searchPath_2 = new List<Node>();
    
                FindPath(root, p, possiblePath_1,ref searchPath_1);
                FindPath(root, q, possiblePath_2,ref searchPath_2);
    
                if (searchPath_1.Count == 0 || searchPath_2.Count == 0)
                {
                    return 0; 
                }
    
                // find first node not equal 
                int index = 0;
                int length1 = searchPath_1.Count;
                int length2 = searchPath_2.Count;
    
                while (index < Math.Min(length1, length2) &&
                    searchPath_1[index] == searchPath_2[index])
                {
                    index++;
                }
    
                //(length1 - 1) + (length2 - 1) - (2 * (index - 1))
                return length1 + length2 - 2 * index;
            }
    
            /// <summary>
            /// Do a preorder search for the node
            /// </summary>
            /// <param name="root"></param>
            /// <param name="search"></param>
            /// <param name="possiblePath"></param>
            /// <param name="searchPath"></param>
            public static void FindPath(Node root, Node search, IList<Node> possiblePath, ref IList<Node> searchPath)
            {
                if (root == null || searchPath.Count > 0)
                {
                    return;
                }
    
                if (root == search)
                {
                    searchPath = possiblePath;
                    searchPath.Add(search);
                    return;
                }
    
                possiblePath.Add(root);
                IList<Node> leftBranch  = new List<Node>(possiblePath);
                IList<Node> rightBranch = new List<Node>(possiblePath);
    
                FindPath(root.Left, search, leftBranch, ref searchPath);
                FindPath(root.Right, search, rightBranch,ref  searchPath);
            }
        }
    }

  • 1
    B

    dist(a,b) = Dist(root , a)+Dist(root,b) - 2* Dist(root , lca)

    int find_dist(Node *root,int a)
    {
        if(root==NULL) return -1;
        int dist=-1;
        if(root->data==a ||
        (dist=find_dist(root->left,a)) >=0||
        (dist=find_dist(root->right,a))>=0
        )
        return dist+1;
        
        return dist;
    }
    Node *lca(Node * root, int a,int b)
    {
        if(root==NULL)
        return root;
        if(root->data==a || root->data==b)
        return root;
        
        Node *L=lca(root->left,a,b);
        Node *R=lca(root->right,a,b);
        
        if(L!=NULL && R!=NULL)
        return root;
        if(L==NULL)
        return R;
        else
        return L;
        
    }
       
    int findDist(Node* root, int a, int b)
    {
        int n1=find_dist(root , a);
        int n2=find_dist(root,b);
        Node *temp=lca(root,a,b);
        int n3=find_dist(root,temp->data);
        return n1+n2-2*(n3);
    }
    

  • 1

    As it is a BST, we can find both of the nodes in time O(log n), and record their paths. Then use the paths record to trace the LCA. In this way, the time complexity is only O(log n) and space complexity also O(log n), considering the binary search tree is already built.


  • 0
    D

    @JohnsonJiang said in Given a Binary Search Tree, Find the distance between 2 nodes:

    then distance = root->node1 + root->node2 - 2 * root->LCA;

    Why don't you just calculate the distance from LCA to node1 and node2, which means distance = LCA->node1 + LCA->node2?


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.