```
void flatten(TreeNode *root) {
dfs(root);
}
TreeNode *dfs(TreeNode *root) {
if (!root) return root;
TreeNode *cur = root;
TreeNode *left = dfs(root->left);
TreeNode *right = dfs(root->right);
if (left) {
cur->left = NULL;
cur->right = left; // set right pointer of the root node
while (cur->right) cur = cur->right; // get the last node of the left side
}
cur->right = right; // right pointer of the last node of the left side
return root;
}
```