# JAVA 117ms, beat 99.81%, merge sort

• ``````/*
* If # of columns is smaller, process one set of columns [i..j) at a time, for each different i<j.
* For one set of colums [i..j), do it like "Count of Range Sum".
* O(n) = n^2 * mlogm.
* Assume we have such result.
*/
public class Solution {
public int maxSumSubmatrix(int[][] matrix, int k) {
int m = matrix.length, n = matrix[0].length, ans = Integer.MIN_VALUE;
long[] sum = new long[m+1]; // stores sum of rect[0..p][i..j]
for (int i = 0; i < n; ++i) {
long[] sumInRow = new long[m];
for (int j = i; j < n; ++j) { // for each rect[*][i..j]
for (int p = 0; p < m; ++p) {
sumInRow[p] += matrix[p][j];
sum[p+1] = sum[p] + sumInRow[p];
}
ans = Math.max(ans, mergeSort(sum, 0, m+1, k));
if (ans == k) return k;
}
}
return ans;
}
int mergeSort(long[] sum, int start, int end, int k) {
if (end == start+1) return Integer.MIN_VALUE; // need at least 2 to proceed
int mid = start + (end - start)/2, cnt = 0;
int ans = mergeSort(sum, start, mid, k);
if (ans == k) return k;
ans = Math.max(ans, mergeSort(sum, mid, end, k));
if (ans == k) return k;
long[] cache = new long[end-start];
for (int i = start, j = mid, p = mid; i < mid; ++i) {
while (j < end && sum[j] - sum[i] <= k) ++j;
if (j-1 >= mid) {
ans = Math.max(ans, (int)(sum[j-1] - sum[i]));
if (ans == k) return k;
}
while (p < end && sum[p] < sum[i]) cache[cnt++] = sum[p++];
cache[cnt++] = sum[i];
}
System.arraycopy(cache, 0, sum, start, cnt);
return ans;
}
}
``````

• Great solution to use merge sort to get the subset with sum <= k, the time complexity is O(NNMlogM).

• This post is deleted!

• Great merge sort solution. Thanks. Here is my C++ rewrite. Similar problems can be solved using merge sort:
Count of Range Sum
Reverse Pairs
Count of Smaller Numbers After Self

``````class Solution {
public:
int maxSumSubmatrix(vector<vector<int>>& matrix, int k) {
int m = matrix.size();
int n = m ? matrix[0].size() : 0;
int res = INT_MIN;
vector<long long> sums(m + 1, 0);

for (int l = 0; l < n; ++l) {
vector<long long>sumInRow(m, 0);
for (int r = l; r < n; ++r) {
for (int i = 0; i < m; ++i) {
sumInRow[i] += matrix[i][r];
sums[i + 1] = sums[i] + sumInRow[i];
}
res = max(res, mergeSort(sums, 0, m + 1, k));
if (res == k) return k;
}
}

return res;
}

int mergeSort(vector<long long>& sums, int start, int end, int k) {
if (end == start + 1) return INT_MIN;

int mid = start + (end - start) / 2;
int res = mergeSort(sums, start, mid, k);
if (res == k) return k;

res = max(res, mergeSort(sums, mid, end, k));
if (res == k) return k;

long long cache[end - start];

int j = mid, c = 0, t = mid;
for (int i = start; i < mid; ++i) {
while (j < end && sums[j] - sums[i] <= k) ++j; /* search first time sums[j] - sums[i] > k */
if (j - 1 >= mid) { /* sums[j - 1] - sums[i] <= k, make sure j - 1 is still in right side */
res = max(res, (int)(sums[j - 1] - sums[i]));
if (res == k) return k;
}
while (t < end && sums[t] < sums[i]) {
cache[c++] = sums[t++];
}
cache[c++] = sums[i];
}

for (int i = start; i < t; ++i) {
sums[i] = cache[i - start];
}

return res;
}
};
``````

• Thanks for such great solution.
Many problems asking for sum within a range in an unsorted collection of numbers could be solved by this solution.

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.