# Java Max Heap Solution w/ Inline Explanation

• As the matrix was sorted, an element indexed (i, j) must be equal or larger than the (i * j)-th element. Thus for each row i in the matrix, we could try to find the last candidate whose i * j index doesn't exceed k, and discard all elements behind it.

In the following solution, we keep the last candidates in a max heap, and the number of the remaining valid elements in a counter c. Our target is to have c decrease from n * n to k by discarding invalid entries first, and then by polling from the heap.

Note that while this solution may save some time by discarding elements simply by referencing their indices, we won't suggest it in interviews which require flawless implementation, as careless interviewers may have trouble converting indices (1-based input and i * j indices) if their preferred language utilizes 0-based indices.

public class Solution implements Comparator<Solution.Cell> {
public int kthSmallest(int[][] matrix, int k) {
final int n = matrix.length;
final PriorityQueue<Cell> heap = new PriorityQueue(n, this);
int[] row;
int c = n * n;
for (int i = 0; i < n; i++) {
// Find the last candidate for each row
int j = k / (i + 1) - 1;
if (j >= n) j = n - 1;
if (j >= 0) heap.offer(new Cell(i, j, matrix[i][j]));
// ...and skip all cells after it
c -= n - j - 1;
}
Cell r = null;
while (c >= k) {
r = heap.poll();
// For each polled cell, list its previous cell as candidate
if (r.j >= 1) heap.offer(new Cell(r.i, r.j - 1, matrix[r.i][r.j - 1]));
c--;
}
return r.val;
}
static class Cell {
int i;
int j;
int val;
public Cell(int ii, int jj, int vv) {
i = ii;
j = jj;
val = vv;
}
}
public int compare(Cell c1, Cell c2) {
// We need max heap
return c2.val - c1.val;
}
}

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.