It was basically medians of medians, inspired by @StefanPochmann.

For each row, the median is chosen by the middle positions since each row is sorted. The medians of those medians is chosen by STL nth_element function. Afterwards, the key step is to find the upper_bound of the median of medians for each row. In this step, we calculate the number of the elements that are less or equivalent to the median of medians for each row and calculate its summation. If the summation is smaller than k, we know the kth element falls within the 'smaller portion'. Otherwise, it falls into those portions that are larger than the median of medians.

```
class Solution {
public:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
// index to indicate the boundary of each row.
vector<vector<int>> bounds(n, vector<int>({ 0, n - 1 }));
int mid;
int median;
vector<int> vc;
int count = INT_MAX;
int prev;
do {
prev = count;
vc.clear();
//extract medians from each row
for (int i = 0; i < n; ++i){
if (bounds[i][0] <= bounds[i][1]){
mid = (bounds[i][0] + bounds[i][1]) / 2;
vc.push_back(matrix[i][mid]);
}
}
//find the median of medians
nth_element(vc.begin(), vc.begin()+vc.size()/2, vc.end());
median = (*(vc.begin() + vc.size() / 2));
vector<int> index(n,0);
count = 0;
//find the upper_bound of the median of medians for each row and sum up the number of elements that are less than the median of medians.
for (int i = 0; i < n; ++i){
if (bounds[i][0] <= bounds[i][1]){
index[i] = upper_bound(matrix[i].begin() + bounds[i][0], matrix[i].begin() + bounds[i][1] + 1, median)-matrix[i].begin();
count += index[i] - bounds[i][0];
}
}
//if the number of elements less than the median of medians is smaller than k, it means that kth element should fall in the upper bounds and vice versa
if (count <= k){
for (int i = 0; i < n; ++i){
if (bounds[i][0] <= bounds[i][1]){
bounds[i][0] = index[i];
}
}
k -= count;
}
else {
for (int i = 0; i < n; ++i){
if (bounds[i][0] <= bounds[i][1]){
bounds[i][1] = index[i] - 1;
}
}
}
} while (prev > count && k >0);
if (k == 0) return median;
vc.clear();
for (int i = 0; i < n; ++i){
for (int j = bounds[i][0]; j <= bounds[i][1]; ++j){
vc.push_back(matrix[i][j]);
}
}
nth_element(vc.begin(), vc.begin()+k-1, vc.end());
median = (*(vc.begin() + k - 1));
return median;
}
};```
```