O(n) from paper. Yes, O(#rows).


  • 33

    It's O(n) where n is the number of rows (and columns), not the number of elements. So it's very efficient. The algorithm is from the paper Selection in X + Y and matrices with sorted rows and columns, which I first saw mentioned by @elmirap (thanks).

    The basic idea: Consider the submatrix you get by removing every second row and every second column. This has about a quarter of the elements of the original matrix. And the k-th element (k-th smallest I mean) of the original matrix is roughly the (k/4)-th element of the submatrix. So roughly get the (k/4)-th element of the submatrix and then use that to find the k-th element of the original matrix in O(n) time. It's recursive, going down to smaller and smaller submatrices until a trivial 2×2 matrix. For more details I suggest checking out the paper, the first half is easy to read and explains things well. Or @zhiqing_xiao's solution+explanation.

    Cool: It uses variants of saddleback search that you might know for example from the Search a 2D Matrix II problem. And it uses the median of medians algorithm for linear-time selection.

    Optimization: If k is less than n, we only need to consider the top-left k×k matrix. Similar if k is almost n2. So it's even O(min(n, k, n^2-k)), I just didn't mention that in the title because I wanted to keep it simple and because those few very small or very large k are unlikely, most of the time k will be "medium" (and average n2/2).

    Implementation: I implemented the submatrix by using an index list through which the actual matrix data gets accessed. If [0, 1, 2, ..., n-1] is the index list of the original matrix, then [0, 2, 4, ...] is the index list of the submatrix and [0, 4, 8, ...] is the index list of the subsubmatrix and so on. This also covers the above optimization by starting with [0, 1, 2, ..., k-1] when applicable.

    Application: I believe it can be used to easily solve the Find K Pairs with Smallest Sums problem in time O(k) instead of O(k log n), which I think is the best posted so far. I might try that later if nobody beats me to it (if you do, let me know :-). Update: I did that now.

    class Solution(object):
        def kthSmallest(self, matrix, k):
    
            # The median-of-medians selection function.
            def pick(a, k):
                if k == 1:
                    return min(a)
                groups = (a[i:i+5] for i in range(0, len(a), 5))
                medians = [sorted(group)[len(group) / 2] for group in groups]
                pivot = pick(medians, len(medians) / 2 + 1)
                smaller = [x for x in a if x < pivot]
                if k <= len(smaller):
                    return pick(smaller, k)
                k -= len(smaller) + a.count(pivot)
                return pivot if k < 1 else pick([x for x in a if x > pivot], k)
    
            # Find the k1-th and k2th smallest entries in the submatrix.
            def biselect(index, k1, k2):
    
                # Provide the submatrix.
                n = len(index)
                def A(i, j):
                    return matrix[index[i]][index[j]]
                
                # Base case.
                if n <= 2:
                    nums = sorted(A(i, j) for i in range(n) for j in range(n))
                    return nums[k1-1], nums[k2-1]
    
                # Solve the subproblem.
                index_ = index[::2] + index[n-1+n%2:]
                k1_ = (k1 + 2*n) / 4 + 1 if n % 2 else n + 1 + (k1 + 3) / 4
                k2_ = (k2 + 3) / 4
                a, b = biselect(index_, k1_, k2_)
    
                # Prepare ra_less, rb_more and L with saddleback search variants.
                ra_less = rb_more = 0
                L = []
                jb = n   # jb is the first where A(i, jb) is larger than b.
                ja = n   # ja is the first where A(i, ja) is larger than or equal to a.
                for i in range(n):
                    while jb and A(i, jb - 1) > b:
                        jb -= 1
                    while ja and A(i, ja - 1) >= a:
                        ja -= 1
                    ra_less += ja
                    rb_more += n - jb
                    L.extend(A(i, j) for j in range(jb, ja))
                    
                # Compute and return x and y.
                x = a if ra_less <= k1 - 1 else \
                    b if k1 + rb_more - n*n <= 0 else \
                    pick(L, k1 + rb_more - n*n)
                y = a if ra_less <= k2 - 1 else \
                    b if k2 + rb_more - n*n <= 0 else \
                    pick(L, k2 + rb_more - n*n)
                return x, y
    
            # Set up and run the search.
            n = len(matrix)
            start = max(k - n*n + n-1, 0)
            k -= n*n - (n - start)**2
            return biselect(range(start, min(n, start+k)), k, k)[0]
    

  • 1

    @StefanPochmann This is so cool! Thank you for sharing Stefan! But it is way beyond what is expected at an interview I guess.


  • 11

    @agave said in O(n) from paper. Yes, O(#rows).:

    @StefanPochmann But it is way beyond what is expected at an interview I guess.

    But if I ever do happen to get asked this question... then I'll get totally excited and giddy and... won't remember all those formulas or how to create them.


  • 0

    @StefanPochmann said in O(n) from paper. Yes, O(#rows).:

    ... won't remember all those formulas or how to create them.

    Haha would be awesome if you did!


  • 6
    O

    @StefanPochmann I didn't even know the existence of Median of medians algorithm, so Problem 215 Kth Largest Element in an Array does have a worst-case linear solution after all! Awesome!

    Does the O(wn) solution seem more practical to you in real life?

    class Solution(object):
        def kthSmallest(self, matrix, k):
            lo, hi = matrix[0][0], matrix[-1][-1]
            while lo < hi:
                mid, count, j = (lo+hi)//2, 0, len(matrix[0])
                for row in matrix:
                    while j>=1 and row[j-1] > mid:
                        j -= 1
                    count += j
                if count < k:
                    lo = mid+1
                else:
                    hi = mid
            return lo
    

  • 0

    @o_sharp Definitely more practical in real interview life :-). I had already tried that and compared those two, with matrices up to 1000000×1000000 or so, asking for a median element. I think they were about equally fast. But I suspect it depends on how I fill the matrices, and I didn't try to optimize either (Maybe instead of median-of-medians, use sort+grab? Would be O(n log n) instead of O(n), but could still be better in practice since sort is pretty fast).


  • 0
    O

    Good solution. However, just like the Median of Medians algorithm it's only algorithmically efficient. In real life the overhead is so large that it's probably even slower. It's definitely something interesting to think about though, with these recursively dividing solutions that the lower bound can be decreased even further.


  • 0

    @oneraynyday What do you think are the most "real-life-efficient" algorithms for this problem?


  • 0

    Thanks so much for providing such an interesting idea!

    We can see in the function biselect, when the number of indices (denoted as n) is even, we update (k1, k2) as:
    k1 <- floor(k1 + 2n/4) + 1
    k2 <- ceil(k2/4)
    while n is odd,
    k1 <- n + 1 + ceil(k1/4)
    k2 <- ceil(k2/4)
    Could you please explain the meaning of k1 and k2, and why update them like this? Thanks!


  • 0
    This post is deleted!

  • 0

    @zhiqing_xiao You got the first one wrong. In all cases, k1 and k2 are shrunk to about one quarter. That's what I referred to with "So roughly get the (k/4)-th element of the submatrix and then use that to find the k-th element of the original matrix in O(n) time". It's actually those two elements k1 and k2, and they define a small range. And biselect actually returns exactly the k1-th and k2-th element. For more details I suggest checking out the paper, the first half is easy to read and explains things well.


  • 0

    @StefanPochmann Thank you very much for your prompt reply!


  • 2
    E

    @o_sharp Great code that you provided. Excellent and practical. However, when (count < k), you increase "lo" to "mid+1". I tried to find the next value for "lo" from the matrix, rather than incrementing by one at each iteration. Here is my code which brought runtime from 14 ms down to 9 ms. Thanks again for the excellent idea.

       public int kthSmallest(int[][] matrix, int k) {
            int n = matrix.length;
            int lo = matrix[0][0], hi = matrix[n - 1][n - 1];
            while(lo < hi) {
                int count = 0, mid = (lo + hi) / 2, nextBiggerMid = Integer.MAX_VALUE;
                for(int i = 0; i < n; i++) {
                    int j = 0; 
                    while(j < n){
                        if(matrix[i][j] > mid) {
                            nextBiggerMid = (nextBiggerMid < matrix[i][j]) ? nextBiggerMid : matrix[i][j];
                            break;
                        }
                        j++;
                    }
                    if(j == 0) {
                        break;
                    }
                    count+= j;
                }
                if(count < k) {
                    lo = nextBiggerMid;
                }
                else {
                    hi = mid;
                }
            }
            return lo;
        }
    

  • 0

    @eyeabhi said in O(n) from paper. Yes, O(#rows).:

    Here is my code which brought runtime from 14 ms down to 9 ms.

    What does your 14 ms version look like? I tried it as well and it gets 2 ms:

    public class Solution {
        public int kthSmallest(int[][] matrix, int k) {
            int n = matrix.length;
            int lo = matrix[0][0], hi = matrix[n-1][n-1];
            while (lo < hi) {
                int mid = (lo + hi) / 2, count = 0, j = n;
                for (int[] row : matrix) {
                    while (j >= 1 && row[j-1] > mid)
                        j--;
                    count += j;
                }
                if (count < k)
                    lo = mid + 1;
                else
                    hi = mid;
            }
            return lo;
        }
    }
    

    Also, you replaced the O(n) counting algorithm with an O(n2) algorithm. So that's not good. Maybe you can integrate your idea better into the original, so that it remains O(n)?


  • 2
    E

    @StefanPochmann Yeah... thanks for pointing out... I fixed the counting algorithm to O(n) and now it runs in 1 ms.

        public int kthSmallest(int[][] matrix, int k) {
            int n = matrix.length;
            int lo = matrix[0][0], hi = matrix[n - 1][n - 1];
            while(lo < hi) {
                int count = 0, mid = (lo + hi) / 2, nextBiggerMid = Integer.MAX_VALUE;
                int i = n - 1, j = 0;
                while(i >= 0 && j < n) {
                    if(matrix[i][j] > mid) {
                        nextBiggerMid = (nextBiggerMid < matrix[i][j]) ? nextBiggerMid : matrix[i][j];
                        i--;
                    }
                    else {
                        count += i + 1;
                        j++;
                    }
                }
                if(count == k - 1) {
                    return nextBiggerMid;
                }
                else if(count < k) {
                    lo = nextBiggerMid;
                }
                else {
                    hi = mid;
                }
            }
            return lo;
        }
    

    Overall complexity: O(nlog(n)).

    However, with

                if (count < k)
                    lo = mid + 1;
    

    the overall complexity would be O(nlog(max-min)), where, max and min are the largest and smallest element in the array, respectively.

    Thanks.


  • 0
    F

    I wanted to see if I could improve the saddleback section of the algorithm by using binary search when looking for column matching ja/jb (a.k.a. "rank" function). On paper it could optimize a corner case where we take the worst case O(2n) down to O(n+log(n)). -- but in reality, implementing it in Python makes it slower than just scanning the list in linear time! (I also realized this method's worst case would be O(n*log(n)) so that's not good..)

    With that in mind, I replaced the pick() function with pick = lambda L, x: sorted(L)[x-1] and the bugger runs faster on a 5000x5000 matrix I gave it..
    In practice, I guess this means there's some |L| where sorted(L)[x-1] is just faster than median of medians approach (just because there's less overhead) and the pick function could switch between the two as necessary.

    Thoughts?

    P.S. I'm not generating a very well randomized matrix, just filling diagonals with increasing numbers. What's a better way of doing it?

    P.P.S. I see, I wasn't reading comments closely enough :P This had already been mentioned


  • 0
    S

    @o_sharp How about the data is not integer? It cannot work.


  • 0
    Z

    Another cool property of "median of median" algorithm is that it can be implemented inplace.


  • 0
    I

    @o_sharp Hi there.
    Could you please explain the meaning of

    while j > 0 and row[j-1] > mid:
        j -= 1
    count += j
    

    Why comparing the elements of the matrix one by one in order with mid. And when exactly does the condition happen where lo finally equals to kth smallest element?
    Thanks!


  • 0
    O

    @ian34 Please look here https://discuss.leetcode.com/topic/52912/binary-search-heap-and-sorting-comparison-with-concise-code-and-1-liners-python-72-ms. In "binary search", that code is clearer but slower. It simply counts how many elements in the matrix are bigger.

    The code here is faster, but the idea is the same.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.