O(k) solution

• Now that I can find the kth smallest element in a sorted n×n matrix in time O(min(n, k)), I can finally solve this problem in O(k).

The idea:

1. If `nums1` or `nums2` are larger than `k`, shrink them to size `k`.
2. Build a virtual matrix of the pair sums, i.e., `matrix[i][j] = nums1[i] + nums2[j]`. Make it a square matrix by padding with "infinity" if necessary. With "virtual" I mean its entries will be computed on the fly, and only those that are needed. This is necessary to stay within O(k) time.
3. Find the kth smallest sum `kthSum` by using that other algorithm.
4. Use a saddleback search variation to discount the pairs with sum smaller than `kthSum`. After this, `k` tells how many pairs we need whose sum equals `kthSum`.
5. Collect all pairs with sum smaller than `kthSum` as well as `k` pairs whose sum equals `kthSum`.

Each of those steps only takes O(k) time.

The code (minus the code for kthSmallest, which you can copy verbatim from my solution to the other problem):

``````class Solution(object):
def kSmallestPairs(self, nums1_, nums2_, k):

# Use at most the first k of each, then get the sizes.
nums1 = nums1_[:k]
nums2 = nums2_[:k]
m, n = len(nums1), len(nums2)

# Gotta Catch 'Em All?
if k >= m * n:
return [[a, b] for a in nums1 for b in nums2]

# Build a virtual matrix.
N, inf = max(m, n), float('inf')
class Row:
def __init__(self, i):
self.i = i
def __getitem__(self, j):
return nums1[self.i] + nums2[j] if self.i < m and j < n else inf
matrix = map(Row, range(N))

# Get the k-th sum.
kthSum = self.kthSmallest(matrix, k)

# Discount the pairs with sum smaller than the k-th.
j = min(k, n)
for a in nums1:
while j and a + nums2[j-1] >= kthSum:
j -= 1
k -= j

# Collect and return the pairs.
pairs = []
for a in nums1:
for b in nums2:
if a + b >= kthSum + (k > 0):
break
pairs.append([a, b])
k -= a + b == kthSum
return pairs

def kthSmallest(self, matrix, k):

# copy & paste from https://discuss.leetcode.com/topic/53126/o-n-from-paper-yes-o-rows
``````

Thanks to @zhiqing_xiao for pointing out that my previous way of capping the input lists might not be O(k). It was this:

``````def kSmallestPairs(self, nums1, nums2, k):
del nums1[k:]
del nums2[k:]``````

• wow, this is just wow....

• I think the "del" operations need to be avoided to achieve the O(k) complexity, since the original lists may be very long.

We might consider copying such as:

``````nums1 = nums1[:k]
nums2 = nums2[:k]
``````

or find other workaround.

• @zhiqing_xiao Interesting point. I don't know how `del` is implemented and what its complexity is for this. I thought it would be O(k). I ran an experiment where I applied `del nums[10:]` to increasingly large lists, and it looks like it's only O(n). See the table below. First column is the list size n, second column is the time for `del nums[10:]`. Third column is the time for your `nums = nums[:10]` and you can see it also looks like it's only O(n), it's about as fast as my `del`. I guessed that the time is caused not really "by the shrinking itself" but by Python "deallocating" the original list, and that appears to be correct because if I keep a reference to the original with `x = nums; nums = nums[:10]`, then it appears to be constant (that's the fourth column in the table (there is a significant increase at n=524288, but it seems to remain constant afterwards... I suspect that that jump is a CPU cache issue)).

``````       1 0.0000024 0.0000030 0.0000024
2 0.0000024 0.0000027 0.0000021
4 0.0000015 0.0000015 0.0000012
8 0.0000012 0.0000015 0.0000009
16 0.0000015 0.0000015 0.0000009
32 0.0000015 0.0000009 0.0000003
64 0.0000018 0.0000009 0.0000006
128 0.0000027 0.0000015 0.0000006
256 0.0000018 0.0000012 0.0000012
512 0.0000030 0.0000018 0.0000006
1024 0.0000039 0.0000030 0.0000006
2048 0.0000066 0.0000057 0.0000006
4096 0.0000249 0.0000111 0.0000006
8192 0.0000300 0.0000219 0.0000006
16384 0.0000574 0.0000454 0.0000009
32768 0.0001373 0.0000859 0.0000006
65536 0.0002809 0.0002028 0.0000015
131072 0.0005817 0.0003888 0.0000009
262144 0.0011667 0.0013596 0.0000009
524288 0.0035945 0.0020089 0.0000030
1048576 0.0054009 0.0041783 0.0000021
2097152 0.0105489 0.0084729 0.0000012
4194304 0.0227224 0.0149213 0.0000024
8388608 0.0459217 0.0479958 0.0000024
16777216 0.0819970 0.0620061 0.0000027
33554432 0.1597240 0.1625328 0.0000036
67108864 0.3310355 0.2623271 0.0000024
``````

Here's my test program producing that table:

``````from timeit import timeit
import gc

codes = (
'del nums[10:]',
'nums = nums[:10]',
'x = nums; nums = nums[:10]'
)

for e in range(27):
n = 2**e
times = []
for code in codes:
gc.collect()
times.append(timeit(code, 'nums = range(%d)' % n, number=1))
print '{:8} {:.7f} {:.7f} {:.7f}'.format(n, *times)
``````

I tested it with Python 3 as well, same picture.

So... I guess I should use that third way... I'll probably just rename the arguments and copy the prefixes into new variables. (Edit: done)

• @StefanPochmann I think when you build matrix, it is already O(mn) or if you slice num1 and num2 it becomes O(k^2). So it is not really O(k). By the way, I am not sure if the question requires us to get sorted result, if so, O(k) can be never achieved. The best we can do is O(klogk)

• @jerrold You'd be correct if I built the matrix explicitly. But I don't. Like I said, I use a virtual matrix. Whenever I access an element of the matrix, I then compute that element.

And the result isn't required to be sorted.

• @StefanPochmann I see...Thanks and this is amazing!

• @zhiqing_xiao Oh... I think I just figured out why deleting elements from the tail of a list takes O(#deletedElements) time. Originally I thought the list would either...

• just reduce its size value and keep its capacity (taking O(1) time) or
• reduce its capacity by copying the remaining elements into a newly allocated space (taking O(#remainingElements) time) and then freeing the old space (taking O(1) time).

But what I forgot is that CPython will go through the deleted elements, reduce their reference counter, and possibly actually destroy the objects (not just delete them from the list). So that's a cost for each deleted element, making it need O(#deletedElements) as observed.

Similar to something else I recently noticed and asked about on StackOverflow.

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.