```
TreeNode *flat(TreeNode *root)
{
if(root==NULL) return NULL;
root->left = flat(root->left);
root->right = flat(root->right);
if(root->left!=NULL){
TreeNode *temp = root->right;
if(root->right){
root->right = root->left;
TreeNode *p = root->right;
while(p->right) p = p->right;
p->right = temp;
}else{
root->right = root->left;
}
root->left = NULL;
}
return root;
}
class Solution {
public:
void flatten(TreeNode* root) {
if(!root) return;
root = flat(root);
}
};
```