C++ O(k)-time O(k)-space solution


  • 0

    This solution is a combination of this thread and this thread.

    Time Complexity: O(k)
    Space Complexity: O(k)

    C++ Accepted Code (13ms):

    class Solution
    {
    private:
        // select two elements from four elements, recursively
        std::array<int, 2> biSelect(
            const std::vector<int> & nums1,
            const std::vector<int> & nums2,
            const std::vector<std::size_t> & indices,
            const std::array<std::size_t, 2> & ks)
            // Select both ks[0]-th element and ks[1]-th element in the virtual matrix,
            // where k0 = ks[0] and k1 = ks[1] and n = indices.size() satisfie
            // 0 <= k0 <= k1 < n*n  and  k1 - k0 <= 4n-4 = O(n)   and  n>=2
        {
            const std::size_t & n = indices.size();
            if (n == 2u) // base case of resursion
            {
                return biSelectNative(nums1, nums2, indices, ks);
            }
    
            // update indices
            std::vector<std::size_t> indices_;
            for (std::size_t idx = 0; idx < n; idx += 2)
            {
                indices_.push_back(indices[idx]);
            }
            if (n % 2 == 0) // ensure the last indice is included
            {
                indices_.push_back(indices.back());
            }
    
            // update ks
            // the new interval [xs_[0], xs_[1]] should contain [xs[0], xs[1]]
            // but the length of the new interval should be as small as possible
            // therefore, ks_[0] is the largest possible index to ensure xs_[0] <= xs[0]
            // ks_[1] is the smallest possible index to ensure xs_[1] >= xs[1]
            std::array<std::size_t, 2> ks_ = { ks[0] / 4 };
            if (n % 2 == 0) // even
            {
                ks_[1] = ks[1] / 4 + n + 1;
            }
            else // odd
            {
                ks_[1] = (ks[1] + 2 * n + 1) / 4;
            }
    
            // call recursively
            std::array<int, 2> xs_ = biSelect(nums1, nums2, indices_, ks_);
    
            // Now we partipate all elements into three parts:
            // Part 1: {e : e < xs_[0]}.  For this part, we only record its cardinality
            // Part 2: {e : xs_[0] <= e < xs_[1]}. We store the set elementsBetween
            // Part 3: {e : x >= xs_[1]}. No use. Discard.
            std::array<std::size_t, 2> numbersOfElementsLessThanX = { 0, 0 };
            std::vector<int> elementsBetween; // [xs_[0], xs_[1])
    
            std::array<std::size_t, 2> cols = { indices.size(), indices.size() }; // column index such that elem >= x
            // the first column where matrix(r, c) > b
            // the first column where matrix(r, c) >= a
            for (std::size_t row = 0; row < n; ++row)
            {
                const std::size_t & row_indice = indices[row];
                for (std::size_t idx : {0, 1})
                {
                    while ((cols[idx] > 0)
                        && ((indices[cols[idx] - 1] >= nums2.size()) ||
                        (nums1[row_indice] + nums2[indices[cols[idx] - 1]] >= xs_[idx])))
                    {
                        --cols[idx];
                    }
                    numbersOfElementsLessThanX[idx] += cols[idx];
                }
                for (std::size_t col = cols[0]; col < cols[1]; ++col)
                {
                    elementsBetween.push_back(nums1[row_indice] + nums2[indices[col]]);
                }
            }
    
            std::array<int, 2> xs; // the return value
            for (std::size_t idx : {0, 1})
            {
                std::size_t k = ks[idx];
                if (k < numbersOfElementsLessThanX[0]) // in the Part 1
                {
                    xs[idx] = xs_[0];
                }
                else if (k < numbersOfElementsLessThanX[1]) // in the Part 2
                {
                    std::size_t offset = k - numbersOfElementsLessThanX[0];
                    std::vector<int>::iterator nth = std::next(elementsBetween.begin(), offset);
                    std::nth_element(elementsBetween.begin(), nth, elementsBetween.end());
                    xs[idx] = (*nth);
                }
                else // in the Part 3
                {
                    xs[idx] = xs_[1];
                }
            }
            return xs;
        }
    
        // select two elements from four elements, using native way
        std::array<int, 2> biSelectNative(
            const std::vector<int> & nums1,
            const std::vector<int> & nums2,
            const std::vector<std::size_t> & indices,
            const std::array<std::size_t, 2> & ks)
        {
            std::vector<int> allElements;
            for (size_t idx1 : indices)
            {
                for (size_t idx2 : indices)
                {
                    if (idx2 >= nums2.size())
                    {
                        break;
                    }
                    allElements.push_back(nums1[idx1] + nums2[idx2]);
                }
            }
            std::sort(allElements.begin(), allElements.end());
            std::array<int, 2> results;
            for (std::size_t idx : {0, 1})
            {
                if (ks[idx] < allElements.size()) {
                    results[idx] = allElements[ks[idx]];
                }
                else
                {
                    results[idx] = std::numeric_limits<int>::max();
                }
            }
            return results;
        }
    public:
        std::vector<std::pair<int, int>> kSmallestPairs(
            std::vector<int>& nums1, std::vector<int>& nums2, int k) {
            std::vector<std::pair<int, int>> results;
            if (nums1.size() * nums2.size() <= static_cast<size_t>(k)) {
                // elements not enough, return all
                for (int num1 : nums1) {
                    for (int num2 : nums2) {
                        results.emplace_back(num1, num2);
                    }
                }
                return results;
            }
            
            // make sure nums1 is longer than nums2
            bool swapped = false;
            if (nums1.size() < nums2.size()) {
                swap(nums1, nums2);
                swapped = true;
            }
            
            std::size_t n = std::min(static_cast<size_t>(k), nums1.size());
            std::vector<std::size_t> indices(nums1.size());
            std::iota(indices.begin(), indices.end(), 0);
            std::array<std::size_t, 2> ks = { k - 1, k - 1 }; // use zero-based indices
            std::array<int, 2> kthValues = biSelect(nums1, nums2, indices, ks);
            const int & kthValue = kthValues.front();
            std::vector<std::size_t> idx2uppers(n);
    
            // collect elements that are less than KthValue
            for (std::size_t idx1 = 0; idx1 < n; ++idx1) {
                for (std::size_t idx2 = 0;; ++idx2) { // this loop is not O(nums2.size()) since it may break early
                    if ((idx2 == nums2.size()) || (nums1[idx1] + nums2[idx2] >= kthValue)) {
                        idx2uppers[idx1] = idx2;
                        break;
                    }
                    results.emplace_back(nums1[idx1], nums2[idx2]);
                }
            }
    
            // collect elements that equals kthValue
            for (std::size_t idx1 = 0; idx1 < n; ++idx1) {
                for (std::size_t idx2 = idx2uppers[idx1]; idx2 < nums2.size(); ++idx2) {
                    if (results.size() == k) {
                        goto RETURN;
                    }
                    if (nums1[idx1] + nums2[idx2] > kthValue) {
                        break;
                    }
                    results.emplace_back(nums1[idx1], nums2[idx2]);
                }
            }
    
        RETURN:
            if (swapped)
            {
                std::for_each(results.begin(), results.end(),
                    [](std::pair<int, int> & p){ std::swap(p.first, p.second); });
            }
            return results;
        }
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.