 imagine we have a virtual vertical line.
 If we move the line to left, the distance of all the elements that on the left column or at the same column with the line, will decrement 1, the other part will increase 1.
 So that if the number of element on the left is greater than that on the right, the dist will decrement. We can traverse the column from the left ot right, and find the critical point that cannot decrement dist, which is the best column.

We can use the same rule to deal with rows
public int minTotalDistance(int[][] grid) { int m = grid.length; if (m == 0) return 0; int n = grid[0].length; if (n == 0) return 0; int[] row = new int[m];// store the number of homes in each row int[] column = new int[n];// store the number of homes in each column for (int i = 0; i < m; i++){ for (int j = 0; j < n; j++){ if (grid[i][j] == 1){ row[i]++; column[j]++; } } } // find the best row & column int r = criticalPoint(row); int c = criticalPoint(column); int dist = 0; for (int i = 0; i < m; i++) dist+= Math.abs(i  r) * row[i]; for (int j = 0; j < n; j++) dist+= Math.abs(j  c) * column[j]; return dist; } private int criticalPoint(int[] a){ int r = 0; int first = a[0]; int second = 0; for (int i = 1; i < a.length; i++) second+= a[i]; while(first < second){ r++; first = first + a[r]; second = second  a[r]; } return r; }