20-line 1ms in-place Java code with expalantion


  • 4
    M
    1. scan from right to left, find first pair where a[i] > a[i-1]. note that a[i:] is non-ascending.
    2. scan from right to left again, find first element a[j] where a[j] > a[i-1]. since a[i:] is non-ascending, a[j] is the smallest number that > a[i-1].
    3. swap a[i-1] with a[j]. note that after swap, a[i:] remains non-ascending.
    4. reverse a[i:] so that a[i:] becomes non-descending.
    public class Solution {
        public void nextPermutation(int[] nums) {
            int p = nums.length-1;
            while (p>0 && nums[p]<=nums[p-1]) { --p; }
            if (p == 0) {  // case like [3,2,1]
                reverse(nums, 0, nums.length-1);
                return;
            }
            int q = nums.length-1;
            while (nums[q]<=nums[p-1]) { --q; }
            int temp = nums[p-1]; nums[p-1] = nums[q]; nums[q] = temp;
            reverse(nums, p, nums.length-1);
        }
        
        private void reverse(int[] a, int from, int to) {
            for (; from < to; ++from, --to) {
                int temp = a[from];
                a[from] = a[to];
                a[to] = temp;
            }
        }
    }

  • 1

    Great. I did almost the same thing, but I used binary search for the second scan. However, I failed to realize that I can simply reverse the non-ascending part, so I called Arrays.sort instead, thus blowing up the complexity to O(n log n). Thanks to your solution, I have improved my run time from 3 ms to 2 ms (for some reason it's slower than linear search). Here is the code:

    public void nextPermutation(int[] nums) {
        if (nums.length <= 1) {
            return;
        }
        int swap = -1;
        for (int i = nums.length - 2; i >= 0; --i) {
            if (nums[i] < nums[i + 1]) {
                swap = i;
                break;
            }
        }
        if (swap >= 0) {
            int min = nums[swap];
            int l = swap + 1, r = nums.length - 1;
            while (l <= r) {
                int m = (l + r) >>> 1;
                if (min >= nums[m]) { // "min >=" means "min + 1/2 >"
                    r = m - 1;
                } else {
                    l = m + 1;
                }
            }
            nums[swap] = nums[r];
            nums[r] = min;
            reverse(nums, swap + 1, nums.length);
        } else {
            reverse(nums, 0, nums.length); // non-descending
        }
    }
    
    private void reverse(int[] nums, int start, int end) {
        for (int i = start, j = end - 1; i < j; ++i, --j) {
            int tmp = nums[i];
            nums[i] = nums[j];
            nums[j] = tmp;
        }
    }
    

  • 0
    M

    cool. maybe it is because the test cases are not large enough...


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.