A clean python code using segment tree


  • 0
    S
    class SegNode(object):
        def __init__(self, left, right):
            self.left_child, self.right_child = None, None
            self.left, self.right = left, right
            self.sum = 0
    
    
    class NumArray(object):
        def __init__(self, nums):
            """
            initialize your data structure here.
            :type nums: List[int]
            """
            self.root = self.buildTree(nums, 0, len(nums)-1) if nums else None
            self.nums = nums
    
        def buildTree(self, nums, left, right):
            node = SegNode(left, right)
            if left == right:
                # leaf node
                node.sum = nums[left]
            else:
                # build tree recursively
                mid = (left+right) / 2
                left_child = self.buildTree(nums, left, mid)
                right_child = self.buildTree(nums, mid+1, right)
                node.left_child, node.right_child = left_child, right_child
                node.sum = left_child.sum + right_child.sum
            return node
               
    
        def update(self, i, val):
            """
            :type i: int
            :type val: int
            :rtype: int
            """
            diff = val - self.nums[i]
            self.nums[i] = val
            
            p = self.root
            # update all the nodes along the path
            while p:
                p.sum += diff
                mid = (p.left + p.right) / 2
                if i <= mid:
                    p = p.left_child
                else:
                    p = p.right_child
            
            
        def sumRange(self, i, j):
            """
            sum of elements nums[i..j], inclusive.
            :type i: int
            :type j: int
            :rtype: int
            """
            return self.sumRange2(self.root, i, j)
        
        def sumRange2(self, root, i, j):
            if i == root.left and j == root.right:
                return root.sum
            mid = (root.left + root.right) / 2
            if j <= mid:
                # search in left child
                return self.sumRange2(root.left_child, i, j)
            if i > mid:
                # search in right child
                return self.sumRange2(root.right_child, i, j)
            # search both
            return self.sumRange2(root.left_child, i, mid) + self.sumRange2(root.right_child, mid+1, j)
             ```
    
    
    
    Quite "standard" solution.

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.