How to optimize the O(N^3) DP solution ?


  • 5

    This is my first O(N^3) DP solution

    The recursion equation is like this:

           dp[k][i] = max{ dp[k][i-1], dp[k-1][j] + prices[i-1]-prices[j] }       0<=j<i-1
    

    Here is the most naive implementation.

      class Solution {
        public:
            int maxProfit(int k, vector<int>& prices) {
                /***
                 *  dp[k][i]:the max profit we get during prices[0...i-1] with at most k transactions
                 *  dp[k][i] = max{ dp[k][i-1], dp[k-1][j] + prices[i-1]-prices[j] }       0<=j<i-1
                 *
                 *  start :  dp[0][i]=0;  dp[k][0]=0
                 ***/
                 int n=prices.size();
                 if(n<=1)  return 0;
                 
                 /** to deal with the bigest cases **/
                 if(k>n/2){
                     int result=0;
                     for(int i=1; i<n; i++)
                        result+=max(prices[i]-prices[i-1], 0);
                     return result;
                 }
                 
                 vector<vector<int>> dp(k+1, vector<int>(n+1, 0));
                 
                 for(int kk=1; kk<=k; kk++){
                     for(int i=1; i<=n; i++){
                         for(int j=0; j<i-1; j++)
                            dp[kk][i]=max(dp[kk][i-1], dp[kk-1][j]+prices[i-1]-prices[j]);
                     }
                 }
                 return dp[k][n];
            }
        };

  • 11

    Here is a possible optimization ideas:

    For equation like this :

       dp[k][i] = max{ dp[k][i-1], dp[k-1][j] + prices[i-1]-prices[j] }
    

    We can notice the right term

            prices[i-1] + (dp[k-1][j]-prices[j])
    

    So we can record the variable temp to store

           temp = max(temp, dp[k-1][j]-prices[j])  for  j<i
    

    Here is the optimized part code:

             for(int kk=1; kk<=k; kk++){
                 int temp=INT_MIN;
                 for(int i=1; i<=n; i++){
                    dp[kk][i]=max(dp[kk][i-1], temp+prices[i-1]);
                    temp=max(temp, dp[kk-1][i]-prices[i-1]);
                 }
             }
    

    Here is the final AC code:

    class Solution {
    public:
        int maxProfit(int k, vector<int>& prices) {
            /***
             *  dp[k][i]:the max profit we get during prices[0...i-1] with at most k transactions
             *  dp[k][i] = max{ dp[k][i-1], dp[k-1][j] + prices[i-1]-prices[j] }       0<=j<i-1
             *
             *  start :  dp[0][i]=0;  dp[k][0]=0
             ***/
             int n=prices.size();
             if(n<=1)  return 0;
             
             /** to deal with the bigest cases **/
             if(k>n/2){
                 int result=0;
                 for(int i=1; i<n; i++)
                    result+=max(prices[i]-prices[i-1], 0);
                 return result;
             }
             
             vector<vector<int>> dp(k+1, vector<int>(n+1, 0));
             
             for(int kk=1; kk<=k; kk++){
                 int temp=INT_MIN;
                 for(int i=1; i<=n; i++){
                    dp[kk][i]=max(dp[kk][i-1], temp+prices[i-1]);
                    temp=max(temp, dp[kk-1][i]-prices[i-1]);
                 }
             }
             return dp[k][n];      
        }
    };

  • 0
    2

    Here we can change the updation of the temp var and the dp recursion equation .

    By doing so , we make it more clear. ...

    key is

                temp = max(temp, dp[count-1][i-1]-prices[i-1]);
                dp[count][i] = max(dp[count][i-1], temp + prices[i-1]);
    

    Here is my implementation ..

    class Solution {
    public:
        int maxProfit(int k, vector<int>& prices) {
            int size_prices = prices.size();
            if(size_prices < 2)  return 0;
            if(k > (size_prices / 2)) {
                int result = 0;
                for(int i = 1; i < size_prices; i++) {
                    result += max(prices[i] - prices[i-1], 0);
                }
                return result;
            }
            
            vector<vector<int>> dp(k + 1, vector<int>(size_prices + 1, 0));
    
            for(int count = 1; count <= k; count++) {
                int temp = INT_MIN;
                for(int i = 1; i <= size_prices; i++) {
                    temp = max(temp, dp[count-1][i-1]-prices[i-1]);
                    dp[count][i] = max(dp[count][i-1], temp + prices[i-1]);
                }
            }
            
            return dp[k][size_prices];
        }
    };

  • 0
    F

    The dp call is recursion by whether the day i go on the k-th transaction!!!


  • 0

    There are a more easy to understand solution

    
    class Solution {
    public:
        int maxProfit(int k, vector<int>& prices) {
            if (prices.empty()) return 0;
            if (k > prices.size() / 2) return help(prices);
            vector<int> local(k + 1, 0);
            vector<int> global(k + 1, 0);
            for (int i = 0; i < prices.size() - 1; i++) {
                int diff = prices[i + 1] - prices[i];
                for (int j = k; j >= 1; j--) {
                    local[j] = max(global[j - 1] + (diff > 0 ? diff : 0), local[j] + diff);
                    global[j] = max(local[j], global[j]);
                }
            }
            return global[k];
        }
        
        int help(vector<int> &prices) {
            int res = 0;
            for (int i = 1; i < prices.size(); ++i) {
                if (prices[i] - prices[i - 1] > 0) {
                    res += prices[i] - prices[i - 1];
                }
            }
            return res;
        }
    };

  • 0
    Z

    said in How to optimize the O(N^3) DP solution ?:

    class Solution {
    public:
    int maxProfit(int k, vector<int>& prices) {
    /***
    * dp[k][i]:the max profit we get during prices[0...i-1] with at most k transactions
    * dp[k][i] = max{ dp[k][i-1], dp[k-1][j] + prices[i-1]-prices[j] } 0<=j<i-1
    *
    * start : dp[0][i]=0; dp[k][0]=0
    ***/
    int n=prices.size();
    if(n<=1) return 0;

             /** to deal with the bigest cases **/
             if(k>n/2){
                 int result=0;
                 for(int i=1; i<n; i++)
                    result+=max(prices[i]-prices[i-1], 0);
                 return result;
             }
             
             vector<vector<int>> dp(k+1, vector<int>(n+1, 0));
             
             for(int kk=1; kk<=k; kk++){
                 for(int i=1; i<=n; i++){
                     for(int j=0; j<i-1; j++)
                        dp[kk][i]=max(dp[kk][i-1], dp[kk-1][j]+prices[i-1]-prices[j]);
                 }
             }
             return dp[k][n];
        }
    };
    

    Your naive implementation is wrong. Try those two examples :
    3
    [1,3,2,5,7,4,9,11]
    2
    [6,1,3,2,4,7]
    it will give us 10, 5 instead of 14, 7


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.