Very Short and Clear MergeSort & BST Java Solutions


  • 37

    MergeSort

    Explanation: In each round, we divide our array into two parts and sort them. So after "int cnt = mergeSort(nums, s, mid) + mergeSort(nums, mid+1, e); ", the left part and the right part are sorted and now our only job is to count how many pairs of number (leftPart[i], rightPart[j]) satisfies leftPart[i] <= 2*rightPart[j].
    For example,
    left: 4 6 8 right: 1 2 3
    so we use two pointers to travel left and right parts. For each leftPart[i], if j<=e && nums[i]/2.0 > nums[j], we just continue to move j to the end, to increase rightPart[j], until it is valid. Like in our example, left's 4 can match 1 and 2; left's 6 can match 1, 2, 3, and left's 8 can match 1, 2, 3. So in this particular round, there are 8 pairs found, so we increases our total by 8.

    public class Solution {
        public int reversePairs(int[] nums) {
            return mergeSort(nums, 0, nums.length-1);
        }
        private int mergeSort(int[] nums, int s, int e){
            if(s>=e) return 0; 
            int mid = s + (e-s)/2; 
            int cnt = mergeSort(nums, s, mid) + mergeSort(nums, mid+1, e); 
            for(int i = s, j = mid+1; i<=mid; i++){
                while(j<=e && nums[i]/2.0 > nums[j]) j++; 
                cnt += j-(mid+1); 
            }
            Arrays.sort(nums, s, e+1); 
            return cnt; 
        }
    }
    

    Or:
    Because left part and right part are sorted, you can replace the Arrays.sort() part with a actual merge sort process. The previous version is easy to write, while this one is faster.

    public class Solution {
        int[] helper;
        public int reversePairs(int[] nums) {
            this.helper = new int[nums.length];
            return mergeSort(nums, 0, nums.length-1);
        }
        private int mergeSort(int[] nums, int s, int e){
            if(s>=e) return 0; 
            int mid = s + (e-s)/2; 
            int cnt = mergeSort(nums, s, mid) + mergeSort(nums, mid+1, e); 
            for(int i = s, j = mid+1; i<=mid; i++){
                while(j<=e && nums[i]/2.0 > nums[j]) j++; 
                cnt += j-(mid+1); 
            }
            //Arrays.sort(nums, s, e+1); 
            myMerge(nums, s, mid, e);
            return cnt; 
        }
        
        private void myMerge(int[] nums, int s, int mid, int e){
            for(int i = s; i<=e; i++) helper[i] = nums[i];
            int p1 = s;//pointer for left part
            int p2 = mid+1;//pointer for rigth part
            int i = s;//pointer for sorted array
            while(p1<=mid || p2<=e){
                if(p1>mid || (p2<=e && helper[p1] >= helper[p2])){
                    nums[i++] = helper[p2++];
                }else{
                    nums[i++] = helper[p1++];
                }
            }
        }
    }
    

    BST
    BST solution is no longer acceptable, because it's performance can be very bad, O(n^2) actually, for extreme cases like [1,2,3,4......49999], due to the its unbalance, but I am still providing it below just FYI.
    We build the Binary Search Tree from right to left, and at the same time, search the partially built tree with nums[i]/2.0. The code below should be clear enough.

    public class Solution {
        public int reversePairs(int[] nums) {
            Node root = null;
            int[] cnt = new int[1];
            for(int i = nums.length-1; i>=0; i--){
                search(cnt, root, nums[i]/2.0);//search and count the partially built tree
                root = build(nums[i], root);//add nums[i] to BST
            }
            return cnt[0];
        }
        
        private void search(int[] cnt, Node node, double target){
            if(node==null) return; 
            else if(target == node.val) cnt[0] += node.less;
            else if(target < node.val) search(cnt, node.left, target);
            else{
                cnt[0]+=node.less + node.same; 
                search(cnt, node.right, target);
            }
        }
        
        private Node build(int val, Node n){
            if(n==null) return new Node(val);
            else if(val == n.val) n.same+=1;
            else if(val > n.val) n.right = build(val, n.right);
            else{
                n.less += 1;
                n.left = build(val, n.left);
            }
            return n;
        }
        
        class Node{
            int val, less = 0, same = 1;//less: number of nodes that less than this node.val
            Node left, right;
            public Node(int v){
                this.val = v;
            }
        }
    }
    

    Similar to this https://leetcode.com/problems/count-of-smaller-numbers-after-self/. But the main difference is: here, the number to add and the number to search are different (add nums[i], but search nums[i]/2.0), so not a good idea to combine build and search together.


  • 2
    Y

    This is nice, I find the BST a clear approach. I wonder if anybody did it in python? Sadly mine was Time Limit Exceeded and I even translated your code into python below but it still gets TLE.

    class TreeNode:
        def __init__(self, val):
            self.val = val
            self.less = 0
            self.same = 1
            self.left = None
            self.right = None
    
    class Solution(object):
        def reversePairs(self, nums):
            root = None
            cnt = [0]
            for i in range(len(nums)-1, -1, -1):
                self.search(cnt, root, nums[i]/2.0)
                root = self.build(nums[i], root)
            return cnt[0]
    
        def search(self, cnt, node, target):
            if not node:
                return
            if target == node.val:
                cnt[0] += node.less
            elif target < node.val:
                self.search(cnt, node.left, target)
            else:
                cnt[0] += (node.less + node.same)
                self.search(cnt, node.right, target)
    
        def build(self, val, n):
            if not n:
                return TreeNode(val)
            elif val == n.val:
                n.same += 1
            elif val > n.val:
                n.right = self.build(val, n.right)
            else:
                n.less += 1
                n.left = self.build(val, n.left)
            return n
    

  • 0

    Impressive!!! Your solution reminds me how much practice I still need to do.... I did count of smaller numbers after self before, but I even cannot remember it during the contest....WTF

    Below is my solution using merge sort.

    public class Solution {
        public int reversePairs(int[] nums) {
            if (nums == null || nums.length == 0) return 0;
            return mergeSort(nums, 0, nums.length - 1);
        }
        private int mergeSort(int[] nums, int l, int r) {
            if (l >= r) return 0;
            int mid = l + (r - l)/2;
            int count = mergeSort(nums, l, mid) + mergeSort(nums, mid + 1, r);
            int[] cache = new int[r - l + 1];
            int i = l, t = l, c = 0;
            for (int j = mid + 1; j <= r; j++, c++) {
                while (i <= mid && nums[i] <= 2 * (long)nums[j]) i++;
                while (t <= mid && nums[t] < nums[j]) cache[c++] = nums[t++];
                cache[c] = nums[j];
                count += mid - i + 1;
            }
            while (t <= mid) cache[c++] = nums[t++];
            System.arraycopy(cache, 0, nums, l, r - l + 1);
            return count;
        }
    }
    

  • 0
    L

    @yorkshire
    Try not to use recursive. I had the same issue when I used recursive method. So I rewrited it using while loops, which speeds up the code by x1.5, and now it finishes within time limit.

    class TreeNode(object):
        def __init__(self, val):
            self.val = val #the value at this node
            self.same = 1 #number of keys with this value
            self.less = 0 #number of keys in the left subtree
            self.left = None #left subtree (nodes with less value)
            self.right = None #right subtree (nodes with higher value)
    
    def search(node, val):
        n_temp = 0
        while node != None:
            if node.val == val:
                n_temp += node.less #keys with the same value doesn't count
                node = None
            elif val > node.val:
                n_temp += node.less + node.same
                node = node.right
            else:
                node = node.left
        return n_temp
    
    def insert(node, val):
        while True:
            if node.val == val:
                node.same += 1
                return
            elif val > node.val:
                if node.right == None:
                    node.right = TreeNode(val)
                    return
                else:
                    node = node.right
            else:
                node.less += 1
                if node.left == None:
                    node.left = TreeNode(val)
                    return
                else:
                    node = node.left
    
    class Solution(object):
        def reversePairs(self, nums):
            length = len(nums)
            if length <= 1: return 0
            n = 0
            root = TreeNode(2*nums[-1])
            for i in reversed(xrange(length - 1)):
                a = nums[i]
                n += search(root, a)
                insert(root, 2*a)
            return n
    

  • 1
    T

    @louis925 The python code you posted does not pass the TLE. I wrote my own code and also test directly with your posted code. Neither passed the test. Thus, the BST method should not be considered as standard solution.


  • 0
    Y

    @tbjc The original java code from @Chidong now also gives a Time Limit Exceeded for the test case where nums = [i for i in range(5000)]. In this situation the tree is effectively linear and so O(n**2) time complexity.
    At least the results are standardised now so that both languages are not accepted and we have to use something like mergesort that is O(n log n). It's just a shame that Java was accepted in the contest and Python was not.


  • 0
    L

    @tbjc @yorkshire
    Thanks for notifying me this! I think they change the test cases.
    In the contest, the person who got second place actually only used binary sort and passed all the test cases simply because it was written in C++. I try to do the same thing in python and it got TLE.
    Anyway, now I will try the mergesort method next.


  • 0

    my code is the same to you, why TLE?


  • 0
    B

    @tbjc Neither did C++ version...I think the non-balanced binary tree is not the standard solution!


  • 0

    @Chidong Hi, thanks for posting your solution. I'm just confused about one thing..maybe I'm missing something but where is the merge in your mergeSort? I see Arrays.sort()..isn't that O(nlogn) so total time complexity would be O(nlognlogn)?


  • 0

    @jonathan82
    The process is first partitioning [s, e] to [s, mid] and [mid+1, e]. Then the entire [s, e] are "merged/sorted. The merging process in merge sort is essentially a sorting process. Yes you can use Arrays.sort() to do the job, which is easier to write but take more time to run. Or you can implement this part yourself, since left and right parts are all sorted, it only it takes O(n) to finish the merge/sort job. I added that in my post.


  • 0

    @Chidong oh, got it thx. I also thought about using a fenwick tree but couldn't seem to get it to fit this problem.


  • 0
    L
    This post is deleted!

  • 0
    C

    Why is it Arrays.sort(nums, s, e+1); instead of Arrays.sort(nums, s, e); ?

    How is that +1 coming from? I tried to remove that +1 but get error. I can debug it to find out why. But how did you know it before debugging? Thank you very much!


  • 0

    @coder2 Java Api says: "public static void sort(short[] a, int fromIndex, int toIndex)
    Sorts the specified range of the array into ascending order. The range to be sorted extends from the index fromIndex, inclusive, to the index toIndex, exclusive. If fromIndex == toIndex, the range to be sorted is empty. .... "


  • 1

    My Python merge sort solution:

    class Solution(object):
        def reversePairs(self, nums):
            def mergeSort(s,e):
                if s >= e: return 0
                mid = (s+e)/2
                cnt,j = mergeSort(s,mid) + mergeSort(mid+1,e),mid+1
                for i in xrange(s,mid+1):
                    while j <= e and nums[i] > 2*nums[j]:
                        j += 1
                    cnt += j-(mid+1)
                nums[s:e+1] = sorted(nums[s:e+1])
                return cnt
            return mergeSort(0,len(nums)-1)

  • 0
    H

    Thanks for the explanations. Post my 239ms C++ version:

    class Solution {
    public:
        int reversePairs(vector<int>& nums) {
            return mergeSort(nums, 0, nums.size()-1);
        }
        
        int mergeSort(vector<int>& nums, int l, int r) {
            if (l >= r) return 0;
            int mid = l + (r - l) / 2;
            int res = 0;
            res = mergeSort(nums, l, mid) + mergeSort(nums, mid+1, r);
            for (int i = mid+1; i <= r; i++) {
                auto it = upper_bound(nums.begin()+l, nums.begin()+mid+1, nums[i] * 2L);
                int dis = distance(nums.begin()+l, it);
                if (dis > mid-l) break;
                res += mid-l+1 - dis;
            }
            inplace_merge(nums.begin()+l, nums.begin()+mid+1, nums.begin()+r+1);
            return res;
        }
    };
    

  • 0
    R

    @Chidong Thanks for your post,
    I have a question: When you work on this solution, how do you know whether your sorting could introduce duplicate or miss counting, since sorting changes the position of numbers in original array?

    Thanks for advance.


  • 0
    N

    @Chidong For the MergeSort Solution, could you explain the complexity? I think it's not O(nlogn), since the recurrence seems to be T(n) = 2T(n/2) + O(n^2), instead of T(n) = 2T(n/2) + O(n). Thanks.


  • 0

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.