# C++ "merge sort" solution modified for sorted matrix

• The basic idea is to do a merge sort the sorted rows, but there are some optimizations:

• If k > n*n/2, find the largest (n*n-k+1)th element instead;
• Since the columns are also sorted, remember where the elements in the previously merged row is inserted, and only scan (where to insert the current element) from the insert location of its upper element;
• Stop the merge sort if we already have enough info (we know for sure that the following elements will be inserted after k).
``````class Solution {
public:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int height = matrix.size();
int width = matrix[0].size();
// choose the closer direction (find the kth smallest, or the (height*width-k+1)th largest)
if (k > height*width/2)
{
return kthlarge(matrix, height*width-k+1);
}
return kthsmall(matrix, k);
}

// find the kth smallest
int kthsmall(vector<vector<int>>& matrix, int k)
{
vector<int> sorted; // the sorted array
int height = matrix.size();
int width = matrix[0].size();
int * preloc = new int[width]; // array of the insert locations of the elements in the previous row
// put the first row in the sorted array for merge
for (int w = 0; w < width; w++)
{
sorted.push_back(matrix[0][w]);
preloc[w] = w;
}
//merge by rows until we have the first k smallest elements
for (int h = 1; h < height; h++)
{
// if x[m-1][n] was inserted in as the pth in the sorted array,
// then x[m][n], x[m][n+1], ... has to be inserted after location p
int insertcur = preloc[0]+1;
int candcur = 0;
if (insertcur >= k) // we do not care about the ones that should be inserted after the kth element
break;
while(candcur < width)
{
if (insertcur >= k) // we do not care about the ones that should be inserted after the kth element
break;
if (insertcur == sorted.size()) // append the remaining to the end of the array
{
sorted.push_back(matrix[h][candcur]);
candcur++;
insertcur++;
continue;
}
if (matrix[h][candcur] <= sorted[insertcur]) // insert the element
{
sorted.insert(sorted.begin()+insertcur, matrix[h][candcur]);
preloc[candcur] = insertcur;
candcur++;
// find the next insert cursor location
insertcur = (insertcur + 1 > preloc[candcur]+candcur)?(insertcur + 1) : (preloc[candcur]+candcur);
continue;
}
insertcur++;
}
}
delete [] preloc;
// the smallest k elements are in sorted[0] to sorted[k-1], and in sorted order
return sorted[k-1];
}

// find the kth largest, basically the same as kthsmall
int kthlarge(vector<vector<int>>& matrix, int k)
{
vector<int> sorted; // the sorted array
int height = matrix.size();
int width = matrix[0].size();
int * preloc = new int[width];
for (int w = width-1; w >=0; w--)
{
sorted.push_back(matrix[height-1][w]);
preloc[w] = width-w-1;
}
for (int h = height-2; h >= 0; h--)
{
int insertcur = preloc[width-1]+1;
int candcur = width-1;
if (insertcur >= k)
break;
while(candcur >= 0)
{
if (insertcur >= k)
break;
if (insertcur == sorted.size())
{
sorted.push_back(matrix[h][candcur]);
candcur--;
insertcur++;
continue;
}
if (matrix[h][candcur] >= sorted[insertcur])
{
sorted.insert(sorted.begin()+insertcur, matrix[h][candcur]);
preloc[candcur] = insertcur;
candcur--;
insertcur = (insertcur + 1 > preloc[candcur]+width-1-candcur)?(insertcur + 1) : (preloc[candcur]+width-1-candcur);
continue;
}
insertcur++;
}
}
delete [] preloc;
return sorted[k-1];
}
};
``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.