Share my C++ O(m^2+ n^2 + k^2) solution


  • 0
    M

    First step, calculate two dp matrixes called rangeMaxIdx1 and rangeMaxIdx2. The (i, j) element of rangeMaxIdx1 stores the index of the largest number from nums1[i] to nums1[j], and so does rangeMaxIdx2. This step takes O(m^2 + n^2) time.

    Second step, loop i from 0 to k. In each iteration, two vectors representing "the maximum number with length i from nums1" and "the maximum number with length k-i from nums2" are generated. This is done by calling a function called maxNumber1D and it takes O(k) time on average based on the usage of two dp matrixes. Next, merge two vectors (length i and length k - i, respectively) into a new vector with length k. This step takes O(k) time on average. Lastly, the maximum number are recorded during the iteration by comparing the old maximum number with the newly generated number. This step also takes O(k) time. Therefore, the time complexity in the second step takes O(k^2) time on average.

    In conclusion, the whole function has the time complexity of O(m^2 + n^2 + k^2) and space complexity of O(m^2 + n^2). It is worth noting that if nums1 and nums2 have a lot of common digits as well as duplicates (e.g. nums1={1,1,1,1,1,1,2}, nums2={1,1,1,1,1,1,3}), the worse case of the second step may become O(k^3) since merging two vectors will take a lot extra time. Additional optimizations may solve this issue..

    class Solution {
    public:
        vector<int> maxNumber(vector<int>& nums1, vector<int>& nums2, int k) {
            int m = nums1.size(), n = nums2.size();
            
            vector<vector<int>> rangeMaxIdx1(m, vector<int>(m, -1));
            vector<vector<int>> rangeMaxIdx2(n, vector<int>(n, -1));
            
            createRangeMaxIdx(rangeMaxIdx1, m, nums1);
            createRangeMaxIdx(rangeMaxIdx2, n, nums2);
            
            vector<int> maxSolution(k, -1);
            for (int i = 0; i <= k; i++)
            {
                int j = k - i;
                if (i > m || j > n)
                    continue;
                
                vector<int> partialSolution1 = maxNumber1D(nums1, i, rangeMaxIdx1);
                vector<int> partialSolution2 = maxNumber1D(nums2, j, rangeMaxIdx2);
                
                vector<int> solution = mergePartial(partialSolution1, partialSolution2);
                
                int p;
                for (p = 0; p < k; p++)
                {
                    if (solution[p] < maxSolution[p])
                        break;
                    else if (solution[p] > maxSolution[p])
                    {
                        maxSolution = solution;
                        break;
                    }
                }
            }
            
            return maxSolution;
        }
        
        void createRangeMaxIdx(vector<vector<int>>& res, int dim, vector<int>& nums)
        {
            for (int i = 0; i < dim; i++)
            {
                for (int j = i; j < dim; j++)
                {
                    if (j == i) 
                        res[i][j] = i;
                    else if (nums[j] > nums[res[i][j - 1]])
                        res[i][j] = j;
                    else
                        res[i][j] = res[i][j - 1];
                }
            }
        }
        
        vector<int> maxNumber1D(vector<int>& nums, int k, vector<vector<int>>& rangeMaxIdx)
        {
            int n = nums.size();
            int p1 = 0, p2 = n - k;
            vector<int> res;
            
            for (int i = 0; i < k; i++)
            {
                res.push_back(nums[rangeMaxIdx[p1][p2]]);
                p1 = rangeMaxIdx[p1][p2] + 1;
                p2++;
            }
            
            return res;
        }
        
        vector<int> mergePartial(vector<int>& nums1, vector<int>& nums2)
        {
            int m = nums1.size(), n = nums2.size();
            vector<int> res(m + n);
            
            int p = 0;
            int p1 = 0, p2 = 0;
            while (p < m + n)
            {
                bool moveP1;
    
                if (p1 < m && p2 < n)
                {
                    if (nums1[p1] == nums2[p2])
                    {
                        int q = 0;
                        while (p1 + q < m && p2 + q < n && nums1[p1 + q] == nums2[p2 + q])
                            q++;
                        
                        if (p1 + q >= m && p2 + q < n)
                            moveP1 = false;
                        else if (p2 + q >= n)
                            moveP1 = true;
                        else 
                            moveP1 = (nums1[p1 + q] > nums2[p2 + q]);
                    }
                    else moveP1 = (nums1[p1] > nums2[p2]);
                }
                else moveP1 = (p2 >= n);
                
                if (moveP1)
                    res[p++] = nums1[p1++];
                else
                    res[p++] = nums2[p2++];
            }
            
            return res;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.