Java 9ms heap queue solution, k log(k)


  • 53

    Frist, we take the first k elements of nums1 and paired with nums2[0] as the starting pairs so that we have (0,0), (1,0), (2,0),.....(k-1,0) in the heap.
    Each time after we pick the pair with min sum, we put the new pair with the second index +1. ie, pick (0,0), we put back (0,1). Therefore, the heap alway maintains at most min(k, len(nums1)) elements.

    public class Solution {
        class Pair{
            int[] pair;
            int idx; // current index to nums2
            long sum;
            Pair(int idx, int n1, int n2){
                this.idx = idx;
                pair = new int[]{n1, n2};
                sum = (long) n1 + (long) n2;
            }
        }
        class CompPair implements Comparator<Pair> {
            public int compare(Pair p1, Pair p2){
                return Long.compare(p1.sum, p2.sum);
            }
        }
        public List<int[]> kSmallestPairs(int[] nums1, int[] nums2, int k) {
            List<int[]> ret = new ArrayList<>();
            if (nums1==null || nums2==null || nums1.length ==0 || nums2.length ==0) return ret;
            int len1 = nums1.length, len2=nums2.length;  
    
            PriorityQueue<Pair> q = new PriorityQueue(k, new CompPair()); 
            for (int i=0; i<nums1.length && i<k ; i++) { // only need first k number in nums1 to start  
                q.offer( new Pair(0, nums1[i],nums2[0]) );
            }
            for (int i=1; i<=k && !q.isEmpty(); i++) { // get the first k sums
                Pair p = q.poll(); 
                ret.add( p.pair );
                if (p.idx < len2 -1 ) { // get to next value in nums2
                    int next = p.idx+1;
                    q.offer( new Pair(next, p.pair[0], nums2[next]) );
                }
            }
            return ret;
        }
    }
    

  • 0
    A

    Well done, I personally made a mistake operating the queue such way that when I fetched the next minimum value from queue I put two pairs there: (i + 1, j) and (i, j + 1) that led to duplicates and had to check for duplicates using Set which was not good. Though it passed :)


  • 19

    @anton4 To get the idea clear, you can think that this is the problem to merge k sorted arrays.
    array1 = (0,0),(0,1),(0,2),....
    array2 = (1,0),(1,1),(1,2),....
    ....
    arrayk = (k-1,0),(k-1,1),(k-1,2),....
    So, each time when an array is chosen having the smallest sum, you only move its index to next one of this array.
    Make sense?


  • 1
    O

    Great solution, below I share my implementation, 7ms 96.5%

    public class Solution {
        private class Pair implements Comparable<Pair>{
            int[] a1, a2;
            int i1, i2;
            
            public Pair(int j1, int j2, int[] arr1, int[] arr2) {
                a1 = arr1;
                a2 = arr2;
                i1 = j1;
                i2 = j2;
            }
            
            public int sum() {
                return a1[i1] + a2[i2];
            }
            
            public int compareTo(Pair p) {
                return this.sum() - p.sum();
            }
        }
        
        public List<int[]> kSmallestPairs(int[] nums1, int[] nums2, int k) {
            int n1 = Math.min(nums1.length, k), n2 = nums2.length, sum = 0;
            List<int[]> ans = new ArrayList();
            if (n1 * n2 == 0) return ans;
            PriorityQueue<Pair> pq = new PriorityQueue();
            for (int i = 0; i < n1; pq.offer(new Pair(i++, 0, nums1, nums2)));
            
            for (int t = 0; t < k && !pq.isEmpty(); t++) {
                Pair p = pq.poll();
                ans.add(new int[]{nums1[p.i1], nums2[p.i2]});
                if (++p.i2 < n2) pq.offer(p);
            }
            return ans;
        }
    }
    

  • 1
    I

    well done, simple and effective, here is my c++ version

    vector<pair<int, int>> kSmallestPairs(vector<int>& a1, vector<int>& a2, int k)
    {
    	auto compare = [&a1, &a2](pair<int, int>i, pair<int, int>j)
    	{
    	return a1[i.first] + a2[i.second] < a1[j.first] + a2[j.second];
    	};
    	multiset<pair<int, int>, decltype(compare)> minSet(compare);
    	vector<pair<int, int>> rtn;
    	int len = k > a1.size() ?  a1.size() : k;
    	for (int i = 0; i < len; i++)
    		minSet.insert(make_pair(i, 0));
    
    	while(!minSet.empty() && rtn.size() < k)
    	{
    		auto ans = *minSet.begin();
    		rtn.push_back(make_pair(a1[ans.first], a2[ans.second]));
    		minSet.erase(minSet.begin());
    		if (ans.second + 1 < a2.size())
    			minSet.insert(make_pair(ans.first, ans.second + 1));
    	}
    	return rtn;
    }
    

  • 1
    E

    Great solution to reduce the size of the heap to just k. Thanks for sharing!


  • 0
    O

    Any idea how to prove the correctness? Intuitively it makes perfect sense to me, but I have no idea how to rigorously prove it. Thx.


  • 0
    D

    @yubad2000 It's awesome that you can think in this way. Very inspiring!


  • 0
    H

    c++ version:

    vector<pair<int, int>> kSmallestPairs(vector<int>& nums1, vector<int>& nums2, int k) {
    	vector<pair<int, int>> res;
    	if(nums1.empty()||nums2.empty()) return res;
    	auto comp = [nums1,nums2](pair<int, int>&a, pair<int, int>&b) {
    		return nums1[a.first] + nums2[a.second] > nums1[b.first] + nums2[b.second];
    	};
    	priority_queue<pair<int, int>, vector<pair<int, int>>, decltype(comp)> q(comp);
    	for (int i = 0;i < nums1.size() && i < k;i++) {
    		q.push({ i, 0 });
    	}
    	for (int i = 0;i < k&&(!q.empty());i++) {
    		pair<int, int> temp = q.top();
    		q.pop();
    		res.push_back({nums1[temp.first], nums2[temp.second]});
    		if (temp.second < nums2.size() - 1) q.push({ temp.first,temp.second + 1 });
    	}
    	return res;
    }
    

  • 0

    This solution and the top rated one reply on such an important assumption which is correct:

    The k smallest sum pairs must be within nums1[0....min(nums1.len, k)-1] and nums2[0......min(nums2.len,k)-1].
    

    Previously, I struggled is it possible for nums1[0] + nums2[k+1] can be a candidate? This cannot be, because nums1[0] + nums2[0....k] always smaller than nums1[0] + nums2[k+1]. This is not a formal proof, but you can generalise it.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.