Here is my Java solution. I'm wondering if it could be made simpler.

The idea is to keep three pointers: one for searching nodes that are smaller than x, a parent pointer of the search pointer, and a third pointer for where the new destination is supposed to be (i.e. the parent of the first encountered element that is >= x). I had to add an if statement to check if the first element is already >= x. If that's the case the third pointer is null, so I had to manipulate the head node instead.

```
public class Solution {
public ListNode partition(ListNode head, int x) {
if(head == null){
return null;
}
ListNode node = head;
ListNode parent = null;
ListNode before = null;
//search for the first node with val >= x
while(node != null && node.val < x){
before = node;
parent = node;
node = node.next;
}
while(node != null){
if(node.val < x){
if(before != null){
parent.next = node.next;
node.next = before.next;
before.next = node;
node = parent.next;
before = before.next;
}else{
//if the first node is already >= x, hence there is no parent node, then manipulate the head node
parent.next = node.next;
node.next = head;
head = node;
before = head;
node = parent.next;
}
}else{
parent = node;
node = node.next;
}
}
return head;
}
}
```