This is a solution I can actually come up with on the fly if someone asks me to do the problem on a whiteboard :). The reason this solution works is that there is exactly one local minimum which is the same as the global minimum. Because of that, we can choose an arbitrary starting point (I use the middle point of the grid), and search for the direction to go, until there is no direction that can reduce the total distance. See the comments for the details on how to decide to go left, right, up, or down. It helps to go through it with a simple example.

```
// want to know at each point (row, col), how many points are on the left, right, up, down, respectively
int minTotalDistance(vector<vector<int>>& grid) {
int m = grid.size(), n = grid[0].size();
// construct left, right, up, down
// left[i] tells us how many points are on the left of column i
// up[j] tells us how many points are above row j, etc.
vector<int> left(n, 0), right(n, 0), up(m, 0), down(m, 0);
for (int col = 1; col < n; col++) {
left[col] = left[col-1];
for (int row = 0; row < m; row++)
left[col] += grid[row][col-1];
}
for (int col = n - 2; col >= 0; col--) {
right[col] = right[col+1];
for (int row = 0; row < m; row++)
right[col] += grid[row][col+1];
}
for (int row = 1; row < m; row++) {
up[row] = up[row-1];
for (int col = 0; col < n; col++)
up[row] += grid[row-1][col];
}
for (int row = m - 2; row >= 0; row--) {
down[row] = down[row+1];
for (int col = 0; col < n; col++)
down[row] += grid[row+1][col];
}
// start from the middle of the grid
int curRow = m / 2, curCol = n / 2;
// if going left reduces the total distance, go left
// right[curCol - 1] is the increase of distance from getting further away from all points on the right and current column
// -left[curCol] is the decrease of distance from getting closer to all points on the left
// so if right[curCol-1] - left[curCol] is negative, total distance is reduced by going left
while (curCol > 0 && right[curCol-1] - left[curCol] < 0)
curCol--;
// right
while (curCol < n - 1 && left[curCol+1] - right[curCol] < 0)
curCol++;
// up
while (curRow > 0 && down[curRow-1] - up[curRow] < 0)
curRow--;
// down
while (curRow < m - 1 && up[curRow+1] - down[curRow] < 0)
curRow++;
// at this point (curRow, curCol) is the best meeting point
// compute the distance from left, right, up and down, using prefix sum
// for example if left is (0,1,1,2,2,2) and curRow is 3, the result of prefix sum is
// (0,1,2,4,2,2), so 4 is the total horizontal distance from the points on the left of our meeting point
partial_sum(left.begin(), left.begin()+curCol+1, left.begin());
for (int i = n - 2; i >= curCol; i--)
right[i] += right[i+1];
partial_sum(up.begin(), up.begin()+curRow+1, up.begin());
for (int i = m - 2; i >= curRow; i--)
down[i] += down[i+1];
int minTotalDistance = left[curCol] + right[curCol] + up[curRow] + down[curRow];
return minTotalDistance;
}
```