Python DP+Memoization gets TLE but Java passes. Can this be fixed?


  • 1
    G
    class Solution(object):
        def helper(self, i, nums, m, cache):
            if i == len(nums):
                return 0
            elif m == 1:
                return sum(nums[i:])
            else:
                if i in cache and m in cache[i]:
                    return cache[i][m]
                cache.setdefault(i, {})
                cache[i][m] = float('inf')
                for j in range(1,len(nums)+1):
                    left, right = sum(nums[i:i+j]), self.helper(i+j, nums, m-1, cache)
                    cache[i][m] = min(cache[i][m], max(left, right))
                return cache[i][m]
        
        def splitArray(self, nums, m):
            """
            :type nums: List[int]
            :type m: int
            :rtype: int
            """
            cache = {}
            return self.helper(0, nums, m, cache)
    

  • 0

    @tarun6 I tried your solution but it doesn't finish the test case within 10 seconds. Could you please provide the equivalent Java solution here?


  • 0
    G

    The DP solution in this thread is using the exact same idea: https://discuss.leetcode.com/topic/61405/dp-java. I use memoization but this guy uses DP.


  • 1
    G
    1. You should precompute the partial sum instead of sum the subarray every time.
    2. You should use binary search in the inner loop since the monotonics of dp(i, m - 1) and sum(i, j)
      Here's the code
    class Solution(object):
        def splitArray(self, nums, m):
            sums = [0]
            for x in nums: sums.append(sums[-1] + x)
            
            def dp(nums, j, m, cache):
                if m == 1: return sums[j]
                state = (j, m)
                if state in cache: return cache[state]
                res = 0x7fffffff
                l, r = m - 1, j
                while r - l > 1:
                    mid = l + (r - l) / 2
                    lval, rval = dp(nums, mid, m - 1, cache), sums[j] - sums[mid]
                    if lval < rval:
                        l = mid
                    else:
                        r = mid
                        res = min(res, lval)
                res = min(res, max(dp(nums, l, m - 1, cache), sums[j] - sums[l]))
                cache[state] = res
                return res
                
            cache = {}
            res = dp(nums, len(nums), m, cache)
            return res
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.