Share my kSum function


  • 0
    M

    This method can be applied to any k >= 1 sum problem.

    class Solution {
    public:
        vector<vector<int>> fourSum(vector<int>& nums, int target) {
            vector<vector<int>> res;
            kSum(nums, 4, target, res);
            
            return res;
        }
        
    private:
        int kSum(vector<int>& v, int k, int target, vector<vector<int>>& output)
        {
            if (v.size() < k)
            {
                return 0;
            }
            
            if (k == 1)
            {
                int count = 0;
                for (int i = 0; i < v.size(); ++i)
                {
                    if (v[i] == target)
                    {
                        ++count;
                        output.push_back(vector<int>(1, target));
                        
                        break;
                    }
                }
                
                return count;
            }
            
            sort(v.begin(), v.end());
            
            if (k == 2)
            {
                return twoSum(v, 0, target, output);
            }
            
            return doKSum(v, -1, k, target, output);
        }
        
        int doKSum(vector<int>& v, int p, int k, int target, vector<vector<int>>& output)
        {
            int count = 0;
            if (k > 2)
            {
                for (int i = p + 1; i < v.size(); ++i)
                {
                    int c = doKSum(v, i, k - 1, target - v[i], output);
                    int l = output.size();
                    for (int j = 0; j < c; ++j)
                    {
                        output[l - j - 1].push_back(v[i]);
                    }
                    
                    count += c;
                    
                    while (i + 1 < v.size() && v[i + 1] == v[i])
                    {
                        ++i;
                    }
                }
            }
            else
            {
                count = twoSum(v, p + 1, target, output);
            }
            
            return count;
        }
    
        int twoSum(const vector<int>& v, int start, int target, vector<vector<int>>& output)
        {
            int end = v.size() - 1;
            int count = 0;
            
            while (start < end)
            {
                int sum = v[start] + v[end];
                if (sum < target)
                {
                    ++start;
                }
                else if (sum > target)
                {
                    --end;
                }
                else
                {
                    ++count;
                    vector<int> temp;
                    temp.reserve(2);
                    temp.push_back(v[start]);
                    temp.push_back(v[end]);
                    
                    while (start < end && v[start] == temp[0])
                    {
                        ++start;
                    }
                    
                    while (start < end && v[end] == temp[1])
                    {
                        --end;
                    }
                    
                    output.push_back(temp);
                }
            }
            
            return count;
        }
    };
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.