```
TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
TreeNode * s = NULL;
dfs(root, p, q, &s);
return s;
}
int dfs(TreeNode * root, TreeNode * p, TreeNode * q, TreeNode ** s) {
if (root == NULL) return 0;
int left = dfs(root->left, p, q, s);
if (left == -1) return -1;
int right = dfs(root->right, p, q, s);
if (right == -1) return -1;
if ((left + right + (root == p || root == q)) == 2) {
*s = root;
return -1;
}
else return left + right + (root == p || root == q);
}
```