```
class Solution {
public:
TreeNode *ans;
int find(TreeNode *r, TreeNode *p, TreeNode *q) {
int ret=0;
if (r==p || r==q) ret++;
if (r!=NULL)
if ((ret += find(r->left, p, q) + find(r->right, p, q)) == 2)
if (ans==NULL)
ans=r;
return ret;
}
TreeNode* lowestCommonAncestor(TreeNode* root, TreeNode* p, TreeNode* q) {
ans=NULL;
find(root, p, q);
return ans;
}
};
```