@JadenPan For an interview, what you did is quite fine. I don't think any interviewer expects Manacher's algorithm. So the rest of this post is purely academic, just out of interest in what is actually possible.

Now, let's modify Manacher's algorithm.

```
def manachers(S):
A = '@#' + '#'.join(S) + '#$'
Z = [0] * len(A)
center = right = 0
for i in range(1, len(A) - 1):
if i < right:
Z[i] = min(right - i, Z[2 * center - i])
while A[i - Z[i] - 1] == A[i + Z[i] + 1]:
Z[i] += 1
if i + Z[i] > right:
center, right = i, i + Z[i]
return Z
```

First, let's discuss this implementation of Manacher's algorithm further, including why it is linear.

Our loop invariants will be that `center, right`

is our knowledge of the palindrome with the largest right-most boundary with `center < i`

, centered at `center`

with right-boundary `right`

. Also, `i > center`

, and we've already computed all `Z[j]`

's for `j < i`

.

When `i < right`

, we reflect `i`

about `center`

to be at some coordinate `j = 2 * center - i`

. Then, limited to the interval with radius `right - i`

and center `i`

, the situation for `Z[i]`

is the same as for `Z[j]`

.

For example, if at some time `center = 7, right = 13, i = 10`

, then for a string like `A = '@#A#B#A#A#B#A#$'`

, the `center`

is at the `'#'`

between the two middle `'A'`

's, the right boundary is at the last `'#'`

, `i`

is at the last `'B'`

, and `j`

is at the first `'B'`

.

Notice that limited to the interval `[center - (right - center), right]`

(the interval with center `center`

and right-boundary `right`

), the situation for `i`

and `j`

is a reflection of something we have already computed. Since we already know `Z[j] = 3`

, we can quickly find `Z[i] = min(right - i, Z[j]) = 3`

.

Now, why is this algorithm linear? The while loop only checks the condition more than once when `Z[i] = right - i`

. In that case, for each time `Z[i] += 1`

, it increments `right`

, and `right`

can only be incremented up to `2*N+2`

times.

Whew, that was a lot of information. Anyways, back to our original problem, of knowing the *count* of how many distinct palindromes there are, in linear time.

In the case that we do not increment `Z[i]`

, it is because the palindrome found was already a match to the mirrored situation. Since we have counted those (mirrored) palindromes already, then we don't need to update our count.

Otherwise, every `Z[i]++`

represents potentially a new palindrome found, and we find an amount of them linear to `N`

(where `N`

is the size of the string), since we proved that `Z[i]++`

only happens a linear number of times. In these cases, we should count that substring with indices in range `[i - Z[i], i + Z[i]]`

. But dealing with strings `S[i:j]`

is too expensive and would add an `O(N)`

factor, making our algorithm `O(N^2)`

.

We can mitigate the problem by dealing with string hashes instead of substrings. For those not familiar with this advanced concept, here is a primer: Let the hash of a word `hash(W)`

be `sum( a^i * W[i] for i in range(len(W)) ) % P`

, where `P`

is some large prime, and `a`

is some constant coprime to `P`

. Then, if we compute in linear time prefix sums `pre[i] = hash(W[:i])`

, we can compute (in constant time) queries `hash(W[i:j]) = (pre[j] - pre[i]) * (a^{-1})^{i} % P`

where `a^{-1}`

is the modular inverse of `a`

mod `P`

.

This means (up to hash collision), we can know the number of unique substrings - it's just the number of unique hashes we've found. Note in the below implementation I didn't precompute `pow(PinvQ, L, Q)`

etc. which doesn't make it a constant time query, but we could easily precompute these hashes (there are order `N`

of them.)

```
class StringHash(object):
P = 41
Q = 10**9 + 33
R = 10**9 + 123 #I'm using two primes Q and R to reduce the chance of collision
PinvQ = pow(P, Q-2, Q)
PinvR = pow(P, R-2, R)
def __init__(self, S):
self.S = S
self.prefix = [(0,0)]
cq = cr = 0
q = r = 1
for i, x in enumerate(self.S):
cq = (cq + ord(x) * q) % StringHash.Q
cr = (cr + ord(x) * r) % StringHash.R
q = q * StringHash.P % StringHash.Q
r = r * StringHash.P % StringHash.R
self.prefix.append((cq, cr))
def query(self, L, R):
cq2, cr2 = self.prefix[R+1]
cq1, cr1 = self.prefix[L]
return ((cq2 - cq1) * pow(StringHash.PinvQ, L, StringHash.Q) % StringHash.Q,
(cr2 - cr1) * pow(StringHash.PinvR, L, StringHash.R) % StringHash.R)
def solve(S):
A = '@#' + '#'.join(S) + '#$'
shash = StringHash(A)
seen = set(shash.query(i, i) for i in xrange(2, len(A)-1, 2))
Z = [0] * len(A)
center = right = 0
for i in range(1, len(A) - 1):
if i < right:
Z[i] = min(right - i, Z[2 * center - i])
while A[i - Z[i] - 1] == A[i + Z[i] + 1]:
Z[i] += 1
if A[i-Z[i]] != '#':
seen.add(shash.query(i - Z[i], i + Z[i]))
if i + Z[i] > right:
center, right = i, i + Z[i]
return len(seen)
```