O(lg(m+n)) c++ solution using kth smallest number


  • 62
    P
    class Solution {
    public:
        int kth(int a[], int m, int b[], int n, int k) {
            if (m < n) return kth(b,n,a,m,k);
            if (n==0) return a[k-1];
            if (k==1) return min(a[0],b[0]);
    
            int j = min(n,k/2);
            int i = k-j;
            if (a[i-1] > b[j-1]) return kth(a,i,b+j,n-j,k-j);
            return kth(a+i,m-i,b,j,k-i);
        }
    
        double findMedianSortedArrays(int a[], int m, int b[], int n) {
            int k = (m+n)/2;
            int m1 = kth(a,m,b,n,k+1);
            if ((m+n)%2==0) {
                int m2 = kth(a,m,b,n,k);
                return ((double)m1+m2)/2.0;
            }
            return m1;
        }
    };

  • 4
    K

    Can you explain the kth function? I can't get why we need the B array to be always less than or equal A in length and also I can't understand "int j = min(n,k/2);"

    If you can explain it that would be great


  • 0
    P

    think about a case where A = {3,4,5,6,7} and B = {1,2}

    1. if (m < n) : the reason we have this is we are trying to force an invariant i+j = k,
    2. j = min(n,k/2) : we cant have j more than the array size.

    Try to think about corner cases , you will be able to come up with this. :)


  • 0
    Z

    kth smallest number seems like a more general solution.


  • 0
    T

    I am still a little bit confused of how it works,could you explain it more specificly??


  • 0
    Y

    Can anyone tell me what's the logic behind of this line "if (k==1) return min(a[0],b[0]);"?

    thanks


  • 0
    Q

    Actually I also like to use find kth smallest for this question. But here we can have O(lgk) complexity for finding kth smallest.

    int findKthSmallest(vector<int> &nums1, int i1, int i2, vector<int> &nums2, int j1, int j2, int k){
    	if (i2 - i1 > j2 - j1)  return findKthSmallest(nums2, j1, j2, nums1, i1, i2, k);
    	if (i1>i2)  return nums2[j1 + k - 1];
    	if (k == 1)  return std::min(nums1[i1], nums2[j1]);
    
    	if (i2-i1+1<k/2 || nums1[i1 + k/2 - 1]>nums2[j1 + k/2 - 1])
    		return findKthSmallest(nums1, i1, i2, nums2, j1 + k/2, j2, k-k/2);
    	else
    		return findKthSmallest(nums1, i1 + k/2, i2, nums2, j1, j2, k-k/2);
    }

  • 0
    C

    It's a good solution but I think time complexity should be O(logM+logN) not O(log(M+N)). Any comment? Thanks.


  • 13
    N

    Let me add some interpretation of the find kth function based on my understanding

    We have two arrays:

    nums1[0], nums1[1]....nums1[m - 1];

    nums2[0], nums2[2]....nums2[n - 1];

    the result after merging:

    num[0],num[1],num[2]...num[m + n - 1];

    Let‘s compare nums1[k / 2 - 1] and nums2[k / 2 - 1]

    if nums1[k / 2 - 1] < nums2 [k / 2 - 1]

    then the nums1[k / 2 - 1] and it's left side elements must smaller than kth number in num arrary(num[k - 1]).
    Why?
    Assume that nums1[k / 2 - 1] == num[k - 1];

    Let's count the number of elements which smaller than nums1[k / 2 - 1].

    Consider an extreme case : nums1[0]....nums1[k / 2 - 2] and nums2[0]...nums2[k / 2 - 2] smaller than nums1[k / 2 - 1];

    In this special case, we only have k / 2 - 1 + k / 2 - 1 = k - 2 elements smaller than the nums1[k / 2 - 1]. so nums1[k / 2 - 1] only can be (k - 1)th smallest number (num[k - 2]);
    So, it's a contradiction with our assumption.

    And now we could say, The num[k / 2 - 1] and it's left side elements must smaller than the Kth smallest number.
    so we could remove the elements which in this range and shrink the problem set.
    same idea when nums1[k / 2 - 1] > nums2 [k / 2 - 1]. we could remove the elements in the nums2;

    Correct me, if I'm wrong. Thanks

    Here is my AC code :

    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
    	int m = nums1.size(), n = nums2.size();
    	int k = (m + n) / 2;
    	int num1 = findKth(nums1, 0, m, nums2, 0, n, k + 1);
    	if ((n + m) % 2 == 0)
    	{
    		int num2 = findKth(nums1, 0, m, nums2, 0, n, k);
    		return (num1 + num2) / 2.0;
    	}
    	else return num1;
    }
    int findKth(vector<int> & nums1, int nums1_left, int nums1_right, vector<int> & nums2, int nums2_left, int nums2_right, int k)
    {
    	int m = nums1_right - nums1_left;
    	int n = nums2_right - nums2_left;
    	if (m > n) return findKth(nums2, nums2_left, nums2_right, nums1, nums1_left, nums1_right, k);
    	else if (m == 0)
    		return nums2[nums2_left + k - 1];
    	else if (k == 1)
    		return min(nums1[nums1_left], nums2[nums2_left]);
    	else {
    		int s1LeftCount = min (k / 2, m);
    		int s2LeftCount = k - s1LeftCount;
    		if (nums1[nums1_left + s1LeftCount - 1] == nums2[nums2_left + s2LeftCount - 1])
    			return nums1[nums1_left + s1LeftCount - 1];
    		else if (nums1[nums1_left + s1LeftCount - 1] < nums2[nums2_left + s2LeftCount - 1])
    			return findKth(nums1, nums1_left + s1LeftCount, nums1_right, nums2, nums2_left, nums2_right, k - s1LeftCount);
    		else
    		return findKth(nums1, nums1_left, nums1_right, nums2, nums2_left + s2LeftCount, nums2_right, k - s2LeftCount);
    	}
    }
    

  • 0
    C

    That is obvious, k = 1 means want to find the smallest one, a[0] is the smallest in array a, b[0] is the smallest in array b, the smaller one of a[0] and b[0] is the smallest in a U b.


  • 0
    C

    Did you consider the case that one of array's length is less than k/2?


  • 1
    A

    I think it's O(log(M+N)), since in the worst case during each recursion we can discard at least k/2 elements from the whole (M+N) elements, where k could be (M+N) at most.


  • 0
    Q

    Here is an iterative findKth, which is IMHO much cleaner The cleanest findKth, and wrote by myself.
    let me excerpt it here,

    template <typename RandomIt, typename T = typename std::iterator_traits<RandomIt>::value_type>
    T findKth(RandomIt first1, RandomIt last1, RandomIt first2, RandomIt last2, size_t k) {
        // k is 0-based, the key point is choose two prefix sub-array, and let the sum of
        // length as large as possile, -- up to k + 1 -- and skip the smaller one.
        size_t n1 = last1 - first1, n2 = last2 - first2;
        ++k;
        for (size_t a1, a2; k > 1 && n1 > 0 && n2 > 0; ) {
            a1 = std::min(k / 2, n1);
            a2 = std::min(k - a1, n2);
            if (first1[a1-1] <= first2[a2-1]) { first1 += a1, n1 -= a1, k -= a1; }
            else { first2 += a2, n2 -= a2, k -= a2; }
        }
        --k;
        if (n1 == 0) return first2[k];
        if (n2 == 0) return first1[k];
        return std::min(*first1, *first2);
    }

  • 0
    J

    Here is my 1-to-1 translated C++ version,

    class Solution {
        typedef decltype(vector<int>().cbegin()) random_it;
    
        int kth(random_it it_a, int size_a, random_it it_b, int size_b,
                int offset) {      // offset is 1-based
            if (size_a < size_b) { // size_a always >= size_b
                return kth(it_b, size_b, it_a, size_a, offset);
            }
            if (size_b == 0) { // obvious case
                return *(it_a + (offset - 1));
            }
            if (offset == 1) { // cannot reduce more
                return min(*it_a, *it_b);
            }
    
            // 'n' prefix indicates it is 1-based
            int nguess_b = min(size_b, offset / 2);
            // in oder to fit the offset, guess_a has to be that
            int nguess_a = offset - nguess_b;
            // we can safely say that vecotr_b part is definetly ahead of kth elem
            // since part a is bigger than part b
            if (*(it_a + (nguess_a - 1)) > *(it_b + (nguess_b - 1))) {
                // remove part b and change offset
                return kth(it_a, size_a, it_b + nguess_b, size_b - nguess_b,
                           offset - nguess_b);
            }
            // similarly, remove part a
            return kth(it_a + nguess_a, size_a - nguess_a, it_b, size_b,
                       offset - nguess_a);
        }
    
      public:
        double findMedianSortedArrays(vector<int> &nums1, vector<int> &nums2) {
            int target = (nums1.size() + nums2.size()) / 2;
            int a = kth(nums1.cbegin(), nums1.size(), nums2.cbegin(), nums2.size(),
                        target + 1);
            if ((nums1.size() + nums2.size()) % 2 == 0) {
                int b = kth(nums1.cbegin(), nums1.size(), nums2.cbegin(),
                            nums2.size(), target);
                return (a + b) / 2.0;
            }
            return a;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.