```
class Solution:
# @param root, a tree node
# @return nothing, do it in place
def flatten(self, root):
if root == None:
return root
def flat(root):
if root.left == None and root.right == None:
return root
if root.left != None and root.right != None:
left_tail = flat(root.left)
right_tail = flat(root.right)
left_tail.right = root.right
root.right = root.left
root.left = None
elif root.left == None:
right_tail = flat(root.right)
else: #root.right == None
left_tail = flat(root.left)
root.right = root.left
right_tail = left_tail
root.left = None
return right_tail
flat(root)
```