N^2 DP Python with Explanation


  • 12

    Let dp(i, j) be the answer for the string T = S[i:j+1] including the empty sequence. The answer is the number of unique characters in T, plus dp(next('a', i) + 1, prev('a', j) - 1) representing palindromes of the form "a_a" where _ is zero or more characters, plus dp(next('b', i) + 1, prev('b', j) - 1) representing "b_b", etc.

    Here, next('a', i) means the next index at or after i where A[next('a', i)] = 'a', and so on.

    class Solution(object):
        def countPalindromicSubsequences(self, S):
            N = len(S)
            A = [ord(c) - ord('a') for c in S]
            prv = [None] * N
            nxt = [None] * N
        
            last = [None] * 4
            for i in xrange(N):
                last[A[i]] = i
                prv[i] = tuple(last)
                
            last = [None] * 4
            for i in xrange(N-1, -1, -1):
                last[A[i]] = i
                nxt[i] = tuple(last)
            
            MOD = 10**9 + 7
            memo = [[None] * N for _ in xrange(N)]
            def dp(i, j):
                if memo[i][j] is not None:
                    return memo[i][j]
                ans = 1 # The empty-string palindrome
                if i <= j:
                    for x in xrange(4): # For letter a, b, c, d ...
                        i0 = nxt[i][x]
                        j0 = prv[j][x]
                        if i <= i0 <= j:
                            ans += 1 # The letter x exists in [i, j]
                        if None < i0 < j0:
                            ans += dp(i0+1, j0-1) # Counting palindromes "x_x"
                ans %= MOD
                memo[i][j] = ans
                return ans
            
            return dp(0, N-1) - 1 #Subtract empty string
    

  • 0
    A

    @awice said in N^2 DP Python with Explanation:

    last[A[i]] = i

    Please help to explain the role of last? last = [None]*4 a, a list of length 4, then last[A[i]] = i


  • 0
    A

    Ok that's because the string only contains 'a', 'b', 'c', and 'd'


  • 0
    8

    I implemented bottom up version with the same idea but got TLE. Interesting.


  • 0

    @8939123 Can I see the solution? This solution passes in 1300ms. I set the time to 4000ms and I had tested two other variants which pass in ~2000ms.


  • 0
    D

    It seems you easily can make an O(n) time and space solution out of this (assuming a fixed alphabet length) by just using a hash for memo. Since each dp call reduces the string by at least 2 you only get about n/2 dp calls.

    Your description doesn't make completely clear why it yields the right numbers. (For one thing you don't explicitly count the number of unique characters, but this seems to be implicitly handled by dp by the condition i <= i0 <= j which leads to an extra count of a_a with switched indexes.) How can one show this? Via induction?


  • 0
    D

    The python solution inspired me to write this Java solution. It should be O(n) and passes in about 260ms.

    class Solution {
    
      private static final String alphabet = "abcd";
      private static final int mod = 1_000_000_007;
      private Map<Integer, Integer> dp ;
      private int[][] prev;
      private int[][] next;
    
      public int countPalindromicSubsequences(String S) {
        dp = new HashMap<>(S.length());
        prev = new int[S.length()][alphabet.length()];
        next = new int[S.length()][alphabet.length()];
    
        final int[] nextIndexes = new int[alphabet.length()];
        Arrays.fill(nextIndexes, -1);
        for (int i = 0; i < S.length(); i++) {
          nextIndexes[S.charAt(i) - 'a'] = i;
          prev[i] = nextIndexes.clone();
        }
        Arrays.fill(nextIndexes, -1);
        for (int i = S.length() - 1; i >=0 ; i--) {
          nextIndexes[S.charAt(i) - 'a'] = i;
          next[i] = nextIndexes.clone();
        }
    
        return countPalindromicSubsequences(S, 0, S.length() - 1) - 1;
      }
    
      private int countPalindromicSubsequences(String S, int start, int end) {
        return dp.computeIfAbsent(start * S.length() + end, coord ->
            (start > end) ? 1 : IntStream.range(0, alphabet.length()).reduce(1, (sum, c) -> {
              final int nextStart = next[start][c];
              final int nextEnd = prev[end][c];
              if (start <= nextStart && nextStart <= end) {
                sum++;
              }
              if (nextStart != -1 && nextStart < nextEnd) {
                sum += countPalindromicSubsequences(S, nextStart + 1, nextEnd - 1);
                sum %= mod;
              }
              return sum;
            }) % mod
        );
      }
    }
    

  • 0
    R

    Hi awice,
    can you explain why you do below add?
    if i <= i0 <= j:
    ans += 1

    the code runs correctly, but i dont quite understand why you add 1 here.


  • 1

    @reliveinfire I count whether the letter represented by x actually exists in S[i:j+1]. It does if the next occurrence i0 of the letter x in S[i:] occurs inside the interval [i, j].


  • 0
    D

    @awice Can you please add comments in the code for "ans = 1" and "ans += 1" in function dp? I don't understand what the 1 is added for.
    As far as I understand. ans = 1 seems to count empty string. ans += 1 is for the subsequence consisting of S[i0] and/or S[j0] only. But dp(i0+1, j0-1) may count duplicated subseqs if so.
    Thank you


  • 0

    @dree I added some comments.


  • 0
    C

    @awice I read your post and article several times. I am still confused by

                        if i <= i0 <= j:
                            ans += 1 # The letter x exists in [i, j]
    

    and how your program solve "aaa"?

    can anyone help explain? thanks


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.