# Java N^2 138ms solution with explanation

• ``````public int countPalindromicSubsequences(String S) {
int len = S.length();
int[][] countFrom = new int[4][len+1];  // int[c][i]countFrom: count of char c in S[i:]
int[][] firstFrom = new int[4][len+1];  // int[c][i]firstFrom: ix of first c from S[i:]
for (int j=0; j<4; j++) {
firstFrom[j][len] = -1;  // there is no char from S[len:], -1 mean cannot find
}
for (int i=len-1; i>=0; i--) {
char c = S.charAt(i);
for (int j=0; j<4; j++){
if (c-'a'==j) {  // pos i == char j
firstFrom[j][i] = i;
countFrom[j][i] = countFrom[j][i+1] + 1;
} else {
firstFrom[j][i] = firstFrom[j][i+1];
countFrom[j][i] = countFrom[j][i+1];
}
}
}
int[][] lastFrom = new int[4][len+1];  // ix of last c from S[:j], before S[j]
for (int j=0; j<4; j++){
lastFrom[j][0] = -1;  // there is no char before S[0];
}
for (int i=0; i<len; i++) {
char c = S.charAt(i);
for (int j=0; j<4; j++) {
if (c-'a' == j) {  // found a char j at pos i
lastFrom[j][i+1] = i;  // the last char j before position i+1 is i
} else {
lastFrom[j][i+1] = lastFrom[j][i];
}
}
}
int[][]dp = new int[len][len];  // dp[i][j] stores the answer of f(i,j)
return f(S, 0, len-1, countFrom, firstFrom, lastFrom, dp);
}

int f(String S, int i, int j, int[][] countFrom, int[][] firstFrom, int[][] lastFrom, int[][] dp) {
/* f(i,j) is the num of palindromic subsquences in S[i: j+1], i is included, j+1 is excluded.
if we know f(i, j-1) that is the answer for S[i:j], we need to add S[j] to the string.
let countSj is count of S[j] in S[i:j+1].
if countSj==1, S[j] is unique, f(i,j) = f(i,j-1)+1. "1" means the one-length palindrome S[j].
if countSj==2, the String S[i:j+1] looks like S[i]...S[j1]...S[j], where S[j1]==S[j]. f(i,j) = f(i,j-1) + f(j1+1, j-1) + 1, where f(j1+1, j-1) is the all palindrome in S[j1+1: j], which can be expanded with S[j1] and S[j]. "1" means two times S[j].
if countSj>2, the String S[i:j+1] looks like S[i]...S[j1]...***...S[j2]...S[j], where S[j1]==S[j2]==S[j]. S[j1] is the first element that equals to S[j] in S[i:j+1]. S[j2] is the last element that equals to S[j] in S[i:j]. f(i,j) = f(i,j-1) + f(j1+1, j-1) - f(j1+1, j2-1). f(j1+1, j-1) is all palindrome in S[j1+1: j], which can be expanded with S[j1] and S[j], but some of them like S[j1] + palindrome in S[j1+1:j2] + S[j2] are double counted. f(j1+1, j2-1) needs to be deducted.
*/
if (i==j) return 1;
if (i>j) return 0;
if (dp[i][j] > 0) return dp[i][j];
int MOD = 1000000007;
char sj = S.charAt(j);
int countSj =countFrom[sj-'a'][i] - countFrom[sj-'a'][j+1]; // count of Sj in S[i:j+1]
int sum = 0;
if (countSj==1) {
// only 1 sj in S[i:j+1]
sum = f(S, i, j-1, countFrom, firstFrom, lastFrom, dp) + 1;  // sj is uniq in S[i:j+1], 1 means one-length palindral
} else if (countSj==2) {
//  i ... j1 ... j
// all of the palin in S[j1+1:j-1] can be extended with a sj, and form a new palin
int j1 = firstFrom[sj-'a'][i];  // ix of first sj from i
sum = f(S, i, j-1, countFrom, firstFrom, lastFrom, dp) + f(S, j1+1, j-1, countFrom, firstFrom, lastFrom, dp) + 1;
// 1 means 2xsj
} else {
// i .X. j1 ... *Y* ... j2 .Z.j
// j1_Y_j2 is part of j1_Y_Z_j
int j1 = firstFrom[sj-'a'][i];  // ix of first sj from i
int j2 = lastFrom[sj-'a'][j];  // last ix of sj before j
sum = f(S,i, j-1, countFrom, firstFrom, lastFrom, dp) + f(S, j1+1, j-1, countFrom, firstFrom, lastFrom, dp)
- f(S, j1+1, j2-1, countFrom, firstFrom, lastFrom, dp);
}
if (sum<0) sum+=MOD;
sum %= MOD;
dp[i][j] = sum;
return sum;
}``````

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.