Java O(n^4) slow and complicated solution with memoization


  • 0
    H

    For an input with boxes[0] != boxes[n-1], we can break it down to 2 smaller problems:

    1. subproblem from 0 to some k where boxes[0] == boxes[k]
    2. subproblem from k+1 to n-1

    Subproblem 2 can be further broken down with the same rule. Now we only need to worry about the case where the first and last element are the same. In this case, it's easy to prove that the first element and last element will surely be removed together, in the last round (possibly also with some other same elements in the middle of the array).

    Let reward(i, j) be the general solution for elements between i and j, and reward(i, j, k) be the solution where boxes[i] == boxes[j], and the last round of removal removes exactly k elements. We just need to iterate through the second last element in the last round of removal.

    reward(s, e) =: (we don't consider s == e case, where we can simply return 1)

    1. boxes[s] != boxes[e]: max{reward(s, i) + reward(i+1, e)}, for all i such that boxes[i] == boxes[s]
    2. boxes[s] == boxes[e]: max{reward(s, e, k)} for all k >= 2 and k <= number of appearances of boxes[s] between s and e inclusive.

    reward(s, e, k) =:

    1. k=2: 4 + reward(s+1, e-1)
    2. k>2: max{reward(s, j, k-1) - (k-1)*(k-1) + k*k + reward(j+1, e-1)} for all j such that boxes[s] == boxes[j] and the number of appearances of boxes[s] between s and j >= k-1

    I need a preprocessing step to record all the appearances of the same elements between boxes[s] and boxes[e]. I also uses 2 memoization maps to store reward(s, e) and reward(s, e, k).

    public class Solution {
        int[][][] memo = null;
        int[][] memo2 = null;
        int[] b = null;
        int n = 0;
        int[][] repeat = null;
        public int removeBoxes(int[] boxes) {
            b = boxes;
            n = b.length;
            memo = new int [n][n][n+1];
            memo2 = new int [n][n];
            repeat = new int [n][n];
            for (int i = 0; i < n; i++) repeat[i][i] = 1;
            for (int i = 0; i < n; i++) {
                for (int j = i+1; j < n; j++) {
                    if (b[i] != b[j]) continue;
                    for (int k = j-1; k >= i; k--) {
                        if (b[k] == b[j]) {
                            repeat[i][j] = repeat[i][k]+1;
                            break;
                        }
                    }
                }
            }
            return reward(0, n-1);
        }
        int reward(int s, int e) {
            if (s > e) return 0;
            if (s == e) return 1;
            if (memo2[s][e] > 0) return memo2[s][e];
            if (b[s] != b[e]) {
                int max = 0;
                for (int i = s; i < e; i++) {
                    if (b[i] == b[s]) {
                        max = Math.max(max, reward(s, i) + reward(i+1, e));
                    }
                }
                memo2[s][e] = max;
                return max;
            }
            // b[s] == b[e]
            int max = 0;
            for (int i = s; i < e; i++) {
                for (int k = 2; k <= repeat[s][e]; k++) {
                    max = Math.max(max, reward(s, e, k));
                }
            }
            memo2[s][e] = max;
            return max;
        }
        int reward(int s, int e, int k) {
            if (memo[s][e][k] > 0) return memo[s][e][k];
            if (k == 2) {
                memo[s][e][k] = 4 + reward(s+1, e-1);
                return memo[s][e][k];
            }
            // k >= 3
            int max = 0;
            for (int j = e-1; j > s; j--) {
                if (b[j] == b[s] && repeat[s][j] >= k-1) {
                    max = Math.max(max, reward(s, j, k-1) - (k-1)*(k-1) + k*k + reward(j+1, e-1));
                }
            }
            memo[s][e][k] = max;
            return max;
        }
    }
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.