How does one think up DP solutions for these types of problems?


  • 0
    T

    I'm trying to improve so I just joined but for every question I've looked at that involves DP I've been stuck. I can't seem to figure out what the subproblem should be, how it should be modeled in a matrix, how to iterate, and im amazed how other people come up with these solutions (if i can understand the solutions which often i can't.). Is there some trick to coming up with these solutions or do you just have to be smart?


  • 18
    S

    Practice on similar questions, think carefully and you would get better understanding. Almost no one could solve it the first day he/she learns DP.

    For this problem, if you have enough experience, it's natural to come up with the O(n3) dp solution. When I first approached it, here is what I thought:

    In the first round, I could choose any number between 1 and n. The goal is minimize the money I would pay and naturally there is one or more numbers corresponding to the optimal strategy.

    At this point, usually either there is some formula or rule that could guide you to which number OR you should try every number and compute the cost for each choice. The former case often corresponds to brain teasers or math knowledge or insights so it's not as usual as the second case (we are programers rather than mathematicians ).

    Suppose I choose number k, now I should pay k, then the correct number is (1) == k, (2) > k (3) < k. Now that the problem is asking for the case of a guaranteed win, so we should consider the worst case, which is the worse one of (2) and (3).

    If (2), we then want to know how much at least we should pay to win of guessing a number between 1 to k-1. If (3) then we want to know how much at least for guessing between k+1 and n. At this point, a sub-problem is clear. I.e. how much at least we should pay to win of guessing a number between [i, j]. It's also clear that sol(i, j) by choose k = k + max(sol(i, k-1), (k+1, j)) k from i to j.

    DP has yet come in until we easily realize we would do a lot of repeated computations for sol(i, j).

    Going back to the choice step. For each choice of k, we could compute the cost, and we choose the minimum cost. So finally we get sol(i, j) = min_k (k + max(sol(i, k-1), (k+1, j)) k from i to j.

    The rest details (base case, directions) are easier once you figure out the above part and have practiced enough about DP.

    As you can see, it may look daunting for beginners but it's natural for the experienced. Practice more and I guarantee you could solve these questions soon on your own. (try flip game II, burst balloon, max coins on a line(not leetcode))


  • 0

    @stupidbird911 Thank you so much for this detailed reply!


  • 6
    C

    First of all, I'm still having hard times to solve DPs. The previous answers is a good explanation and yes, it is right about the experience part. You should practice a lot. But I just wanted to clear the trick part to these type of questions since people may shocked when they first see the DP solutions. I hope it helps.

    Basic recursion

    As I said, I don't have much experience about DP. So when I encounter a question about DP, first I try to solve it with an easy approach. This approach is mostly a recursive solution. For this specific question, here is the solution with recursion;

    public class Solution {
    
        public int getMoneyAmount(int n) {
            return getMoneyAmount(1, n);
        }
    
        private int getMoneyAmount(int s, int e) {
            int diff = e - s;
            //Base part
            if (diff < 1) {
                return 0;
            }
            if (diff == 1) {
                return s;
            }
            if (diff == 2) {
                return s + 1;
            }
            //Recursion part
            int min = Integer.MAX_VALUE;
            for (int i = s + 1; i < e; i++) {
                min = Math.min(min, i + Math.max(getMoneyAmount(s, i - 1), getMoneyAmount(i + 1, e)));
            }
            return min;
        }
    }
    

    The tricky parts are base conditions and the loop to calculate the minimum cost at the end. I won't give detail explanation about it since previous post already did. You should be able to come a recursive solution after some trials. But even if this solution does solve the problem, it doesn't meet the speed and space requirements. It has O(3^n) time complexity which is not practical. So lets speed it up!

    Memoization

    For many recursive solutions, the function will actually calculate the same result with the same inputs again and again. If we can save this result and use it when required, we can speed up the recursive process. This is called memoization. You can read more about it here.

    But how we will achieve it for our recursion? In order to understand how you will do it, you just need to check your recursive function parameters. For our recursive solution it has only two parameters: s (start) and e (end). Simply if you can remember the result of a call like getMoneyAmount(1, 5) then you can use it every time without calculating it again.

    As the next step, you need to understand the range for these input parameters. When you look at start and end parameters you realize that start starts from 1 and can goes to n. You can also say the same for the end parameter. So for every pair like start(1..n),end(1..n) you need to save the result. We will create a basic two dimentional array to save these results. When we call the function getMoneyAmount(1, 5) for the first time we will calculate it and save it. When another recursion call it again, we will return it from our table. You can also call it cache.

    After memoization, now our solution will be like below;

    public class Solution{
    
        public int getMoneyAmount(int n) {
            //Initialize cache and fill it with -1. 
            //We will use -1 to check if it is calculated before or not
            int[][] memo = new int[n + 1][n + 1];
            for (int i = 0; i < n + 1; i++) {
                Arrays.fill(memo[i], -1);
            }
            //Call cached recursion
            return getMoneyAmountMemo(1, n, memo);
        }
    
        //Cached recursion
        private int getMoneyAmountMemo(int s, int e, int memo[][]) {
            //Check if the recursion is calculated before
            if (memo[s][e] == -1) {
                //Calculate recursion and save it to cache
                memo[s][e] = getMoneyAmount(s, e, memo);
            }
            return memo[s][e];
        }
    
        //Previous recursive solution. Added cache parameter to pass it to other recursion calls 
        private int getMoneyAmount(int s, int e, int memo[][]) {
            int diff = e - s;
            if (diff < 1) {
                return 0;
            }
            if (diff == 1) {
                return s;
            }
            if (diff == 2) {
                return s + 1;
            }
            int min = Integer.MAX_VALUE;
            for (int i = s + 1; i < e; i++) {
                //Now call for cached recursive function instead of directly calling basic recursive function
                min = Math.min(i + Math.max(getMoneyAmountMemo(s, i - 1, memo), getMoneyAmountMemo(i + 1, e, memo)), min);
            }
            return min;
        }
    
    }
    

    Now our solution calculates the result for every start and end pair which will take n*n iteration. For every start and end pair it will also check n values to find the minimum cost (the for loop in basic recursive function). So this solution has O(n^3) time complexity. It uses extra O(n^2) space to keep cache and it also uses recursive function call stacks.

    DP

    We solved the speed problem but when you run this solution, you may experience space problems. Recursion uses a lot of space to keep track the function calls. This is in its nature. This huge recursion stacks may cause to the famous StackOverFlow exceptions. We need to calculate every values for our memoization table and do it without recursion some how.

    In order to remove recursion, we need to look at the place which it occurs. For our case, it is in the for loop of basic recursive function. For a random start and end pair, we basically check all the cached results from our memoization tables memo[s][t-1] and memo[t+1][e]. This illustrates that in order to find the minimum value for a pair, we need to know the values in the previous columns of same row (s,t-1) and the the next rows of same column(t+1,e). This means If I initialize my memoization table from the lower left to the upper right, all necessary values will be calculated for any pairs.

    Lets remove recursion and initialize memo table with the for loops start from lower left corner to upper right corner. Our solution will be like below;

    public class Solution {
        public int getMoneyAmount(int n) {
            //Call it dp instead of memo
            int [][] dp = new int[n][n];
            //Goes from n to 0 (from lower to upper)
            for(int i=n; i>=0; i--) {
                //Set values for base conditions
                if(i+1<n) {
                    dp[i][i+1] = i+1; // add 1 to the values since i start from 0 not from 1
                }
                if(i+2<n) {
                    dp[i][i+2] = i+2; 
                }
                //Goes from left to right
                for(int j=i+3; j<n ; j++) {
                    //This part is similar to the for  loop in basic recursive function
                    int min = Integer.MAX_VALUE;
                    for(int t=i+1 ; t<j ; t++) {
                        min = Math.min(t + 1 + Math.max(dp[i][t-1], dp[t+1][j]), min);
                    }
                    dp[i][j] = min;
                }
            }
            return dp[0][n-1];
        }
    }
    

    Now it uses extra O(n^2) space and its time complexity is O(n^3).

    This is how I deal with DP problem. After some point, you start realizing DP solution without these transitions. You just need to practice a lot.


  • 0
    D
    This post is deleted!

  • 1

    @tambourine I have the same problem as you do, and I strongly recommend this blog to you because the idea of the blogger works like a charm.


  • 0

    @stupidbird911 Thanks for your wonderful and detailed analysis of the ideas, and the learn way of dynamic programing


  • 0
    L

    @stupidbird911 Nice reply. it's really help me, thanks


  • 0
    C

    @stupidbird911 This is a beautiful reply !! Amazing algorithmic thinking ! Will surely try to think of problems in this way next time ! Thank you !

    Here is my accepted solution :

    public class Solution {
    public int getMoneyAmount(int n) {

        int[][] dp = new int[n + 1][n + 1];
        
        for(int i = 1; i < n; i++)
        {
            dp[i][i + 1] = i;
        }
        
        int j = 0;
        
        for(int k = 3; k <= n; k++)
        {
            for(int i = 1; i <= n - k + 1; i++)
            {
                j = i + k - 1;
                int min = Integer.MAX_VALUE;
                for(int r = i + 1; r < j; r++)
                {
                    min = Math.min(min, r + Math.max(dp[i][r - 1], dp[r + 1][j]));
                }
                dp[i][j] = min;
            }
        }
        
        return dp[1][n];
    }
    

    }


  • 0

    @stupidbird911 Thank you very much for providing so detail explanation.


  • 0
    L

    @stupidbird911 Thanks a lot!


  • 0
    Z

    @ckarabulut
    Thank you for the detailed and straightforward answer, but I still have a question about what the variable i means in the DP solution. Does it mean the difference of start and end? And what is the relationship between i and j? I am quite confused.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.