# Java merge sort solution, O(nlog(n))

• Similar with count smaller after self, just scan the array before merge

``````public class Solution {

public int ret;
public int reversePairs(int[] nums) {
ret = 0;
mergeSort(nums, 0, nums.length-1);
return ret;
}

public void mergeSort(int[] nums, int left, int right) {
if (right <= left) {
return;
}
int middle = left + (right - left)/2;
mergeSort(nums, left, middle);
mergeSort(nums,middle+1, right);

//count elements
int count = 0;
for (int l = left, r = middle+1; l <= middle;) {
if (r > right || (long)nums[l] <= 2*(long)nums[r]) {
l++;
ret += count;
} else {
r++;
count++;
}
}

//merge sort
int[] temp = new int[right - left + 1];
for (int l = left, r = middle+1, k = 0; l <= middle || r <= right;) {
if (l <= middle && ((r > right) || nums[l] < nums[r])) {
temp[k++] = nums[l++];
} else {
temp[k++] = nums[r++];
}
}
for (int i = 0; i < temp.length; i++) {
nums[left + i] = temp[i];
}
}
}
``````

Clearer and simpler version, but slower, got the idea by another solution

``````public class Solution {

public int ret;
public int reversePairs(int[] nums) {
ret = 0;
mergeSort(nums, 0, nums.length-1);
return ret;
}

public void mergeSort(int[] nums, int left, int right) {
if (right <= left) {
return;
}
int middle = left + (right - left)/2;
mergeSort(nums, left, middle);
mergeSort(nums,middle+1, right);

//count elements
int count = 0;
for (int l = left, r = middle+1; l <= middle;) {
if (r > right || (long)nums[l] <= 2*(long)nums[r]) {
l++;
ret += count;
} else {
r++;
count++;
}
}

//sort
Arrays.sort(nums, left, right + 1);
}
}
``````

• Nice. Interesting

• @yanzhan2

I did the same thing, just wrote it slightly different:

``````public int reversePairs(int[] nums) {
return countWhileMergeSort(nums, 0, nums.length);
}

private int countWhileMergeSort(int[] nums, int start, int end) {
if (end - start <= 1) return 0;
int mid = (start + end) >>> 1;
int count = count(nums, start, mid, end);
merge(nums, start, mid, end);
return count;
}

private int count(int[] nums, int start, int mid, int end) {
int count = countWhileMergeSort(nums, start, mid)
+ countWhileMergeSort(nums, mid, end);

int hi = end - 1;
for (int lo = mid - 1; lo >= start && hi >= mid; lo--) {
while (hi >= mid && (long)nums[lo] <= 2 * (long)nums[hi]) hi--;
count += hi - mid + 1;
}

return count;
}

private void merge(int[] nums, int start, int mid, int end) {
int right = mid, t = 0;
int temp[] = new int[end - start];

for (int left = start; left < mid; left++) {
while (right < end && nums[right] < nums[left]) temp[t++] = nums[right++];
temp[t++] = nums[left];
}

System.arraycopy(temp, 0, nums, start, right - start);
}
``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.