Java O(nlogn) solution with sorting & binary searching


  • 2
    F

    To begin, we have this O(n^2) naive solution: for each interval, compare its end point to the start point of all other intervals in the input array and find one, if any, with a minimum start point that is no less than the given end point. This is essentially equivalent to searching in the given array for an interval satisfying the given criteria.

    Plain searching involves iterating through the whole array, which leads to the bad time performance (O(n^2)). What if the input array is sorted according to the start point of the intervals? Then we can simply do binary search to reduce the time complexity down to O(nlogn).

    One tricky point is that we need to know the index of the element in the original input array but sorting will break the order. The solution is to bind the original index of each interval to its start point so we still have the index info after sorting the start points. This can be done using either a TreeMap (start point as key and original index as value) or simply a new n-by-2 array (start point as the first element and original index as the second one).

    The following solution uses an n-by-2 array. One advantage is that the binary search can start from the index of each interval in the new array, instead of from the beginning every time. Here is a quick explanation:

    a. res is the result array; arr is an auxiliary array whose element will encode the start point and index info of the interval from the input array.

    b. First populate the arr array and sort it according to the start points. Then for each interval in arr, do binary search to find the smallest index of an interval with start point no less than its end point. Set the result to the corresponding index if we can find it or -1 if no index is found.

    public int[] findRightInterval(Interval[] intervals) {
        int[] res = new int[intervals.length];
        int[][] arr = new int[intervals.length][2];
            
        for (int i = 0; i < intervals.length; i++) {
            arr[i][0] = intervals[i].start;
            arr[i][1] = i;
        }
            
        Arrays.sort(arr, new Comparator<int[]>() {
            public int compare(int[] a, int[] b) {
                return Integer.compare(a[0], b[0]);
            }
        });
         
        for (int i = 0; i < arr.length; i++) {
            int l = i + 1, r = arr.length - 1, m = 0;
                
            while (l <= r) {
                m = l + ((r - l) >>> 1);
                    
                if (intervals[arr[i][1]].end <= arr[m][0]) {
                    r = m - 1;
                } else {
                    l = m + 1;
                }
            }
                
            res[arr[i][1]] = (l < arr.length ? arr[l][1] : -1);
        }
            
        return res;
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.