Explained in the comments

```
public int findNumberOfLIS(int[] nums) {
if (nums == null) return -1;
if (nums.length < 2) return nums.length;
TreeMap<Integer, Integer>[] ends = new TreeMap[nums.length]; // ends[i] records all the last values of LIS with length i
int[] minEnds = new int[nums.length]; // minEnds[i] records the min last value of an LIS with length i
int len = 0;
int result = 0;
for (int num : nums) {
int index = Arrays.binarySearch(minEnds, 0, len, num);
if (index < 0) index = - index - 1;
int count = 0; // how many LIS ending at num
if (index == 0) count = 1;
else {
// accumulate the number of LIS that has length 1 less
for (int c : ends[index - 1].headMap(num).values()) count += c;
}
if (index == len - 1) result += count;
else if (index == len) {
result = count;
ends[len++] = new TreeMap<>();
}
minEnds[index] = num;
ends[index].put(num, ends[index].getOrDefault(num, 0) + count);
}
return result;
}
```