We can solve the one dimension problem first.

Let's say the array is `[1,0,0,1,1]`

To solve it, we need to first count the total distance if the meeting point is at index -1. We also want to count how many 1s in the array.

```
[1,0,0,1,1] -> 1 + 4 + 5 = 10
Total number of 1s is 3
```

Now given an index `i`

, we will recompute our distance based on how many 1s are on the left side of the `i`

, and how many 1s are on the right side of the `i`

. The left side 1s will increase the distance while the right side 1s will decrease it.

Two dimensions can be converted to one dimension problem by compressing the matrix, e.g.

```
[1,0,0,0,1]
[0,0,0,0,0]
[0,0,1,0,0]
```

will be converted to

```
Row: [1,0,1,0,1] (Sum of column 0, 2, 4 is 1)
Col: [2,0,1] (Sum of first row is 2, sum of third row is 1)
```

Code:

```
public int minTotalDistance(int[][] grid) {
int n = grid.length;
if (n == 0) return 0;
int m = grid[0].length;
if (m == 0) return 0;
int[] col = new int[n];
int[] row = new int[m];
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
row[j] += grid[i][j];
col[i] += grid[i][j];
}
}
return minTotalDistance(row) + minTotalDistance(col);
}
private int minTotalDistance(int[] nums) {
int sum = 0;
int count = 0;
for (int i = 0; i < nums.length; i++) {
sum += (i + 1) * nums[i];
count += nums[i];
}
int leftCount = 0;
int rightCount = count;
int min = sum;
for (int i = 0; i < nums.length; i++) {
sum = sum - rightCount + leftCount;
if (sum < min) min = sum;
leftCount += nums[i];
rightCount -= nums[i];
}
return min;
}
```