C++ Solution with O(n + klgn) time using Max Heap and Stack


  • 79
    Y

    We can find all adjacent valley/peak pairs and calculate the profits easily. Instead of accumulating all these profits like Buy&Sell Stock II, we need the highest k ones.

    The key point is when there are two v/p pairs (v1, p1) and (v2, p2), satisfying v1 <= v2 and p1 <= p2, we can either make one transaction at [v1, p2], or make two at both [v1, p1] and [v2, p2]. The trick is to treat [v1, p2] as the first transaction, and [v2, p1] as the second. Then we can guarantee the right max profits in both situations, p2 - v1 for one transaction and p1 - v1 + p2 - v2 for two.

    Finding all v/p pairs and calculating the profits takes O(n) since there are up to n/2 such pairs. And extracting k maximums from the heap consumes another O(klgn).

    class Solution {
    public:
        int maxProfit(int k, vector<int> &prices) {
            int n = (int)prices.size(), ret = 0, v, p = 0;
            priority_queue<int> profits;
            stack<pair<int, int> > vp_pairs;
            while (p < n) {
                // find next valley/peak pair
                for (v = p; v < n - 1 && prices[v] >= prices[v+1]; v++);
                for (p = v + 1; p < n && prices[p] >= prices[p-1]; p++);
                // save profit of 1 transaction at last v/p pair, if current v is lower than last v
                while (!vp_pairs.empty() && prices[v] < prices[vp_pairs.top().first]) {
                    profits.push(prices[vp_pairs.top().second-1] - prices[vp_pairs.top().first]);
                    vp_pairs.pop();
                }
                // save profit difference between 1 transaction (last v and current p) and 2 transactions (last v/p + current v/p),
                // if current v is higher than last v and current p is higher than last p
                while (!vp_pairs.empty() && prices[p-1] >= prices[vp_pairs.top().second-1]) {
                    profits.push(prices[vp_pairs.top().second-1] - prices[v]);
                    v = vp_pairs.top().first;
                    vp_pairs.pop();
                }
                vp_pairs.push(pair<int, int>(v, p));
            }
            // save profits of the rest v/p pairs
            while (!vp_pairs.empty()) {
                profits.push(prices[vp_pairs.top().second-1] - prices[vp_pairs.top().first]);
                vp_pairs.pop();
            }
            // sum up first k highest profits
            for (int i = 0; i < k && !profits.empty(); i++) {
                ret += profits.top();
                profits.pop();
            }
            return ret;
        }
    };

  • 8
    Z

    The trick is a genius idea. But since the program keeps pushing the new profit into the priority queue (i.e., the max heap), its time complexity should be O(nlog(n) + klog(n)) given that each insertion of a heap costs O(log(n)). IMHO, a better way would be to push the new profit into a vector and then heapify the vector, which will reduce the total time complexity to O(n + k*log(n)).

    class Solution 
    {
    public:
        // We can find all adjacent valley/peak pairs and calculate the profits easily. 
        // Instead of accumulating all these profits like Buy&Sell Stock II, we need 
        // the highest k ones.
        // 
        // The key point is when there are two v/p pairs (v1, p1) and (v2, p2), satisfying 
        // v1 <= v2 and p1 <= p2, we can either make one transaction at [v1, p2], or make 
        // two at both [v1, p1] and [v2, p2]. The trick is to treat [v1, p2] as the first 
        // transaction, and [v2, p1] as the second. Then we can guarantee the right max 
        // profits in both situations, p2 - v1 for one transaction and p1 - v1 + p2 - v2 
        // for two.
        // 
        // Finding all v/p pairs and calculating the profits takes O(n) since there are 
        // up to n/2 such pairs. And extracting k maximums from the heap consumes another O(k*log(n)).
        int maxProfit(int k, vector<int> &prices) 
        {
            int ret = 0;
            int n = prices.size(); 
            int v = 0;  // valley index 
            int p = 0;  // peak index + 1;
            
            vector<int> profits;
            stack<pair<int, int>> vp_pairs;
            while (p < n) 
            {
                // Find next valley/peak pair.
                for (v = p; (v < n - 1) && (prices[v] >= prices[v + 1]); v++);
                for (p = v + 1; (p < n) && (prices[p] >= prices[p - 1]); p++);
                
                // Save profit of 1 transaction at last v/p pair, if current v is lower than last v.
                while (!vp_pairs.empty() && (prices[v] < prices[vp_pairs.top().first]))
                {
                    profits.push_back(prices[vp_pairs.top().second - 1] - prices[vp_pairs.top().first]);
                    vp_pairs.pop();
                }
                
                // Save profit difference between 1 transaction (last v and current p) and 2 transactions 
                // (last v/p + current v/p), if current v is higher than last v and current p is higher 
                // than last p.
                while (!vp_pairs.empty() && (prices[p - 1] >= prices[vp_pairs.top().second - 1])) 
                {
                    profits.push_back(prices[vp_pairs.top().second - 1] - prices[v]);
                    v = vp_pairs.top().first;
                    vp_pairs.pop();
                }
                
                vp_pairs.push(pair<int, int>(v, p));
            }
            
            // Save profits of the remaining v/p pairs.
            while (!vp_pairs.empty()) 
            {
                profits.push_back(prices[vp_pairs.top().second - 1] - prices[vp_pairs.top().first]);
                vp_pairs.pop();
            }
            
            if (k >= profits.size())
            {
                // Since we have no more than k profit pairs, the result is the sum of all pairs.
                ret = accumulate(profits.begin(), profits.end(), 0);
            }
            else
            {
                // Move the k highest profits to the end and the average time complexity should be O(n).
                nth_element(profits.begin(), profits.begin() + profits.size() - k, profits.end());
                // Sum up the k highest profits.
                ret = accumulate(profits.begin() + profits.size() - k, profits.end(), 0);
            }
            
            return ret;
        }
    };

  • 0
    Y

    Test cases are not strong enough. This algorithm fails in the case 1, {1,7,2,8,3,9} which should be 8, but output is 7. Only the case which requires one merge of intervals is considered. However, in this case, three v/p should all be merged.


  • 0
    Y

    You are right. Building heap by inserting elements one by one has O(nlgn) in the worst case. Thank you for pointing it out. Fantastic use of std::make_heap and std::pop_heap by the way:-)


  • 0
    Y

    As I tested, my code has the right output (8) in your case.


  • 0
    Y

    Sorry, I was too rush last time. You're right :P


  • 0
    C

    Here is another solution for the second part. Without using heap, I use selection ranking algorithm.

    Since oj leetcode doesn't support Random class, the algorithm is average O(N) worst O(N^2). With random pivot, theoretically this algorithm is O(N).

    public class Solution {
    public int maxProfit(int k, int[] prices) {
        Stack<Point> vp_pairs = new Stack<Point>();
        ArrayList<Integer> profits = new ArrayList<Integer>();
        int n = prices.length, v = 0, p = 0;
        while(p < n){
            for(v = p; (v < n - 1) && (prices[v] >= prices[v + 1]); v++);
            for(p = v + 1; (p < n) && (prices[p] >= prices[p - 1]); p++);
            while(!vp_pairs.isEmpty() && (prices[v] < prices[vp_pairs.peek().x])){
                profits.add(prices[vp_pairs.peek().y - 1] - prices[vp_pairs.peek().x]);
                vp_pairs.pop();
            }
            while(!vp_pairs.isEmpty() && (prices[p - 1] >= prices[vp_pairs.peek().y - 1])){
                profits.add(prices[vp_pairs.peek().y - 1] - prices[v]);
                v = vp_pairs.peek().x;
                vp_pairs.pop();
            }
            vp_pairs.push(new Point(v, p));
        }
        while(!vp_pairs.isEmpty()){
            Point pt = vp_pairs.pop();
            profits.add(prices[pt.y - 1] - prices[pt.x]);
        }
    
        if(k < profits.size())
            selectionRank(profits, k);
        int res = 0;
        for(int i = 0; i < k && i < profits.size(); i++)
            res += profits.get(i);
        return res;
    }
    
    public int Partition(ArrayList<Integer> profits, int head, int end){
        int pivot = end - 1;
        int p = head - 1;
        for(int i = head; i < end; i++){
            if(profits.get(i) > profits.get(pivot)){
                p++;
                int tmp = profits.get(i);
                profits.set(i, profits.get(p));
                profits.set(p, tmp);
            }
        }
        p++;
        int tmp = profits.get(pivot);
        profits.set(pivot, profits.get(p));
        profits.set(p, tmp);
        return p;
    }
    
    public void selectionRank(ArrayList<Integer> profits, int k){
        int head = 0, tail = profits.size();
        int cur_index = -1;
        while(cur_index != k){
            cur_index = Partition(profits, head, tail);
            if(cur_index > k){
                tail = cur_index;
            }else if(cur_index < k)
                head = cur_index + 1;
        }
    }
    

    }


  • 2

    How about using the median-of-three to select the pivot.


  • 0
    L

    Can someone comment why is this a DP problem? What's the subproblem? Thanks.


  • 0
    T

    I was thinking along the same lines, basically trying to merge all the existing continuous increasing sequences into K sequences. but then there is the case where some raw continuous increasing sequences could be ignored ....

    could you elaborate on the idea more ? I hate to read too much "spoilers" by reading the code and lose this chance for learning. thanks!


  • 0
    U

    Awesome! Thank you for sharing such a good solution!

    I think it will be better to use http://en.cppreference.com/w/cpp/algorithm/nth_element.


  • 0

    I suggest use nth_element() and accumulate() rather than make_heap(), pop_heap() to find the summation of k largest elements.


  • 0
    Z

    Makes sense. Thank zhiqing for the tips! I have updated the code to use nth_element() and accumulate().


  • 0

    I think this is the greedy approach, not dp


  • 0
    P

    How can v2,p1 considered as a transaction opportunity. Because I think v2 always happens after p1.( we cannot sell stock before even buying them. ) Can you please explain?


  • 0
    S

    briliant answer!!!
    nth_element() is good trick!!


  • 0
    S

    The trick is beautiful idea. I saw it here and realize my solution.

    class Solution {
    public:
        int maxProfit(int k, vector<int>& prices) {
            vector<int> main_prices;
            int direct = -1, cl = 1;
            for (int i = 0; i + 1 < prices.size(); i += cl) {
                cl = 1;
                while (prices[i + cl] == prices[i]) cl++;
                if ((prices[i + cl] - prices[i]) * direct < 0) {
                    main_prices.push_back(prices[i]);
                    direct *= -1;
                }
            }
            if (direct == 1)
                main_prices.push_back(prices[prices.size() - 1]);
            if (main_prices.empty())
                return 0;
            vector<int> profits;
            stack<int> l, h;
            for (int i = 0; i < main_prices.size(); i += 2) {
                while (l.size() && l.top() > main_prices[i]) {
                    profits.push_back(h.top() - l.top());
                    h.pop();
                    l.pop();
                }
    
                while (l.size() && h.top() < main_prices[i + 1]) {
                    profits.push_back(h.top() - main_prices[i]);
                    h.pop();
                    main_prices[i] = l.top();
                    l.pop();
                }
                l.push(main_prices[i]);
                h.push(main_prices[i + 1]);
            }
            while (l.size()) {
                profits.push_back(h.top() - l.top());
                h.pop();
                l.pop();
            }
            if (k >= profits.size())
                return accumulate(profits.begin(), profits.end(), 0);
            nth_element(profits.begin(), profits.end() - k, profits.end());
            return accumulate(profits.end() - k, profits.end(), 0);        
        }
    };
    

    Main differences: Averange time O(n).8ms. Another methods to find k highest profits and to find peak and valley.
    Please can anybody explain how this 2 line works? Why v+1 but p-1? Must it be local min and max?

    for (v = p; v < n - 1 && prices[v] >= prices[v+1]; v++);
    for (p = v + 1; p < n && prices[p] >= prices[p-1]; p++);

  • 0

    LeetCode does support the Random class. You just either need to address it as java.util.Random or add import java.util.Random; before the class declaration.


  • 1

    @praneeth1209, took me a while to figure it out too. (v2, p1) is an imaginary transaction. Its “profit” is p1 - v2, which, if added to the profit of the (v1, p2) transaction, gives exactly p1 - v2 + p2 - v1 == p1 - v1 + p2 - v2. That is, the same profit we get from two regular transactions performed in sequence. Of course it doesn't make sense without the first transaction. But because of the v1 <= v2 && p1 <= p2 condition, we have p1 - v2 <= p2 - v1 (draw a picture to understand it better) and therefore, it almost guarantees that we perform the first transaction beforehand (unless the profits are equal, but then the order obviously doesn't matter).


  • 2

    A really awesome idea. My own implementation in Java runs in 3-4 ms (using arrays for stacks and randomized quickselect for calculating the profit). I have written a very detailed explanation of it in my blog (I describe three solutions there, scroll down to the third one) for those who have a hard time figuring out all the tricks and cases and loop invariants.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.