Quickselect with using median as an pivot


  • 0
    S

    The time of using nth_element is 12ms.
    Below solution is 13ms.

    class Solution {
        void move_median_to_begin(vector<int>& nums, int beg, int end) {
            // at least 3 number
            int mid = beg + (end - beg) / 2;
            // if ((nums[beg] >= nums[mid] && nums[beg] <= nums[end]) ||
            //     (nums[beg] >= nums[end] && nums[beg] <= nums[mid])) {
            //         // beg is median
            //         return;
            //     }
            if ((nums[mid] >= nums[beg] && nums[mid] <= nums[end]) ||
                (nums[mid] >= nums[end] && nums[mid] <= nums[beg])) {
                    // mid is median
                    swap(nums[mid], nums[beg]);
                    return;
                }
            if ((nums[end] >= nums[beg] && nums[end] <= nums[mid]) ||
                (nums[end] >= nums[mid] && nums[end] <= nums[beg])) {
                    // end is median
                    swap(nums[end], nums[beg]);
                    return;
                }
        }
        
        int partition(vector<int>& nums, int beg, int end) {
            if (beg == end) return beg;
            if (beg == end-1) {
                if (nums[beg] > nums[end])
                    swap(nums[beg], nums[end]);
                return beg;
            }
            
            move_median_to_begin(nums, beg, end);
            
            int pivot = nums[beg];
            int i = beg + 1;
            int part = beg;
            while (i <= end) {
                if (nums[i] < pivot)
                    swap(nums[++part], nums[i]);
                i++;
            }
            swap(nums[beg], nums[part]);
            return part;
        }
        
        int findKth(vector<int>& nums, int beg, int end, int k) {
            int pivot = partition(nums, beg, end);
            if (pivot == nums.size()-k)
                return nums[pivot];
            if (pivot > nums.size()-k)
                return findKth(nums, beg, pivot-1, k);
            else
                return findKth(nums, pivot+1, end, k);
        }
        
    public:
        int findKthLargest(vector<int>& nums, int k) {
            return findKth(nums, 0, nums.size()-1, k);
        }
    };

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.