Share my Python Union-find solution


  • 1
    D

    The idea is based on Connected-component labeling, two-pass algorithm. Then use a set to get number of distinct label sets (number of islands)

    from collections import defaultdict
    
    # disjoint-set node
    class DsNode:
        def __init__(self):
            self.rank = 0
            self.parent = self
    
    class DisjointSets:
        # DisjointSets Constructors and public methods.
        def __init__(self):
            self._sets = defaultdict(DsNode)
    
        def find(self, x):
            # path compression
            while x.parent is not x:
                x.parent = x.parent.parent
                x = x.parent
            return x
    
        def findByLabel(self, label):
            return self.find(self._sets[label])
    
        def unionByLabel(self, labelA, labelB):
            # union by rank
            a, b = self.find(self._sets[labelA]), self.find(self._sets[labelB])
            if a is not b:
                if a.rank > b.rank:
                    b.parent = a
                else:
                    a.parent = b
                    if a.rank == b.rank:
                        b.rank += 1
    
    class Solution(object):
        def numIslands(self, grid):
            """
            :type grid: List[List[str]]
            :rtype: int
            """
            rows = len(grid)
            if rows > 0:
                cols = len(grid[0])
                if cols > 0:
                    ds = DisjointSets()
                    labels, next_label = [[0]*cols for _ in range(rows)], 1
                    for row in xrange(rows):
                        for col in xrange(cols):
                            if grid[row][col] == '1':
                                # land, check get north and west 
                                north, west = row - 1, col - 1
                                if north >= 0:
                                    # use label of north cell for now
                                    labels[row][col] = labels[north][col]
                                if west >= 0 and grid[row][west] == '1':
                                    if labels[row][col] == 0:
                                        # current cell not labeled, use label of west
                                        labels[row][col] = labels[row][west]
                                    elif labels[row][col] != labels[row][west]:
                                        # labels of north and west are different, union the two labels
                                        ds.unionByLabel(labels[row][col], labels[row][west])
                                if labels[row][col] == 0:
                                    # current cell not labeled: must be an isolated cell. Use next label
                                    labels[row][col] = next_label
                                    next_label += 1
                    node_set = set()
                    for label in range(1, next_label):
                        node_set.add(ds.findByLabel(label))
                    return len(node_set)
    
            return 0
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.