Idea is simple and intuitive. Scan the matrix from left-top to right-bottom then another pass from right-bottom to left-top.

```
class Solution {
public:
vector<vector<int>> updateMatrix(vector<vector<int>>& matrix) {
vector<vector<int>> res(matrix);
int m = res.size();
if (m == 0) return res;
int n = res[0].size();
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
if (matrix[i][j]) {
int left = j > 0 ? res[i][j-1] : m*n;
int up = i > 0 ? res[i-1][j] : m*n;
res[i][j] = min(left, up) + 1;
}
}
}
for (int i = m-1; i >= 0; i--) {
for (int j = n-1; j >= 0; j--) {
if (matrix[i][j]) {
int right = j < n-1 ? res[i][j+1] : m*n;
int down = i < m-1 ? res[i+1][j] : m*n;
res[i][j] = min(res[i][j], min(right, down) + 1);
}
}
}
return res;
}
};
```