Python O(log(min(m, n))) solution using predicate binary search

  • 0

    The idea for this is based on The article discusses how binary search can be done on a predicate. If the predicate p is monotonic, i.e. it satisfies the property:
    If p is true for x, then it is true for all values > x

    then we can use binary search to find the smallest true value.

    This can be applied to the current problem. Recall that the median of an array divides it into two equal parts. Similarly, to find the median of two arrays, we are trying to find k, l that cut the arrays into two parts such that k + l = (m + n) // 2 and

          left_part          |        right_part
    A[0], A[1], ..., A[k-1]  |  A[k], A[k+1], ..., A[m-1]
    B[0], B[1], ..., B[l-1]  |  B[l], B[l+1], ..., B[n-1]
    (diagram copied from the current solution, but variables changed)

    where the left part <= the right part. In other words we would like
    a[k-1], b[l-1] <= a[k], b[l] (eqn 1)

    As noted in other solutions, we can use binary search to find the location of the cut. An important note is that we do binary search on the shorter array. Suppose that b is the shorter array. Then, we can binary search the indices 0 .. n-1 in b for the location of the cut.

    To solve this problem, we can find the smallest l such that a[k-1] < b[l] (<- the predicate). Note that this predicate is monotonic. (Think of k as a function of l.) So, this approach can be done with predicate binary search.

    Why does this work? Once such an l is found, we also know that a[k] >= b[l-1], else the found l is not the smallest. Therefore equation (1) is satisfied, so l is actually the desired cut!

    Handling edge cases: Several edge cases are handled by the way we implement predicate binary search. Note that the cut can pass before the start of b, between two elements of b, or after the end of b.

    • If the predicate is true for all of b, the whole b is in the right part, and predicate binary search naturally returns 0.
    • If the predicate is false for all of b, the whole b is in the left part. In this case we make predicate binary search return n.

    So most of the pieces of the solution are here, and the idea is:

    1. Do predicate binary search to find the appropriate cut.
    2. Find the most extreme values of the left and right parts.
      • If l = 0, the cut passes before the start of b, so all of b is on the right.
      • If l = n, the cut passes after the end of b, so all of b is on the left.
    3. If the # of #s is even, return avg(max(left part), min(right part)). Else, return min(right part).

    Since we use binary search, the desired O(log(m + n)) runtime is achieved.

    # Python 3
    class Solution:
        def predicate_binary_search(self, lo, hi, p):
            Return the smallest index for which the predicate is true. 
            The predicate p must have a breakpoint; it is false for all values 
            up to a certain value, after which it is true for all values. In 
            other words, p satisfies "if p is true for x, p is true for all 
            values > x". 
            If p is false for the whole range, return hi + 1.
            # lo, hi contains smallest i such that p(i) is true.
            while lo < hi:
                mid = (lo + hi) // 2
                if p(mid):
                    hi = mid
                    lo = mid + 1
            if p(lo) == False:
                return hi + 1  # p(x) is false for all x in S!
            return lo  # lo is the least x for which p(x) is true
        def findMedianSortedArrays(self, a, b):
            if len(a) == len(b) == 0:
                raise IndexError
            # Assume b is shorter
            if len(b) > len(a):
                a, b = b, a
            m = len(a)
            n = len(b)
            # Find cut
            if n == 0:
                # Since b is empty, we consider the cut to be before the start of b
                # So 0 works
                l = 0
                # It is left as an exercise for the reader why the index into a is
                # always in bounds
                l = self.predicate_binary_search(
                        0, n - 1, 
                        lambda l: a[(m+n)//2 - l - 1] < b[l])
            # Get contenders for max elts on the left and min elts on the right
            if l == 0:
                # Use this slightly awkward syntax to add elements, since some of them
                # may have index out of bounds
                left = [a[(m+n)//2 - 1]] if (m+n)//2 - 1 >= 0 else []
                right = (([a[(m+n)//2]] if (m+n)//2 < len(a) else [])
                         + ([b[0]] if len(b) > 0 else [])
            elif l == n:
                left = (([a[(m-n)//2 - 1]] if (m-n)//2 - 1 >= 0 else [])
                        + ([b[n-1]] if len(b) > 0 else [])
                right = [a[(m-n)//2]]  # Must exist since a is the longer array
                # n >= 2
                k = (m + n) // 2 - l
                left = [a[k-1], b[l-1]]
                right = [a[k], b[l]]
            # Depending on whether the # of #s is even or odd, just return the min of 
            # the right half or average of the two halves
            if (m + n) % 2 == 0:
                return (max(left) + min(right)) / 2
                return min(right)

Log in to reply

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.