JAVA 117ms, beat 99.81%, merge sort


  • 19
    I
    /*
     * If # of columns is smaller, process one set of columns [i..j) at a time, for each different i<j.
     * For one set of colums [i..j), do it like "Count of Range Sum".
     * O(n) = n^2 * mlogm.
     * Assume we have such result.
     */
    public class Solution {
        public int maxSumSubmatrix(int[][] matrix, int k) {
            int m = matrix.length, n = matrix[0].length, ans = Integer.MIN_VALUE;
            long[] sum = new long[m+1]; // stores sum of rect[0..p][i..j]
            for (int i = 0; i < n; ++i) {
                long[] sumInRow = new long[m];
                for (int j = i; j < n; ++j) { // for each rect[*][i..j]
                    for (int p = 0; p < m; ++p) {
                        sumInRow[p] += matrix[p][j];
                        sum[p+1] = sum[p] + sumInRow[p];
                    }
                    ans = Math.max(ans, mergeSort(sum, 0, m+1, k));
                    if (ans == k) return k;
                }
            }
            return ans;
        }
        int mergeSort(long[] sum, int start, int end, int k) {
            if (end == start+1) return Integer.MIN_VALUE; // need at least 2 to proceed
            int mid = start + (end - start)/2, cnt = 0;
            int ans = mergeSort(sum, start, mid, k);
            if (ans == k) return k;
            ans = Math.max(ans, mergeSort(sum, mid, end, k));
            if (ans == k) return k;
            long[] cache = new long[end-start];
            for (int i = start, j = mid, p = mid; i < mid; ++i) {
                while (j < end && sum[j] - sum[i] <= k) ++j;
                if (j-1 >= mid) {
                    ans = Math.max(ans, (int)(sum[j-1] - sum[i]));
                    if (ans == k) return k;
                }
                while (p < end && sum[p] < sum[i]) cache[cnt++] = sum[p++];
                cache[cnt++] = sum[i];
            }
            System.arraycopy(cache, 0, sum, start, cnt);
            return ans;
        }
    }
    

  • 0

    Great solution to use merge sort to get the subset with sum <= k, the time complexity is O(NNMlogM).


  • 0
    F
    This post is deleted!

  • 0

    Great merge sort solution. Thanks. Here is my C++ rewrite. Similar problems can be solved using merge sort:
    Count of Range Sum
    Reverse Pairs
    Count of Smaller Numbers After Self

    class Solution {
    public:
        int maxSumSubmatrix(vector<vector<int>>& matrix, int k) {
            int m = matrix.size();
            int n = m ? matrix[0].size() : 0;
            int res = INT_MIN;
            vector<long long> sums(m + 1, 0);
            
            for (int l = 0; l < n; ++l) {
                vector<long long>sumInRow(m, 0);
                for (int r = l; r < n; ++r) {
                    for (int i = 0; i < m; ++i) {
                        sumInRow[i] += matrix[i][r];
                        sums[i + 1] = sums[i] + sumInRow[i];
                    }
                    res = max(res, mergeSort(sums, 0, m + 1, k));
                    if (res == k) return k;
                }
            }
            
            return res;
        }
        
        int mergeSort(vector<long long>& sums, int start, int end, int k) {
            if (end == start + 1) return INT_MIN;
        
            int mid = start + (end - start) / 2;
            int res = mergeSort(sums, start, mid, k);
            if (res == k) return k;
            
            res = max(res, mergeSort(sums, mid, end, k));
            if (res == k) return k;
                
            long long cache[end - start];
            
            int j = mid, c = 0, t = mid;
            for (int i = start; i < mid; ++i) {
                while (j < end && sums[j] - sums[i] <= k) ++j; /* search first time sums[j] - sums[i] > k */
                if (j - 1 >= mid) { /* sums[j - 1] - sums[i] <= k, make sure j - 1 is still in right side */
                    res = max(res, (int)(sums[j - 1] - sums[i]));
                    if (res == k) return k;
                }
                while (t < end && sums[t] < sums[i]) {
                    cache[c++] = sums[t++];
                }
                cache[c++] = sums[i];
            }
            
            for (int i = start; i < t; ++i) {
                sums[i] = cache[i - start];
            }
            
            return res;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.