Share one divide and conquer O(log(m+n)) method with clear description


  • 71
    S
    // using divide and conquer idea, each time find the mid of both arrays
    
    double findMedianSortedArrays(int A[], int m, int B[], int n) {
            /* A[0, 1, 2, ..., n-1, n] */
            /* A[0, 1, 2, ..., m-1, m] */
            int k = (m + n + 1) / 2;
            double v = (double)FindKth(A, 0, m - 1, B, 0, n - 1, k);
            
            if ((m+n) % 2 == 0) {
                int k2 = k+1;
                double v2 = (double)FindKth(A, 0, m - 1, B, 0, n - 1, k2);
                v = (v + v2) / 2;
            }
            
            return v;
        }
        
        // find the kth element int the two sorted arrays
        // let us say: A[aMid] <= B[bMid], x: mid len of a, y: mid len of b, then wen can know
        // 
        // (1) there will be at least (x + 1 + y) elements before bMid
        // (2) there will be at least (m - x - 1 + n - y) = m + n - (x + y +1) elements after aMid
        // therefore
        // if k <= x + y + 1, find the kth element in a and b, but unconsidering bMid and its suffix
        // if k > x + y + 1, find the k - (x + 1) th element in a and b, but unconsidering aMid and its prefix
        int FindKth(int A[], int aL, int aR, int B[], int bL, int bR, int k) {
            if (aL > aR) return B[bL + k - 1];
            if (bL > bR) return A[aL + k - 1];
            
            int aMid = (aL + aR) / 2;
            int bMid = (bL + bR) / 2;
            
            if (A[aMid] <= B[bMid]) {
                if (k <= (aMid - aL) + (bMid - bL) + 1) 
                    return FindKth(A, aL, aR, B, bL, bMid - 1, k);
                else
                    return FindKth(A, aMid + 1, aR, B, bL, bR, k - (aMid - aL) - 1);
            }
            else { // A[aMid] > B[bMid]
                if (k <= (aMid - aL) + (bMid - bL) + 1) 
                    return FindKth(A, aL, aMid - 1, B, bL, bR, k);
                else
                    return FindKth(A, aL, aR, B, bMid + 1, bR, k - (bMid - bL) - 1);
            }
        }

  • 0
    M

    Neat solution. How long did it take to come up with?


  • 0
    S

    really a long time, ;)


  • -13
    S

    ..Really slow..But better than mine..My solution is out the range of "Accepted Solutions Runtime Distribution"..


  • 0
    G

    How did you come up with this solution?


  • 0
    M

    Brilliant solution!


  • 1
    I

    Wait so are we seriously supposed to come up with this sort of solution in the interview? it seems pretty hard..


  • 0
    Z

    nice implementation! I have tried a couple of rounds, failed all over in all kinds of corner cases, after checking out yours, just realized calculating k by m - start position makes the calculation consistent.


  • 0
    R

    Hi All,

    I am not able to understand the time complexity, how it is O(log(m+n))? because looks like every-time 1/4 of element are eliminated not 1/2 elements. Can anyone clarify?

    Thanks


  • 0
    R

    @ranhar It is still O(log(m+n)) just the base is not 2 but 4/3


  • 0
    K

    @shichaotan - Can you explain your algorithm using examples. For instance why do you start with k= (m+n+1)/2. Also how do you derive to complexity O(log(m+n))? Thanks in advance.


  • 3

    I revised the code from @shichaotan for recent cpp code.

    I will explain the algo. after the code.

    class Solution {
    public:
        double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
            int len1 = nums1.size(), len2 = nums2.size();
            int k = (len1 + len2 + 1) / 2; // for odd total it is the mid one, for even it is the left mid
            int num1 = findKth(nums1, 0, len1 - 1, nums2, 0, len2 - 1, k);
            if ((len1 + len2) & 1) return num1; // the sum of lengths is odd
    
            int num2 = findKth(nums1, 0, len1 - 1, nums2, 0, len2 - 1, k + 1);
            return (num1 + num2) / 2.0;
        }
    
    private:
        int findKth(vector<int>& nums1, int L1, int R1, vector<int>& nums2, int L2, int R2, int k) {
            if (L1 > R1) return nums2[L2 + k - 1];
            if (L2 > R2) return nums1[L1 + k - 1];
            int mid1 = L1 + (R1 - L1) / 2, mid2 = L2 + (R2 - L2);
    
            if (nums1[mid1] <= nums2[mid2]) {
                if (k <=  (mid1 - L1) + (mid2 - L2) + 1) return findKth(nums1, L1, R1, nums2, L2, mid2 - 1, k);
                else return findKth(nums1, mid1 + 1, R1, nums2, L2, R2, k - (mid1 - L1) - 1);
            } else {
                if (k <=  (mid1 - L1) + (mid2 - L2) + 1) return findKth(nums1, L1, mid1 - 1, nums2, L2, R2, k);
                else return findKth(nums1, L1, R1, nums2, mid2 + 1, R2, k - (mid2 - L2) - 1);
            }
        }
    };
    

    1 key point in findMedianSortedArrays()

    It is k = (m + n + 1) / 2.
    Let me give you examples here.
    If the sum of lengths in 2 arrays is odd, the median is the (m + n + 1) / 2 one. (5 + 1) / 2 = 3.

    1 2 3 4 5
        m
    

    If the sum of lengths in 2 arrays is even, the mid-left one is the (m + n + 1) / 2 one. (4 + 1) / 2 = 2. we need k + 1 one to calculate the median.

    1 2 3 4 
      m
    

    2 key points in findKth()

    The first one is the logic to do binary search in both two arrays.
    0_1480292099915_Screen Shot 2016-11-27 at 16.13.28.png

    The second one is how to define the size of left part.
    The (mid1 - L1) + (mid2 - L2) + 1 is actually meaning there should be two pointers in two arrays: one is the kth and another one is the one which makes the recursion end.

    array1: 
    x x x x x x
      n
    
    array2:
    x x x x x x x x
          k
    
    "merged one array"
    x x x x x x x x x x x x x x
            k
    

    so we count the left of both pointers and the one where k is.

    I think the code is simpler if we make if condition for k first, and then for nums1 or nums2.

    class Solution {
    public:
        double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
            int len1 = nums1.size(), len2 = nums2.size();
            int k = (len1 + len2 + 1) / 2; // for odd total it is the mid one, for even it is the left mid
            int num1 = findKth(nums1, 0, len1 - 1, nums2, 0, len2 - 1, k);
            if ((len1 + len2) & 1) return num1;
    
            int num2 = findKth(nums1, 0, len1 - 1, nums2, 0, len2 - 1, k + 1);
            return (num1 + num2) / 2.0;
        }
    
    private:
        int findKth(vector<int>& nums1, int L1, int R1, vector<int>& nums2, int L2, int R2, int k) {
            if (L1 > R1) return nums2[L2 + k - 1];
            if (L2 > R2) return nums1[L1 + k - 1];
            int mid1 = L1 + (R1 - L1) / 2, mid2 = L2 + (R2 - L2);
    
            if (k <= (mid1 - L1) + (mid2 - L2) + 1) {
                if (nums1[mid1] <= nums2[mid2]) return findKth(nums1, L1, R1, nums2, L2, mid2 - 1, k);
                else return findKth(nums1, L1, mid1 - 1, nums2, L2, R2, k);
            } else {
                if (nums1[mid1] <= nums2[mid2]) return findKth(nums1, mid1 + 1, R1, nums2, L2, R2, k - (mid1 - L1) - 1);
                else return findKth(nums1, L1, R1, nums2, mid2 + 1, R2, k - (mid2 - L2) - 1);
            }
        }
    };
    

    Complexity

    I think we need to do binary search in both arrays, generally the complexity is log(m + n) for divide conquer. But actually it is not that much.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.