```
class Solution {
public:
void flatten(TreeNode* root) {
if (NULL == root) return;
recursiveFlatten(root);
}
TreeNode * recursiveFlatten(TreeNode* root) {
if (root->left == NULL && root->right == NULL) return root;
TreeNode * tmpLeft = root->left;
TreeNode * tmpRight = root->right;
if (tmpLeft != NULL && tmpRight != NULL) {
root->left = NULL;
root->right = tmpLeft;
TreeNode * leftMost = recursiveFlatten(tmpLeft);
leftMost->right = tmpRight;
return recursiveFlatten(tmpRight);
} else if (tmpLeft != NULL) {
root->left = NULL;
root->right = tmpLeft;
return recursiveFlatten(tmpLeft);
} else if (tmpRight != NULL) {
return recursiveFlatten(tmpRight);
}
}
};
```