O(k) solution


  • 9

    Now that I can find the kth smallest element in a sorted n×n matrix in time O(min(n, k)), I can finally solve this problem in O(k).

    The idea:

    1. If nums1 or nums2 are larger than k, shrink them to size k.
    2. Build a virtual matrix of the pair sums, i.e., matrix[i][j] = nums1[i] + nums2[j]. Make it a square matrix by padding with "infinity" if necessary. With "virtual" I mean its entries will be computed on the fly, and only those that are needed. This is necessary to stay within O(k) time.
    3. Find the kth smallest sum kthSum by using that other algorithm.
    4. Use a saddleback search variation to discount the pairs with sum smaller than kthSum. After this, k tells how many pairs we need whose sum equals kthSum.
    5. Collect all pairs with sum smaller than kthSum as well as k pairs whose sum equals kthSum.

    Each of those steps only takes O(k) time.

    The code (minus the code for kthSmallest, which you can copy verbatim from my solution to the other problem):

    class Solution(object):
        def kSmallestPairs(self, nums1_, nums2_, k):
    
            # Use at most the first k of each, then get the sizes.
            nums1 = nums1_[:k]
            nums2 = nums2_[:k]
            m, n = len(nums1), len(nums2)
    
            # Gotta Catch 'Em All?
            if k >= m * n:
                return [[a, b] for a in nums1 for b in nums2]
            
            # Build a virtual matrix.
            N, inf = max(m, n), float('inf')
            class Row:
                def __init__(self, i):
                    self.i = i
                def __getitem__(self, j):
                    return nums1[self.i] + nums2[j] if self.i < m and j < n else inf
            matrix = map(Row, range(N))
    
            # Get the k-th sum.
            kthSum = self.kthSmallest(matrix, k)
    
            # Discount the pairs with sum smaller than the k-th.
            j = min(k, n)
            for a in nums1:
                while j and a + nums2[j-1] >= kthSum:
                    j -= 1
                k -= j
    
            # Collect and return the pairs.
            pairs = []
            for a in nums1:
                for b in nums2:
                    if a + b >= kthSum + (k > 0):
                        break
                    pairs.append([a, b])
                    k -= a + b == kthSum
            return pairs
    
        def kthSmallest(self, matrix, k):
            
            # copy & paste from https://discuss.leetcode.com/topic/53126/o-n-from-paper-yes-o-rows
    

    Thanks to @zhiqing_xiao for pointing out that my previous way of capping the input lists might not be O(k). It was this:

    def kSmallestPairs(self, nums1, nums2, k):
        del nums1[k:]
        del nums2[k:]

  • 0
    H

    wow, this is just wow....


  • 2

    I think the "del" operations need to be avoided to achieve the O(k) complexity, since the original lists may be very long.

    We might consider copying such as:

    nums1 = nums1[:k]
    nums2 = nums2[:k]
    

    or find other workaround.


  • 0

    @zhiqing_xiao Interesting point. I don't know how del is implemented and what its complexity is for this. I thought it would be O(k). I ran an experiment where I applied del nums[10:] to increasingly large lists, and it looks like it's only O(n). See the table below. First column is the list size n, second column is the time for del nums[10:]. Third column is the time for your nums = nums[:10] and you can see it also looks like it's only O(n), it's about as fast as my del. I guessed that the time is caused not really "by the shrinking itself" but by Python "deallocating" the original list, and that appears to be correct because if I keep a reference to the original with x = nums; nums = nums[:10], then it appears to be constant (that's the fourth column in the table (there is a significant increase at n=524288, but it seems to remain constant afterwards... I suspect that that jump is a CPU cache issue)).

           1 0.0000024 0.0000030 0.0000024
           2 0.0000024 0.0000027 0.0000021
           4 0.0000015 0.0000015 0.0000012
           8 0.0000012 0.0000015 0.0000009
          16 0.0000015 0.0000015 0.0000009
          32 0.0000015 0.0000009 0.0000003
          64 0.0000018 0.0000009 0.0000006
         128 0.0000027 0.0000015 0.0000006
         256 0.0000018 0.0000012 0.0000012
         512 0.0000030 0.0000018 0.0000006
        1024 0.0000039 0.0000030 0.0000006
        2048 0.0000066 0.0000057 0.0000006
        4096 0.0000249 0.0000111 0.0000006
        8192 0.0000300 0.0000219 0.0000006
       16384 0.0000574 0.0000454 0.0000009
       32768 0.0001373 0.0000859 0.0000006
       65536 0.0002809 0.0002028 0.0000015
      131072 0.0005817 0.0003888 0.0000009
      262144 0.0011667 0.0013596 0.0000009
      524288 0.0035945 0.0020089 0.0000030
     1048576 0.0054009 0.0041783 0.0000021
     2097152 0.0105489 0.0084729 0.0000012
     4194304 0.0227224 0.0149213 0.0000024
     8388608 0.0459217 0.0479958 0.0000024
    16777216 0.0819970 0.0620061 0.0000027
    33554432 0.1597240 0.1625328 0.0000036
    67108864 0.3310355 0.2623271 0.0000024
    

    Here's my test program producing that table:

    from timeit import timeit
    import gc
    
    codes = (
        'del nums[10:]',
        'nums = nums[:10]',
        'x = nums; nums = nums[:10]'
        )
        
    for e in range(27):
        n = 2**e
        times = []
        for code in codes:
            gc.collect()
            times.append(timeit(code, 'nums = range(%d)' % n, number=1))
        print '{:8} {:.7f} {:.7f} {:.7f}'.format(n, *times)
    

    I tested it with Python 3 as well, same picture.

    So... I guess I should use that third way... I'll probably just rename the arguments and copy the prefixes into new variables. (Edit: done)


  • 0
    J

    @StefanPochmann I think when you build matrix, it is already O(mn) or if you slice num1 and num2 it becomes O(k^2). So it is not really O(k). By the way, I am not sure if the question requires us to get sorted result, if so, O(k) can be never achieved. The best we can do is O(klogk)


  • 1

    @jerrold You'd be correct if I built the matrix explicitly. But I don't. Like I said, I use a virtual matrix. Whenever I access an element of the matrix, I then compute that element.

    And the result isn't required to be sorted.


  • 0
    J

    @StefanPochmann I see...Thanks and this is amazing!


  • 0

    @zhiqing_xiao Oh... I think I just figured out why deleting elements from the tail of a list takes O(#deletedElements) time. Originally I thought the list would either...

    • just reduce its size value and keep its capacity (taking O(1) time) or
    • reduce its capacity by copying the remaining elements into a newly allocated space (taking O(#remainingElements) time) and then freeing the old space (taking O(1) time).

    But what I forgot is that CPython will go through the deleted elements, reduce their reference counter, and possibly actually destroy the objects (not just delete them from the list). So that's a cost for each deleted element, making it need O(#deletedElements) as observed.

    Similar to something else I recently noticed and asked about on StackOverflow.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.