Java Concise Binary Search


  • 14

    If we are not allowed to use TreeMap:

    1. Sort starts
    2. For each end, find leftmost start using binary search
    3. To get the original index, we need a map
    public int[] findRightInterval(Interval[] intervals) {
        Map<Integer, Integer> map = new HashMap<>();
        List<Integer> starts = new ArrayList<>();
        for (int i = 0; i < intervals.length; i++) {
            map.put(intervals[i].start, i);
            starts.add(intervals[i].start);
        }
        
        Collections.sort(starts);
        int[] res = new int[intervals.length];
        for (int i = 0; i < intervals.length; i++) {
            int end = intervals[i].end;
            int start = binarySearch(starts, end);
            if (start < end) {
                res[i] = -1;
            } else {
                res[i] = map.get(start);
            }
        }
        return res;
    }
    
    public int binarySearch(List<Integer> list, int x) {
        int left = 0, right = list.size() - 1;
        while (left < right) {
            int mid = left + (right - left) / 2;
            if (list.get(mid) < x) { 
                left = mid + 1;
            } else {
                right = mid;
            }
        }
        return list.get(left);
    }

  • 1

    For step 2, do you actually mean "find the rightmost start"?
    For example: when the input is [[1,12],[2,9],[3,10],[13,14],[15,16],[16,17]]
    for interval [16, 17], the end is 17, and we want to find the rightmost start in the list {1,2,3,13,15,16}, which should be 16. And since 16 < 17, we set res[i] = -1.


  • 0

    great solution, I have basically the same here, only difference is this line of code:

    Interval right = FindRight(intervals, interval.end, i + 1); 
    

    I leverage the fact that when you do your binary search you can set your min index search marker to the next element you are iterating over. This still yields O(n log n) overall but will reduce searching by half over the course of the iteration.

    C#

        public int[] FindRightInterval(Interval[] intervals) 
        {
            // capture indexes in hash
            Dictionary<Interval,int> map = new Dictionary<Interval,int>();
            for (int i = 0; i < intervals.Length; i++) map[intervals[i]] = i;
            
            // sort by start val to enable binary search
            Array.Sort(intervals, (a,b) => a.start.CompareTo(b.start));
            
            int[] res = new int[intervals.Length];
            for (int i = 0; i < intervals.Length; i++)
            {
                Interval interval = intervals[i];
                int originalIndex = map[interval];
                Interval right = FindRight(intervals, interval.end, i + 1);
                int rightIndex = right != null ? map[right] : -1;
                
                res[originalIndex] = rightIndex;
            }
            
            return res;
        }
        
        public Interval FindRight(Interval[] intervals, int minStart, int minIndex)
        {
            int left = minIndex;
            int right = intervals.Length - 1;
            Interval best = null;
            
            while (left <= right)
            {
                int mid = (left + right)/2;
                int start = intervals[mid].start;
                
                if (start == minStart)
                {
                    return intervals[mid];
                }
                else if (start > minStart)
                {
                    if (best == null || start < best.start) best = intervals[mid];
                    right = mid - 1;
                }
                else
                {
                    left = mid + 1;
                }
            }
            
            return best;
        }
    

  • 0

    Nice! Even I came up with the same, except I use more space.

    public int[] findRightInterval(Interval[] intervals) {
        Map<Interval, Integer> indexMap = new LinkedHashMap<>();
        int[] result = new int[intervals.length];
        for(int i = 0; i < intervals.length; i++) {
            indexMap.put(intervals[i], i); // saving positions
        }
        
        Arrays.sort(intervals, new Comparator<Interval>() {
            @Override
            public int compare(Interval a, Interval b) {
                return a.start - b.start;
            }
        });
        
        int i = 0;
        for(Interval each : indexMap.keySet()) {
            int targetIndex = binarySearch(intervals, each);
            if(targetIndex < 0) result[i++] = -1;
            else result[i++] = indexMap.get(intervals[targetIndex]);
        }
        return result;
    }
    public int binarySearch(Interval[] intervals, Interval each) {
        int low = 0, high = intervals.length-1;
        while(low < high) {
            int mid = low + (high - low) / 2;
            if(intervals[mid].start < each.end) {
                low = mid + 1;
            }else {
                high = mid;
            }
        }
        if(intervals[low].start < each.end) return -1;
        return low;
    }

  • 0

    Same idea. The question actually gives some hints on the solution.
    "You may assume none of these intervals have the same start point." => You can use it as hash/mapping.
    Interval non-overlap question => Probably need some kind of sorting.

    class Solution {
        public int[] findRightInterval(Interval[] intervals) {
            if (intervals == null || intervals.length < 1) {
                return new int[0];
            }
            int n = intervals.length;
            int[] out = new int[n];
            Map<Integer, Integer> map = new HashMap<>();
            for (int i = 0; i < n; ++i) {
                map.put(intervals[i].start, i);
            }
            Arrays.sort(intervals, (a, b) -> (a.start - b.start));
            for (int i = 0; i < n; ++i) {
                out[map.get(intervals[i].start)] = -1;
                int l = i + 1;
                int r = n - 1;
                while (l < r) {
                    int m = l + (r - l) / 2;
                    if (intervals[m].start < intervals[i].end) {
                        l = m + 1;
                    } else {
                        r = m;
                    }
                }
                if (l < n && intervals[l].start >= intervals[i].end) {
                    out[map.get(intervals[i].start)] = map.get(intervals[l].start);
                }
            }
            return out;
        }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.