Given a Binary Search Tree, Find the distance between 2 nodes


The examples seem a bit wrong. The path from G to F should be: G>D>B>A>C>F, and the path from E to H should be E>B>A>C>F>H.
The obvious solution uses bottomup approach, which takes worst case O(n) time. A similar problem on Leetcode is a question called "maximum path sum" I think.
Pseudocode (without much verification):
def findDistance(TreeNode root, TreeNode p, TreeNode q): if root == null or p == q: return 0 distance_left = findDistance(root.left, p, q) distance_right = findDistance(root.right, p, q) if (distance_left > 0 and distance_right > 0): return distance_left + distance_right if (distance_left > 0 and root is p or q): return distance_left if (distance_right > 0 and root is p or q): return distance_right if (distance_left == 0 and distance_right == 0): return 0 if root is not p and not q return 1 if root is p or q else: return max(distance_left, distance_right) + 1

A couple of flaws in your algorithm: 1) The last line wrongly assumes the current node i.e. root in method argument is always on the resulting path, hence your algorithm will yield 3 if the input is D and E in the above tree. 2) Your algorithm cannot handle the case where one of the target node is not in the tree. The solution: you need to record the number of matches found as well.


Last line returns the current tentative distance if only one of the nodes is found. Since I'm using a bottomup approach, if only one node is found in a subtree and the other is not found, in the end the path must go through the current node (root). I don't see a problem returning 3 if the input is D and E (D>B>E). My solution does not handle the case when a node or both nodes are not in the tree.

C# solution: to first find the LCA, and then calculate the level of root>LCA/root>node1/root>node2, then distance = root>node1 + root>node2  2 * root>LCA;
class MyTreeNode { public int Data { get; set; } public MyTreeNode Left { get; set; } public MyTreeNode Right { get; set; } public MyTreeNode(int data) { this.Data = data; } } class QTwoNodeDis { public static int Distance(MyTreeNode root, MyTreeNode node1, MyTreeNode node2) { var node = FindLCA(root, node1, node2); int distLCA = FindLevel(root, node); int dist1 = FindLevel(root, node1); int dist2 = FindLevel(root, node2); return dist1 + dist2  2 * distLCA; } private static MyTreeNode FindLCA(MyTreeNode root, MyTreeNode node1, MyTreeNode node2) { if (root == null) return null; //找到两个节点中的一个就返回 if (root.Data == node1.Data  root.Data== node2.Data) return root; //分别在左右子树查找两个节点 MyTreeNode left_lca = FindLCA(root.Left, node1, node2); MyTreeNode right_lca = FindLCA(root.Right, node1, node2); if (left_lca != null && right_lca != null) //此时说明，两个节点肯定是分别在左右子树中，当前节点比为LCA return root; return left_lca != null ? left_lca : right_lca; } private static int FindLevel(MyTreeNode root, MyTreeNode node) { if (root == null) return 1; if(root.Data == node.Data) return 0; int level = FindLevel(root.Left, node); if (level == 1) level = FindLevel(root.Right, node); if(level != 1) return level + 1; return 1; } }


@mikealive
I think this will work if its BST. did you think of any case where this code will fail ?import java.util.*; public class DistanceBetweenBSTNode{ public int distance(TreeNode root, TreeNode node1,TreeNode node2){ if(root == null  node1 == node2) return 0; if(root.val >= node1.val && root.val <= node2.val) return distance(root,node1) + distance(root,node2); else if(root.val < node1.val && root.val < node2.val) return distance(root.right,node1,node2); else{ return distance(root.left,node1,node2); } } public int distance(TreeNode root, TreeNode node1){ if(root == node1) return 0; if(root.val < node1.val){ return 1+ distance(root.right,node1); } else{ return 1+ distance(root.left,node1); } } public static void main(String[] args){ DistanceBetweenBSTNode dbb = new DistanceBetweenBSTNode(); TreeNode root = new TreeNode(5); root.left = new TreeNode(3); root.right = new TreeNode(10); root.right.right = new TreeNode(11); root.right.left = new TreeNode(8); //dbb.inorder(root); System.out.println(dbb.distance(root,root,root.right.right)); } }