python heap solution


  • 0
    H
    
    class pair(object):
        def __init__(self,pair,ind):
            self.pair=pair
            self.ind=ind
            self.sm=sum(pair)
        def __eq__(self,other):
            return self.sm==other.sm
        def __lt__(self,other):
            return self.sm<other.sm
        def __gt__(self,other):
            return self.sm>other.sm
        def __le__(self,other):
            return self<other or self==other
        def __ge__(self,other):
            return self>other or self==other
    
    
    class heap(object):
        def __init__(self):
            self.array=[]
        def __len__(self):
            return len(self.array)
        def insert(self,pair):
            self.array.append(pair)
            self.bubbleup(len(self)-1)
        def bubbleup(self,index):
            if not index:
                return
            parent=(index-1)//2
            if self.array[index]<self.array[parent]:
                self.swap(index,parent)
                self.bubbleup(parent)
        def swap(self,a,b):
            self.array[a],self.array[b]=self.array[b],self.array[a]
        def deletemin(self):
            self.swap(0,-1)
            res=self.array.pop()
            self.bubbledown(0)
            return res
        def bubbledown(self,index):
            left=2*index+1
            right=left+1
            if left>=len(self):
                return
            target=left
            if right<len(self) and self.array[right]<self.array[left]:
                target=right
            if self.array[target]<self.array[index]:
                self.swap(target,index)
                self.bubbledown(target)
        def __nonzero__(self):
            return bool(self.array)
    
    
    class Solution(object):
        def kSmallestPairs(self, nums1, nums2, k):
            """
            :type nums1: List[int]
            :type nums2: List[int]
            :type k: int
            :rtype: List[List[int]]
            """
            h=heap()
            ans=[]
            if not nums1 or not nums2 or not k:
                return ans
            curpos=[-1]*len(nums1)
            curpos[0]=0
            h.insert(pair([nums1[0],nums2[0]],0))
            p=h.deletemin()
            ans.append(p.pair)
            if k>len(nums1)*len(nums2):
                k=len(nums1)*len(nums2)
            while len(ans)<k:
                curpos[p.ind]+=1
                if curpos[p.ind]<len(nums2):
                    h.insert(pair([nums1[p.ind],nums2[curpos[p.ind]]],p.ind))
                if p.ind+1<len(curpos):
                    if curpos[p.ind+1]<0:
                        curpos[p.ind+1]=0
                        h.insert(pair([nums1[p.ind+1],nums2[curpos[p.ind+1]]],p.ind+1))
                p=h.deletemin()
                ans.append(p.pair)
            return ans
            
                
            
            
            
            
            
            
            
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.