I think this below solution is very easy to understand:

```
void fix(TreeLinkNode *r1, TreeLinkNode *r2){
if(r1 != NULL && r2 != NULL){
r1->next = r2;
fix(r1->right, r2->left);
fix(r1->left, r1->right);
fix(r2->left, r2->right);
}
}
void connect(TreeLinkNode *root) {
if(root != NULL)
fix(root->left, root->right);
}
```

Looking to a parent, we link the parent->left->next with parent->right. We can do it by **fix(r1->right, r2->left)**. But it will not work properly, since we need to solve each subtree also. So we call for each root, their subtrees with: **fix(r1->left, r1->right); fix(r2->left, r2->right);**

Note that r1 and r2 represents the left and right child.