O(n log ^ 2 n) Java 2 ms solution using Fenwick tree—just for fun


  • 0

    OK, I admit that it has no point for such small inputs, and for large input our main trouble will be computing the factorial, but the same technique can be used for other problems when you need to keep a list of something, delete something by indexes and perform relatively fast searches.

    The idea is that we keep an array of indexes that accompany digits. Or, rather, let me use letters to avoid confusion between indices and digits. Imagine we need to compute the 3rd permutation of "abcd", which is "acbd". Our array first looks like this (1-based indexes):

    1 2 3 4
    a b c d
    

    (n - 1)! is 6, and (k - 1) / 6 = 0, so we use the first digit and “delete” it by decrementing the indexes to the right:

    1 1 2 3
    a b c d
    

    Now we update k = (k - 1) % 6 + 1 = 3, and compute (k - 1) / (n - 2)! = 1, so we need the second digit now. Delete it and we get

    1 1 2 2
    a b c d
    

    Now to the interesting part. k = (k - 1) % (n - 2)! + 1 = 1, so we need the first digit again. But there are two first digits, so we should pick the last one (a sequence of equal indices means that only the last value is not deleted). So when we are performing the binary search for the next digit, we need to look for “insertion point” between 1 and 2 (b and c). The digit we're looking for is to the left of that insertion point.

    To perform fast deletion we need to update all indexes to the right. Structures like Fenwick tree work perfectly in such situations, giving O(log n) update time, but then lookup is also O(log n), so binary search becomes O(log ^ 2 n).

    Here is the code in Java. To avoid populating the initial tree with all '1's (to make the ith prefix sum equal to the index), I chose to initialize it with zeroes, but when summing, initialize the sum not with zero, but with “supposed” index.

    public String getPermutation(int n, int k) {
        int n1f = 1; // (n - 1)!
        for (int i = 2; i <= n - 1; ++i) {
            n1f *= i;
        }
        int[] bit = new int[n];
        char[] buf = new char[n];
        for (int kk = k - 1, pos = 0; ; kk %= n1f, n1f /= (n - pos)) {
            int index = kk / n1f + 1;
            int l = 1, r = n;
            while (l <= r) { // do binary search for the digit with the needed index
                int m = (l + r) >>> 1, i = m;
                for (int j = m; j > 0; j -= j & -j) {
                    i += bit[j - 1]; // compute the index of mth digit
                }
                if (index >= i) { // greater than or equal to because we need to skip deleted digits
                    l = m + 1;
                } else {
                    r = m - 1;
                }
            }
            buf[pos++] = (char) ('0' + l - 1);
            if (pos == n) {
                break; // need to break here to avoid div by zero in the `for` statement
            }
            for (int j = l; j <= n; j += j & -j) {
                --bit[j - 1]; // delete the digit by shifting indexes
            }
        }
        return new String(buf);
    }

  • 0
    J

    "when summing, initialize the sum not with zero", Could you please tell me how to realized this? I still got confused about this part. Thanks.


  • 0
    J

    I got it. It is really a good solution, but a little complex!


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.