Super easy to understand Java iterative merge sort using O(1) space

  • 5

    I found most iterative merge sort solutions not quite easy to read. So I decided to write my own. I tried to write it as readable as possible.

    The basic idea is simple: bottom-up merge sort. Plenty of threads have covered this so I won't elaborate.
    However, a few observations make it easier to build your code.

    1. At the very beginning, we have sorted sublists of length 1. Then we have sorted sublists of length 2, 4, 8,.....2^i. When 2^i >= length of the input list, we are done. So we can use this fact for our main loop.

    2. At the ith pass, only the last sorted sublist may have length less than 2^{i-1}. All the other sorted sublists have length exactly 2^{i-1}. So we can use this fact to get the heads of all sublists.
      I wrote a function getListTail(head, len). Given the head of a sublist, we can walk (len - 1) steps to get to it's tail. The head of the next sublist is just
      Why do we want the tail? Because when we merge two sublists, we want to set = null.

    3. When you have even number of sublists, you can always pair them. When it's odd, you can just leave the last one there, you will eventually merge it with other another sublist. At the last pass, we always merge two sorted sublists, although the second one may not have length 2^k for some k.

    4. So for each pass, each time we want to get two sorted sublists and merge them until either (1) we are left with only one sublist, or (2) we are left nothing.

    5. When merging two adjacent sublists. You want to set their tails point to null. Then after merging, you want to connect its head to the previously merged pairs and its tail to the rest of the unmerged sublists. So keep pointers to both ends. Also your merge function should return both the head and tail of the merged list.

    The iterative solution consumes O(1) space since it doesn't use collections or recursion.
    My implementation is slower than my recursive version. I think it's reasonable because of the tradeoff between space and speed (recursion use O(log n) stack memory, part of the memory is used to mark the boundary of sublists, but we have to "manually" find the boundary, doing setting to null and reconnecting staff)

     public class Solution {
            class ListNodePair{
                ListNode head;
                ListNode tail;
                ListNodePair(ListNode h, ListNode t){   head = h;   tail = t;   } 
            public ListNode sortList(ListNode head) {
                if(head == null || == null)   return head;
                int len = 0;
                for(ListNode runner = head; runner != null; runner =   len++;
                int lenOfList = 1;              //length of sorted sublists: at first we start with 1
                while(lenOfList < len){         //if greater than len, we have the whole list sorted
                    head = mergeLayer(head, lenOfList);
                    lenOfList *= 2;             //each iterate double the length of sorted sublists
                return head;
            //merge every two sublists of length lenOfList, assuming each sublist is already sorted 
            private ListNode mergeLayer(ListNode head, int lenOfList){
                ListNode fakehead = new ListNode(0);
                ListNode merge_tail = fakehead;   //merge_pail points to the tail of merged part of this layer
                ListNode first;
                ListNode first_tail;
                ListNode second;
                ListNode second_tail;
                ListNodePair pair;
                while(head != null){
                    first = head;
                    first_tail = getListTail(first, lenOfList);
                    second =;
                    if(second == null){                     //we have only 1 sorted sublist, 
               = head;             //link the sorted part to the last sorted sublist
                    second_tail = getListTail(second, lenOfList);
                    head =;    //now we have get two suitable subllists, point head to the rest
           = null;
           = null;
                    pair = merge(first, second);
           = pair.head;  //link the old sorted part to the newly sorted part
                    merge_tail = pair.tail;       //update the end of sorted part
            //get the tail of the list with head 'head' and length 'len' (or at most len)     
            private ListNode getListTail(ListNode head, int len){ 
                while(len > 1 && != null){
                    head =;
                    len --;
                return head;
            //merge two sorted lists, return both the head and tail of the new list
            private ListNodePair merge(ListNode l1, ListNode l2){
                ListNode fakehead = new ListNode(0);
                ListNode tail = fakehead;
                while(l1 != null && l2 != null){
                    if(l1.val < l2.val){
               = l1;
                        l1 =;
               = l2;
                        l2 =;
                    tail =;
       = (l1 == null) ? l2 : l1;
                while( != null)    tail =;
                return new ListNodePair(, tail);

  • 0

    Thanks for the detailed explanation!

  • 0

    Thank you for your easy understand code!

  • 0

    Thanks for the thorough explanation! Your code is already pretty easy to read, but I wrote out the Python version for others:

    class Solution(object):
        def sortList(self, head):
            if not head or not
                return head
            cur = head
            length = 0
            while cur:
                cur =
                length += 1
            lenOfList = 1
            while lenOfList < length:
                head = self.mergeLayer(head, lenOfList)
                lenOfList *= 2
            return head
        def mergeLayer(self, head, lenOfList):
            fakehead = ListNode(0)
            merge_tail = fakehead
            while head:
                first = head
                first_tail = self.getListTail(first, lenOfList)
                second =
                if not second:
           = head
                second_tail = self.getListTail(second, lenOfList)
                head =
       = None
       = None
                pair = self.merge(first, second)
       = pair[0]
                merge_tail = pair[1]
        def getListTail(self, head, length):
            while length > 1 and
                head =
                length -= 1
            return head
        def merge(self, l1, l2):
            fakehead = ListNode(0)
            tail = fakehead
            while l1 and l2:
                if l1.val < l2.val:
           = l1
                    l1 =
           = l2
                    l2 =
                tail =
   = l2 if not l1 else l1
                tail =
            return, tail

Log in to reply

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.