The problem can actually be extended to a more general case where we allow at most m
adjacent posts with same color. If m=2
, it will reduce to exactly the original problem. Let T[n]
be the number of ways to the extended problem to paint n
posts. Obviously, we have initial condition:
T[n] = pow(k, n)
, for anyn <= m
.
Also, as in case m = 2
, we can prove the recursion equation
T[n]=(k1)(T[n1]+T[n2]+...+T[nm])
, wheren>m
. (*)
Note that the restriction in this problem is that at most m
adjacent posts with same color are allowed, so whether we can paint a post freely solely depends on how many previous adjacent posts have the same color. Let S[j][n]
be the number of ways to paint n
posts with exactly last j
posts painted with same color (1<=j<=m
), then we can decompose the total number of valid painting solutions into
T[n]=S[m][n] + S[m1][n] + ... + S[1][n]
.
Now we can have two key observations:

S[1][n]=(k1)*T[n1]
, that is if we have a solution to paintn1
posts, simply using a color different from postn1
to paint postn
will give a solution forS[1][n]
without violating the rule (also holds vice versa).

S[j][n]=S[j1][n1]
, that is if we have a solution to paintn1
posts with exactly lastj1
in different colors, simply repeating the same color from postn1
to paint postn
will give a solution forS[j][n]
without violating the rule (also holds vice versa).
Now combining equations 1, 2 as well as decomposition of T[n]
will prove the recursion (*). The coding will be straightforward as following. (O(nm)
time, O(m)
space).
Btw, I do think the corner case answer for n=0
should be 1 instead of 0 (not sure what is OJ's answer). Because in the sense of math, doing nothing is also a "valid" paint way since it does not violate any rules.
int numWays(int n, int k, int m) {
vector<int> T(m+1, 1);
for (int i = 1; i <= m; ++i) T[i] = T[i1] * k; // T[i] = pow(k, i) for i <= m
for (int i = m+1, j; i++ <= n; T[m] = T[0]) // T[n] = (k1)*(T[n1]+...+T[nm])
for (j = 1, T[0] = 0; j <= m; ++j) T[0] += (k1)*T[j], T[j] = T[(j+1)%(m+1)];
return T[(n <= m)*n];
}