Java N^2 138ms solution with explanation


  • 0
    C
    public int countPalindromicSubsequences(String S) {
        int len = S.length();
        int[][] countFrom = new int[4][len+1];  // int[c][i]countFrom: count of char c in S[i:]
        int[][] firstFrom = new int[4][len+1];  // int[c][i]firstFrom: ix of first c from S[i:]
        for (int j=0; j<4; j++) {
            firstFrom[j][len] = -1;  // there is no char from S[len:], -1 mean cannot find
        }
        for (int i=len-1; i>=0; i--) {
            char c = S.charAt(i);
            for (int j=0; j<4; j++){
                if (c-'a'==j) {  // pos i == char j
                    firstFrom[j][i] = i;
                    countFrom[j][i] = countFrom[j][i+1] + 1;
                } else {
                    firstFrom[j][i] = firstFrom[j][i+1];
                    countFrom[j][i] = countFrom[j][i+1];
                }
            }
        }
        int[][] lastFrom = new int[4][len+1];  // ix of last c from S[:j], before S[j]
        for (int j=0; j<4; j++){
            lastFrom[j][0] = -1;  // there is no char before S[0];
        }
        for (int i=0; i<len; i++) {
            char c = S.charAt(i);
            for (int j=0; j<4; j++) {
                if (c-'a' == j) {  // found a char j at pos i
                    lastFrom[j][i+1] = i;  // the last char j before position i+1 is i
                } else {
                    lastFrom[j][i+1] = lastFrom[j][i];
                }
            }
        }
        int[][]dp = new int[len][len];  // dp[i][j] stores the answer of f(i,j)
        return f(S, 0, len-1, countFrom, firstFrom, lastFrom, dp);
    }
    
    int f(String S, int i, int j, int[][] countFrom, int[][] firstFrom, int[][] lastFrom, int[][] dp) {
        /* f(i,j) is the num of palindromic subsquences in S[i: j+1], i is included, j+1 is excluded.
        if we know f(i, j-1) that is the answer for S[i:j], we need to add S[j] to the string.
        let countSj is count of S[j] in S[i:j+1].
        if countSj==1, S[j] is unique, f(i,j) = f(i,j-1)+1. "1" means the one-length palindrome S[j].
        if countSj==2, the String S[i:j+1] looks like S[i]...S[j1]...S[j], where S[j1]==S[j]. f(i,j) = f(i,j-1) + f(j1+1, j-1) + 1, where f(j1+1, j-1) is the all palindrome in S[j1+1: j], which can be expanded with S[j1] and S[j]. "1" means two times S[j].
        if countSj>2, the String S[i:j+1] looks like S[i]...S[j1]...***...S[j2]...S[j], where S[j1]==S[j2]==S[j]. S[j1] is the first element that equals to S[j] in S[i:j+1]. S[j2] is the last element that equals to S[j] in S[i:j]. f(i,j) = f(i,j-1) + f(j1+1, j-1) - f(j1+1, j2-1). f(j1+1, j-1) is all palindrome in S[j1+1: j], which can be expanded with S[j1] and S[j], but some of them like S[j1] + palindrome in S[j1+1:j2] + S[j2] are double counted. f(j1+1, j2-1) needs to be deducted.
        */
        if (i==j) return 1;
        if (i>j) return 0;
        if (dp[i][j] > 0) return dp[i][j];
        int MOD = 1000000007;
        char sj = S.charAt(j);
        int countSj =countFrom[sj-'a'][i] - countFrom[sj-'a'][j+1]; // count of Sj in S[i:j+1]
        int sum = 0;
        if (countSj==1) {
            // only 1 sj in S[i:j+1]
            sum = f(S, i, j-1, countFrom, firstFrom, lastFrom, dp) + 1;  // sj is uniq in S[i:j+1], 1 means one-length palindral
        } else if (countSj==2) {
            //  i ... j1 ... j
            // all of the palin in S[j1+1:j-1] can be extended with a sj, and form a new palin
            int j1 = firstFrom[sj-'a'][i];  // ix of first sj from i
            sum = f(S, i, j-1, countFrom, firstFrom, lastFrom, dp) + f(S, j1+1, j-1, countFrom, firstFrom, lastFrom, dp) + 1;
            // 1 means 2xsj
        } else {
            // i .X. j1 ... *Y* ... j2 .Z.j
            // j1_Y_j2 is part of j1_Y_Z_j
            int j1 = firstFrom[sj-'a'][i];  // ix of first sj from i
            int j2 = lastFrom[sj-'a'][j];  // last ix of sj before j
            sum = f(S,i, j-1, countFrom, firstFrom, lastFrom, dp) + f(S, j1+1, j-1, countFrom, firstFrom, lastFrom, dp)
                    - f(S, j1+1, j2-1, countFrom, firstFrom, lastFrom, dp);
        }
        if (sum<0) sum+=MOD;
        sum %= MOD;
        dp[i][j] = sum;
        return sum;
    }

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.