# O(lg(m+n)) c++ solution using kth smallest number

• ``````class Solution {
public:
int kth(int a[], int m, int b[], int n, int k) {
if (m < n) return kth(b,n,a,m,k);
if (n==0) return a[k-1];
if (k==1) return min(a[0],b[0]);

int j = min(n,k/2);
int i = k-j;
if (a[i-1] > b[j-1]) return kth(a,i,b+j,n-j,k-j);
return kth(a+i,m-i,b,j,k-i);
}

double findMedianSortedArrays(int a[], int m, int b[], int n) {
int k = (m+n)/2;
int m1 = kth(a,m,b,n,k+1);
if ((m+n)%2==0) {
int m2 = kth(a,m,b,n,k);
return ((double)m1+m2)/2.0;
}
return m1;
}
};``````

• Can you explain the kth function? I can't get why we need the B array to be always less than or equal A in length and also I can't understand "int j = min(n,k/2);"

If you can explain it that would be great

• think about a case where A = {3,4,5,6,7} and B = {1,2}

1. if (m < n) : the reason we have this is we are trying to force an invariant i+j = k,
2. j = min(n,k/2) : we cant have j more than the array size.

Try to think about corner cases , you will be able to come up with this. :)

• kth smallest number seems like a more general solution.

• I am still a little bit confused of how it works,could you explain it more specificly??

• Can anyone tell me what's the logic behind of this line "if (k==1) return min(a[0],b[0]);"?

thanks

• Actually I also like to use find kth smallest for this question. But here we can have O(lgk) complexity for finding kth smallest.

``````int findKthSmallest(vector<int> &nums1, int i1, int i2, vector<int> &nums2, int j1, int j2, int k){
if (i2 - i1 > j2 - j1)  return findKthSmallest(nums2, j1, j2, nums1, i1, i2, k);
if (i1>i2)  return nums2[j1 + k - 1];
if (k == 1)  return std::min(nums1[i1], nums2[j1]);

if (i2-i1+1<k/2 || nums1[i1 + k/2 - 1]>nums2[j1 + k/2 - 1])
return findKthSmallest(nums1, i1, i2, nums2, j1 + k/2, j2, k-k/2);
else
return findKthSmallest(nums1, i1 + k/2, i2, nums2, j1, j2, k-k/2);
}``````

• It's a good solution but I think time complexity should be O(logM+logN) not O(log(M+N)). Any comment? Thanks.

• Let me add some interpretation of the find kth function based on my understanding

We have two arrays:

nums1[0], nums1[1]....nums1[m - 1];

nums2[0], nums2[2]....nums2[n - 1];

the result after merging:

num[0],num[1],num[2]...num[m + n - 1];

Let‘s compare `nums1[k / 2 - 1]` and `nums2[k / 2 - 1]`

if `nums1[k / 2 - 1] < nums2 [k / 2 - 1]`

then the `nums1[k / 2 - 1]` and it's `left side elements` must smaller than `k`th number in num arrary(`num[k - 1]`).
Why?
Assume that `nums1[k / 2 - 1] == num[k - 1]`;

Let's count the number of elements which smaller than nums1[k / 2 - 1].

Consider an extreme case : nums1[0]....nums1[k / 2 - 2] and nums2[0]...nums2[k / 2 - 2] smaller than nums1[k / 2 - 1];

In this special case, we only have k / 2 - 1 + k / 2 - 1 = `k - 2` elements smaller than the nums1[k / 2 - 1]. so nums1[k / 2 - 1] only can be `(k - 1)`th smallest number (num[k - 2]);
So, it's a contradiction with our assumption.

And now we could say, `The num[k / 2 - 1] and it's left side elements must smaller than the Kth smallest number.`
so we could remove the elements which in this range and shrink the problem set.
same idea when nums1[k / 2 - 1] > nums2 [k / 2 - 1]. we could remove the elements in the nums2;

Correct me, if I'm wrong. Thanks

Here is my AC code :

``````double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size(), n = nums2.size();
int k = (m + n) / 2;
int num1 = findKth(nums1, 0, m, nums2, 0, n, k + 1);
if ((n + m) % 2 == 0)
{
int num2 = findKth(nums1, 0, m, nums2, 0, n, k);
return (num1 + num2) / 2.0;
}
else return num1;
}
int findKth(vector<int> & nums1, int nums1_left, int nums1_right, vector<int> & nums2, int nums2_left, int nums2_right, int k)
{
int m = nums1_right - nums1_left;
int n = nums2_right - nums2_left;
if (m > n) return findKth(nums2, nums2_left, nums2_right, nums1, nums1_left, nums1_right, k);
else if (m == 0)
return nums2[nums2_left + k - 1];
else if (k == 1)
return min(nums1[nums1_left], nums2[nums2_left]);
else {
int s1LeftCount = min (k / 2, m);
int s2LeftCount = k - s1LeftCount;
if (nums1[nums1_left + s1LeftCount - 1] == nums2[nums2_left + s2LeftCount - 1])
return nums1[nums1_left + s1LeftCount - 1];
else if (nums1[nums1_left + s1LeftCount - 1] < nums2[nums2_left + s2LeftCount - 1])
return findKth(nums1, nums1_left + s1LeftCount, nums1_right, nums2, nums2_left, nums2_right, k - s1LeftCount);
else
return findKth(nums1, nums1_left, nums1_right, nums2, nums2_left + s2LeftCount, nums2_right, k - s2LeftCount);
}
}
``````

• That is obvious, k = 1 means want to find the smallest one, a[0] is the smallest in array a, b[0] is the smallest in array b, the smaller one of a[0] and b[0] is the smallest in a U b.

• Did you consider the case that one of array's length is less than k/2?

• I think it's O(log(M+N)), since in the worst case during each recursion we can discard at least k/2 elements from the whole (M+N) elements, where k could be (M+N) at most.

• Here is an iterative findKth, which is IMHO much cleaner The cleanest findKth, and wrote by myself.
let me excerpt it here,

``````template <typename RandomIt, typename T = typename std::iterator_traits<RandomIt>::value_type>
T findKth(RandomIt first1, RandomIt last1, RandomIt first2, RandomIt last2, size_t k) {
// k is 0-based, the key point is choose two prefix sub-array, and let the sum of
// length as large as possile, -- up to k + 1 -- and skip the smaller one.
size_t n1 = last1 - first1, n2 = last2 - first2;
++k;
for (size_t a1, a2; k > 1 && n1 > 0 && n2 > 0; ) {
a1 = std::min(k / 2, n1);
a2 = std::min(k - a1, n2);
if (first1[a1-1] <= first2[a2-1]) { first1 += a1, n1 -= a1, k -= a1; }
else { first2 += a2, n2 -= a2, k -= a2; }
}
--k;
if (n1 == 0) return first2[k];
if (n2 == 0) return first1[k];
return std::min(*first1, *first2);
}``````

• Here is my 1-to-1 translated C++ version,

``````class Solution {
typedef decltype(vector<int>().cbegin()) random_it;

int kth(random_it it_a, int size_a, random_it it_b, int size_b,
int offset) {      // offset is 1-based
if (size_a < size_b) { // size_a always >= size_b
return kth(it_b, size_b, it_a, size_a, offset);
}
if (size_b == 0) { // obvious case
return *(it_a + (offset - 1));
}
if (offset == 1) { // cannot reduce more
return min(*it_a, *it_b);
}

// 'n' prefix indicates it is 1-based
int nguess_b = min(size_b, offset / 2);
// in oder to fit the offset, guess_a has to be that
int nguess_a = offset - nguess_b;
// we can safely say that vecotr_b part is definetly ahead of kth elem
// since part a is bigger than part b
if (*(it_a + (nguess_a - 1)) > *(it_b + (nguess_b - 1))) {
// remove part b and change offset
return kth(it_a, size_a, it_b + nguess_b, size_b - nguess_b,
offset - nguess_b);
}
// similarly, remove part a
return kth(it_a + nguess_a, size_a - nguess_a, it_b, size_b,
offset - nguess_a);
}

public:
double findMedianSortedArrays(vector<int> &nums1, vector<int> &nums2) {
int target = (nums1.size() + nums2.size()) / 2;
int a = kth(nums1.cbegin(), nums1.size(), nums2.cbegin(), nums2.size(),
target + 1);
if ((nums1.size() + nums2.size()) % 2 == 0) {
int b = kth(nums1.cbegin(), nums1.size(), nums2.cbegin(),
nums2.size(), target);
return (a + b) / 2.0;
}
return a;
}
};
``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.