Python DP + segment_tree, O(nlogn)


  • 0
    D

    First, calculate the length of LIS endswith index i, we call it dp[n].
    dp[i] = max{dp[j]} + 1 where j < i and a[j] < a[i]. We can use a segment tree to make dp transfer faster.
    Moreover, for counting job, we just save it in segment tree, so we can solve the problem in O(nlogn) time-complexity.

    def merge(x, y):
        if x[0] == y[0]:
            if x[0] == 0:
                return (0, 1)
            return (x[0], x[1] + y[1])
        return max(x, y)
    
    class TreeNode(object):
        def __init__(self, start, end):
            self.left = start
            self.right = end
            self._lchild = self._rchild = None
            self.max_val = 0
            self.max_cnt = 1
        def push_up(self):
            self.val_tuple = merge(self.lchild.val_tuple, self.rchild.val_tuple)
        @property
        def mid(self):
            return (self.left + self.right) / 2
        @property
        def val_tuple(self):
            return (self.max_val, self.max_cnt)
        @val_tuple.setter
        def val_tuple(self, val):
            self.max_val, self.max_cnt = val
        @property
        def lchild(self):
            if not self._lchild:
                self._lchild = TreeNode(self.left, self.mid)
            return self._lchild
        @property
        def rchild(self):
            if not self._rchild:
                self._rchild = TreeNode(self.mid + 1, self.right)
            return self._rchild
    class SegmentTree(object):
        def __init__(self, start, end):
            self.root = TreeNode(start, end)
    
        def query(self, key, cur):
            if cur.right <= key:
                return cur.val_tuple
            elif cur.left > key:
                return (0, 1)
            else:
                return merge(self.query(key, cur.lchild), self.query(key, cur.rchild))
    
        def insert(self, key, val, val_cnt, cur):
            if cur.left == cur.right:
                cur.val_tuple = merge((val, val_cnt), cur.val_tuple)
                return
            if key <= cur.mid:
                self.insert(key, val, val_cnt, cur.lchild)
            else:
                self.insert(key, val, val_cnt, cur.rchild)
            cur.push_up()
    class Solution(object):
        def findNumberOfLIS(self, nums):
            """
            :type nums: List[int]
            :rtype: int
            """
            if not nums:
                return 0
            st = SegmentTree(min(nums), max(nums))
            for num in nums:
                res = st.query(num - 1, st.root)
                st.insert(num, res[0] + 1, res[1], st.root)
            return st.root.max_cnt
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.