```
public int findMaxLength(int[] nums) {
int i, n = nums.length;
Map<Integer, Integer> map = new HashMap<>();
map.put(0, 0);
int maxLen = 0;
int diff = 0;
Integer first;
for (i=0;i<n;i++) {
if (nums[i] == 0)
diff--;
else
diff++;
first = map.get(diff);
if (first == null) {
map.put(diff, i+1);
} else {
maxLen = Math.max(maxLen, i - first + 1);
}
}
return maxLen;
}
}
```