# Super easy to understand Java iterative merge sort using O(1) space

• I found most iterative merge sort solutions not quite easy to read. So I decided to write my own. I tried to write it as readable as possible.

The basic idea is simple: bottom-up merge sort. Plenty of threads have covered this so I won't elaborate.
However, a few observations make it easier to build your code.

1. At the very beginning, we have sorted sublists of length 1. Then we have sorted sublists of length 2, 4, 8,.....2^i. When 2^i >= length of the input list, we are done. So we can use this fact for our main loop.

2. At the ith pass, only the last sorted sublist may have length less than 2^{i-1}. All the other sorted sublists have length exactly 2^{i-1}. So we can use this fact to get the heads of all sublists.
I wrote a function getListTail(head, len). Given the head of a sublist, we can walk (len - 1) steps to get to it's tail. The head of the next sublist is just tail.next.
Why do we want the tail? Because when we merge two sublists, we want to set tail.next = null.

3. When you have even number of sublists, you can always pair them. When it's odd, you can just leave the last one there, you will eventually merge it with other another sublist. At the last pass, we always merge two sorted sublists, although the second one may not have length 2^k for some k.

4. So for each pass, each time we want to get two sorted sublists and merge them until either (1) we are left with only one sublist, or (2) we are left nothing.

5. When merging two adjacent sublists. You want to set their tails point to null. Then after merging, you want to connect its head to the previously merged pairs and its tail to the rest of the unmerged sublists. So keep pointers to both ends. Also your merge function should return both the head and tail of the merged list.

The iterative solution consumes O(1) space since it doesn't use collections or recursion.
My implementation is slower than my recursive version. I think it's reasonable because of the tradeoff between space and speed (recursion use O(log n) stack memory, part of the memory is used to mark the boundary of sublists, but we have to "manually" find the boundary, doing setting to null and reconnecting staff)

`````` public class Solution {
class ListNodePair{
ListNode tail;
ListNodePair(ListNode h, ListNode t){   head = h;   tail = t;   }
}
int len = 0;
for(ListNode runner = head; runner != null; runner = runner.next)   len++;
int lenOfList = 1;              //length of sorted sublists: at first we start with 1
while(lenOfList < len){         //if greater than len, we have the whole list sorted
lenOfList *= 2;             //each iterate double the length of sorted sublists
}
}
//merge every two sublists of length lenOfList, assuming each sublist is already sorted
private ListNode mergeLayer(ListNode head, int lenOfList){
ListNode merge_tail = fakehead;   //merge_pail points to the tail of merged part of this layer
ListNode first;
ListNode first_tail;
ListNode second;
ListNode second_tail;
ListNodePair pair;
first_tail = getListTail(first, lenOfList);
second = first_tail.next;
if(second == null){                     //we have only 1 sorted sublist,
merge_tail.next = head;             //link the sorted part to the last sorted sublist
break;
}
second_tail = getListTail(second, lenOfList);
head = second_tail.next;    //now we have get two suitable subllists, point head to the rest

first_tail.next = null;
second_tail.next = null;
pair = merge(first, second);
merge_tail.next = pair.head;  //link the old sorted part to the newly sorted part
merge_tail = pair.tail;       //update the end of sorted part
}
}
//get the tail of the list with head 'head' and length 'len' (or at most len)
private ListNode getListTail(ListNode head, int len){
while(len > 1 && head.next != null){
len --;
}
}
//merge two sorted lists, return both the head and tail of the new list
private ListNodePair merge(ListNode l1, ListNode l2){
while(l1 != null && l2 != null){
if(l1.val < l2.val){
tail.next = l1;
l1 = l1.next;
}else{
tail.next = l2;
l2 = l2.next;
}
tail = tail.next;
}
tail.next = (l1 == null) ? l2 : l1;
while(tail.next != null)    tail = tail.next;
}
}``````

• Thanks for the detailed explanation!

• Thank you for your easy understand code!

• Thanks for the thorough explanation! Your code is already pretty easy to read, but I wrote out the Python version for others:

``````class Solution(object):
length = 0
while cur:
cur = cur.next
length += 1
lenOfList = 1
while lenOfList < length:
lenOfList *= 2

first_tail = self.getListTail(first, lenOfList)
second = first_tail.next
if not second:
break
second_tail = self.getListTail(second, lenOfList)

first_tail.next = None
second_tail.next = None
pair = self.merge(first, second)
merge_tail.next = pair[0]
merge_tail = pair[1]

while length > 1 and head.next:
length -= 1

def merge(self, l1, l2):
while l1 and l2:
if l1.val < l2.val:
tail.next = l1
l1 = l1.next
else:
tail.next = l2
l2 = l2.next
tail = tail.next
tail.next = l2 if not l1 else l1
while tail.next:
tail = tail.next