Python heap that supports log(n) delete, any suggestions and optimizations are welcome.


  • 0
    Z

    Very bad running time since too much overhead...

    from collections import defaultdict
    
    
    class HeapList(object):
    
        def __init__(self, is_min=True):
            # values only
            self.heap = []
            self.is_min = is_min
            # {value: {indexes}}
            self.heaplist = defaultdict(set)
    
        def _lt(self, a, b):
            if self.is_min:
                return a < b
            return a > b
    
        def add(self, val):
            self.heap.append(val)
            idx = len(self.heap) - 1
            self.heaplist[val].add(idx)
            self.bubble_up(idx)
    
        @property
        def size(self):
            return len(self.heap)
    
        @property
        def top(self):
            return self.heap[0]
    
        @property
        def empty(self):
            return not self.heap
    
        def pop(self):
            if self.empty:
                raise IndexError
            val = self.heap[0]
            self.heaplist[val].remove(0)
            # clean the heaplist a little bit
            if not self.heaplist[val]:
                del self.heaplist[val]
            if len(self.heap) == 1:
                return self.heap.pop()
            v = self.heap.pop()
            # since we just popped, the index of the last element has become
            # the size of the heap, so we shouldn't subtract one from the heap
            self.heaplist[v].remove(len(self.heap))
            self.heaplist[v].add(0)
            self.heap[0] = v
            self.bubble_down(0)
            return val
    
        def remove(self, val):
            if val not in self.heaplist or not self.heaplist[val]:
                raise IndexError
            indexes = self.heaplist[val]
            # get a random index from set
            idx = indexes.pop()
            # clean up
            if not self.heaplist[val]:
                del self.heaplist[val]
            if idx == len(self.heap) - 1:
                self.heap.pop()
                return
            v = self.heap.pop()
            self.heap[idx] = v
            # since we just popped, the index of the last element has become
            # the size of the heap, so we shouldn't subtract one from the heap
            self.heaplist[v].remove(len(self.heap))
            self.heaplist[v].add(idx)
            if self._lt(val, v):
                self.bubble_down(idx)
            else:
                self.bubble_up(idx)
    
        def contains(self, val):
            return val in self.heaplist and len(self.heaplist[val]) > 0
    
        def bubble_up(self, idx):
            val = self.heap[idx]
            parent = idx
            # remove the index first, by the end of the function
            # we will know the new position for our val
            self.heaplist[val].remove(idx)
            while parent > 0:
                parent = idx - 1 >> 1
                v = self.heap[parent]
                if self._lt(val, v):
                    self.heap[idx] = v
                    # replace the old index with new index
                    # we could do this replace with list
                    # something like self.heaplist[val][self.heaplist[val].index(idx)] = parent
                    # but really, this should be slower, since it takes O(n) time to
                    # find the index
                    self.heaplist[v].remove(parent)
                    self.heaplist[v].add(idx)
                    idx = parent
                else:
                    break
            self.heap[idx] = val
            self.heaplist[val].add(idx)
    
        def bubble_down(self, idx):
            val = self.heap[idx]
            self.heaplist[val].remove(idx)
            # still have at least left child
            while idx * 2 + 1 < len(self.heap):
                left_child = idx * 2 + 1
                right_child = idx * 2 + 2
                min_child_idx = left_child
                if right_child < len(self.heap):
                    min_child_idx = left_child if self._lt(self.heap[left_child], self.heap[right_child]) else right_child
                v = self.heap[min_child_idx]
                # val <= v
                if self._lt(val, v):
                    break
                self.heap[idx] = v
                self.heaplist[v].remove(min_child_idx)
                self.heaplist[v].add(idx)
                idx = min_child_idx
            self.heap[idx] = val
            self.heaplist[val].add(idx)
    
    
    
    class Solution(object):
        def medianSlidingWindow(self, nums, k):
            """
            :type nums: List[int]
            :type k: int
            :rtype: List[float]
            """
            self.min_heap = HeapList(True)
            self.max_heap = HeapList(False)
            # loading the initial data
            for i in range(k):
                self.min_heap.add(nums[i])
            self.rebalance()
            median = [self.get_median()]
            remove_idx = 0
            for add_idx in range(k, len(nums)):
                self.replace(nums[remove_idx], nums[add_idx])
                median.append(self.get_median())
                remove_idx += 1
            return median
        
        
        def replace(self, old, new):
            """add first, then remove"""
            if new < self.min_heap.top:
                self.max_heap.add(new)
            else:
                self.min_heap.add(new)
            if self.min_heap.contains(old):
                self.min_heap.remove(old)
            else:
                self.max_heap.remove(old)
            self.rebalance()
    
    
        def rebalance(self):
            """making sure the size of min_heap >= size of max_heap and difference must be less than 1"""
            while self.min_heap.size - self.max_heap.size not in [0, 1]:
                if self.min_heap.size >= self.max_heap.size:
                    self.max_heap.add(self.min_heap.pop())
                else:
                    self.min_heap.add(self.max_heap.pop())
    
    
        def get_median(self):
            """size of min_heap is always >= max_heap"""
            if self.min_heap.size > self.max_heap.size:
                return float(self.min_heap.top)
            return (self.min_heap.top + self.max_heap.top) / 2.
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.