Hello, everyone.

**EDIT:** in this case, **n** is a total number of elements. Thanks @erikandre for pointing this out!

Here's my solution based on a "max" heap of size **k**. I believe it has a runtime complexity O(n log k) and O(k) extra space.

The idea is to iterate through the matrix once and keep adding its elements to the "max" heap of size **k** until the heap becomes full. When the heap is full - compare each of the remaining matrix elements to the heap's current maximum element. If the element is less than heap's maximum value - remove the maximum value and add the element into the heap (so the size of the heap never exceeds **k**).

When you reach the end of the matrix, the maximum element of the heap would be the k-th smallest element of the matrix.

Please, let me know if you find any mistakes or logic flaws.

p.s. this solution doesn't use the fact that all rows and cols are sorted.

```
public class Solution
{
public int kthSmallest(int[][] matrix, int k)
{
final int ROWS = matrix.length;
final int COLS = matrix[0].length;
MaxPQ queue = new MaxPQ(k);
for (int i = 0; i < ROWS; ++i)
{
for (int j = 0; j < COLS; ++j)
{
final int element = matrix[i][j];
if (queue.size < k)
{
queue.add(element);
}
else if (element < queue.max())
{
queue.dequeue();
queue.add(element);
}
}
}
return queue.max();
}
static class MaxPQ
{
private final int[] data;
private int size;
public MaxPQ(int size)
{
data = new int[size + 1];
}
public void add(int e)
{
++size;
data[size] = e;
swim(size);
}
public int dequeue()
{
int max = data[1];
swap(1, size--);
sink(1);
return max;
}
public int max()
{
return data[1];
}
public boolean empty()
{
return size == 0;
}
private void swim(int k)
{
int parent;
while ((parent = k / 2) >= 1 && greater(k, parent))
{
swap(k, parent);
k = parent;
}
}
private void sink(int k)
{
int child;
while ((child = 2 * k) <= size)
{
if (child < size && greater(child + 1, child)) ++child;
if (greater(k, child)) break;
swap(k, child);
k = child;
}
}
private void swap(int i, int j)
{
int temp = data[i];
data[i] = data[j];
data[j] = temp;
}
private boolean greater(int i, int j)
{
return data[i] > data[j];
}
}
}
```