The idea for this is based on https://www.topcoder.com/community/data-science/data-science-tutorials/binary-search/. The article discusses how binary search can be done on a predicate. If the predicate
p is monotonic, i.e. it satisfies the property:
If p is true for x, then it is true for all values > x
then we can use binary search to find the smallest true value.
This can be applied to the current problem. Recall that the median of an array divides it into two equal parts. Similarly, to find the median of two arrays, we are trying to find
l that cut the arrays into two parts such that
k + l = (m + n) // 2 and
left_part | right_part A, A, ..., A[k-1] | A[k], A[k+1], ..., A[m-1] B, B, ..., B[l-1] | B[l], B[l+1], ..., B[n-1] (diagram copied from the current solution, but variables changed)
where the left part <= the right part. In other words we would like
a[k-1], b[l-1] <= a[k], b[l] (eqn 1)
As noted in other solutions, we can use binary search to find the location of the cut. An important note is that we do binary search on the shorter array. Suppose that
b is the shorter array. Then, we can binary search the indices
0 .. n-1 in
b for the location of the cut.
To solve this problem, we can find the smallest
l such that
a[k-1] < b[l] (<- the predicate). Note that this predicate is monotonic. (Think of
k as a function of
l.) So, this approach can be done with predicate binary search.
Why does this work? Once such an
l is found, we also know that
a[k] >= b[l-1], else the found
l is not the smallest. Therefore equation (1) is satisfied, so
l is actually the desired cut!
Handling edge cases: Several edge cases are handled by the way we implement predicate binary search. Note that the cut can pass before the start of
b, between two elements of
b, or after the end of
- If the predicate is true for all of
b, the whole
bis in the right part, and predicate binary search naturally returns 0.
- If the predicate is false for all of
b, the whole
bis in the left part. In this case we make predicate binary search return
So most of the pieces of the solution are here, and the idea is:
- Do predicate binary search to find the appropriate cut.
- Find the most extreme values of the left and right parts.
l = 0, the cut passes before the start of
b, so all of
bis on the right.
l = n, the cut passes after the end of
b, so all of
bis on the left.
- If the # of #s is even, return avg(max(left part), min(right part)). Else, return min(right part).
Since we use binary search, the desired
O(log(m + n)) runtime is achieved.
# Python 3 class Solution: def predicate_binary_search(self, lo, hi, p): """ Return the smallest index for which the predicate is true. The predicate p must have a breakpoint; it is false for all values up to a certain value, after which it is true for all values. In other words, p satisfies "if p is true for x, p is true for all values > x". If p is false for the whole range, return hi + 1. """ # lo, hi contains smallest i such that p(i) is true. while lo < hi: mid = (lo + hi) // 2 if p(mid): hi = mid else: lo = mid + 1 if p(lo) == False: return hi + 1 # p(x) is false for all x in S! return lo # lo is the least x for which p(x) is true def findMedianSortedArrays(self, a, b): if len(a) == len(b) == 0: raise IndexError # Assume b is shorter if len(b) > len(a): a, b = b, a m = len(a) n = len(b) # Find cut if n == 0: # Since b is empty, we consider the cut to be before the start of b # So 0 works l = 0 else: # It is left as an exercise for the reader why the index into a is # always in bounds l = self.predicate_binary_search( 0, n - 1, lambda l: a[(m+n)//2 - l - 1] < b[l]) # Get contenders for max elts on the left and min elts on the right if l == 0: # Use this slightly awkward syntax to add elements, since some of them # may have index out of bounds left = [a[(m+n)//2 - 1]] if (m+n)//2 - 1 >= 0 else  right = (([a[(m+n)//2]] if (m+n)//2 < len(a) else ) + ([b] if len(b) > 0 else ) ) elif l == n: left = (([a[(m-n)//2 - 1]] if (m-n)//2 - 1 >= 0 else ) + ([b[n-1]] if len(b) > 0 else ) ) right = [a[(m-n)//2]] # Must exist since a is the longer array else: # n >= 2 k = (m + n) // 2 - l left = [a[k-1], b[l-1]] right = [a[k], b[l]] # Depending on whether the # of #s is even or odd, just return the min of # the right half or average of the two halves if (m + n) % 2 == 0: return (max(left) + min(right)) / 2 else: return min(right)