AC python union find

  • 0

    video link for union find explanation:

    class Solution(object):
        def validTree(self, n, edges):
            :type n: int
            :type edges: List[List[int]]
            :rtype: bool
            ds = DisjointSet()
            # make set for each node
            for i in range(n):
            for from_vert, to_vert in edges:
                parent1 = ds.find_set(from_vert)
                parent2 = ds.find_set(to_vert)
                # if two nodes have the same parent return False
                if parent1 == parent2:
                    return False
                ds.union(from_vert, to_vert) # union nodes with different parents
            # in the end number of sets should equal 1
            return ds.num_sets == 1
    class Node(object):
        def __init__(self, data, parent = None, rank = 0):
   = data
            self.parent = parent
            self.rank = rank
        def __str__(self):
            return str(
        def __repr__(self):
            return self.__str__()
    class DisjointSet(object):
        def __init__(self):
   = {}
            self.num_sets = 0
        def make_set(self, data):
            node = Node(data)
            node.parent = node # very important!
  [data] = node
            self.num_sets += 1 # make_set increases the number of disjoint sets by one
        def union(self, data1, data2):
            # gets nodes given data values
            node1 =[data1]
            node2 =[data2]
            # get parents given nodes
            parent1 = self.find_set_util(node1)
            parent2 = self.find_set_util(node2)
            # if they are part of same set do nothing
            if ==
            # else whoever's rank is higher becomes parent of other
            if parent1.rank >= parent2.rank:
                # increment rank only if both sets have same rank
                if parent1.rank == parent2.rank:
                    parent1.rank = parent1.rank + 1
                parent2.parent = parent1
                parent1.parent = parent2
            self.num_sets -= 1 # union decreases the number of disjoint sets by one
        # Finds the representative of this set
        def find_set(self, data):
            return self.find_set_util([data]) # pass in the node
        # Find the representative recursively and does path compression as well.
        def find_set_util(self, node):
            parent = node.parent
            if parent == node:
                return parent
            node.parent = self.find_set_util(node.parent) # path compression
            return node.parent

Log in to reply

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.