Fastest Python solution 180 ms


  • 6
    S

    "Your runtime beats 100.00% of python submissions." I might never get that message again when I submit a solution so it seems worth posting about.

    First I made a dictionary to keep track of the number of occurrences of each input value. Observe that no more than two instances of any given nonzero value can ever be used in a sum, so additional instances can be discarded. Three zeros can be used in a sum, but if we only have two zeros to work with, we can't use them together, so if we have only two zeros we can discard one of them.

    On a second pass through the dictionary I checked all values x that occur twice to see if -2x also occurs. If not, they can't be used together so discard one of the x's.

    Then I used the dictionary to build an ordered list of values, with duplicates to represent values that occur twice for nonzero values or three times in the case of zero. Note that the sum of the smallest two entries in the list implies ceiling on the largest useful value, while the sum of the largest two entries implies a floor on the smallest useful value. Values outside those bounds can be removed and this process was repeated until no change needed to be made on the final iteration. Example of successive iterations:
    [-10,-9,-5,-2,0,1,1,3,6,21,52] floor = -73, ceiling = 19 (remove 21,52)
    [-10,-9,-5,-2,0,1,1,3,6] floor = -9, ceiling = 19 (remove -10)
    [-9,-5,-2,0,1,1,3,6] floor = -9, ceiling = 14 (done)

    Finally I looped through the list of values to find ordered (v1,v2) pairs such that v1 <= v2 and searched the dictionary for corresponding v3 = -(v1+v2). Except for (0,0,0) which I treated as a special case, the ordering requirements imply v1 < 0 so we break out of the outer loop as soon as v1 >= 0. Similarly we break out of the inner loop as soon as we find v3 < v2. Also since the list may have duplicates, we continue past any iterations in both the outer and inner loops where we just looped on the same value.

    class Solution(object):
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        #it's pointless to have more than two instances of any number other than 0, and pointless to have more than three instances of 0. Simplification: delete extraneous instances
        instances = {}
        for n in nums:
            if n in instances:
                count = instances[n]
                if count == 1 or (n == 0 and count == 2):
                    instances[n] += 1
            else:
                instances[n] = 1
                
        #remove extraneous duplicate values. Three 0's is always useful, but two 0's isn't because no third value sums with them to 0. When count = 3 the value must be 0, so leave that alone but otherwise 0 gets no exception. For other values n, count = 2 is only useful when the value -2n is available.
        for n, count in instances.iteritems():
            if count == 2 and (n == 0 or -2 * n not in instances):
                instances[n] = 1
        
        #create an ordered list of values
        values = []
        for n, count in sorted(instances.iteritems()):
            for i in range(count):
                values.append(n)
        nvalues = len(values)
        while nvalues >= 4:
            floor = -(values[nvalues-1] + values[nvalues-2])
            ceiling = -(values[0] + values[1])
            if floor > ceiling:
                return []
            iLeft = nvalues
            iRight = -1
            for i in range(nvalues):
                if values[i] >= floor:
                    iLeft = i
                    break
            for i in range(nvalues-1, -1, -1):
                if values[i] <= ceiling:
                    iRight = i
                    break
            if iLeft == 0 and iRight == nvalues - 1:
                break
            values = values[iLeft:iRight+1]
            nvalues = len(values)
        if nvalues < 3:
            return []
            
        result = []
        #special case for (0,0,0), otherwise v1 must be negative
        if 0 in instances and instances[0] == 3:
            result.append([0,0,0])
        for i in range(nvalues-2):
            v1 = values[i]
            if v1 >= 0:
                break
            if i > 0 and v1 == values[i-1]:
                continue
            for j in range(i+1,nvalues-1):
                v2 = values[j]
                if j > i+1 and v2 == values[j-1]:
                    continue
                v3 = -(v1 + v2)
                if v3 < v2:
                    break
                if v3 in instances:
                    if v3 > v2 or instances[v3] > 1:
                        result.append([v1,v2,v3])
        return result

  • 3
    S

    I cleaned up the code a little, and got exactly the same 180 ms runtime, but I think this version is an improvement. Now on the second pass through the dictionary I'm immediately copying each value (with no duplicates) into an ordered list, and I'm adding any sums involving duplicate elements to my result right away. An intermediate pass through the dictionary was eliminated. Also, the dictionary stores the full number of elements in the original list without modifications, but once all sums involving duplicates are handled, subsequent queries only care about the presence or absence of keys and the associated values don't matter anymore. If nothing else the code should be a little easier to follow now.

    class Solution(object):
    def threeSum(self, nums):
        """
        :type nums: List[int]
        :rtype: List[List[int]]
        """
        instances = {}
        for n in nums:
            if n in instances:
                instances[n] += 1
            else:
                instances[n] = 1
        values = []
        result = []
        for n, count in sorted(instances.iteritems()):
            values.append(n)
            if n == 0 and count >= 3:
                result.append([0,0,0])
            elif n != 0 and count >= 2:
                third = -2*n
                if third in instances:
                    if n < third:
                        result.append([n,n,third])
                    else:
                        result.append([third,n,n])
        #any sums involving duplicate values were handled above
        nvalues = len(values)
        while nvalues >= 3:
            floor = -(values[nvalues-1] + values[nvalues-2])
            ceiling = -(values[0] + values[1])
            if floor > ceiling:
                return []
            iLeft = nvalues
            iRight = -1
            for i in range(nvalues):
                if values[i] >= floor:
                    iLeft = i
                    break
            for i in range(nvalues-1, -1, -1):
                if values[i] <= ceiling:
                    iRight = i
                    break
            if iLeft == 0 and iRight == nvalues - 1:
                break
            values = values[iLeft:iRight+1]
            nvalues = len(values)
        if nvalues < 3:
            return result
            
        for i in range(nvalues-2):
            v1 = values[i]
            if v1 >= 0:
                break
            for j in range(i+1,nvalues-1):
                v2 = values[j]
                v3 = -(v1 + v2)
                if v3 <= v2:
                    break
                if v3 in instances:
                    result.append([v1,v2,v3])
        return result

  • 0
    L

    thx for your code, to be honest it takes me a lot of time to understand your method.


Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.