A "modular" solution inspired by Skienna's Algorithm Design Manual


  • 0
    S

    This is obviously a straight forward DFS (could do with BFS). I was interested to see how I could have the exact same Graph, DFS "template" that I could solve a bunch of other problems with. The "Edge", "Graph" and "dfs" method are very generic (and have been used as is to solve other Graph problems like, course scheduling (and II). By simply changing the "Traversal" class, other behaviour's can be induced.

    Note some of the boilerlate (like Edge class is not even required) for this specific problem.

    
    
    class Solution:
        # @param node, a undirected graph node
        # @return a undirected graph node
        def cloneGraph(self, node):
            if not node: return None
            def neighbors_func(node):
                return [(n,None) for n in node.neighbors]
            def key_func(node):
                return node.label
            g = Graph(neighbors_func = neighbors_func, key_func = key_func)
            t = Traversal(g)
            dfs(node, t)
            return t.copiedNodes[node.label]
    
    class Edge(object):
        def __init__(self, source, target, data = None):
            self._source, self._target, self.data = source, target, data
    
        def __repr__(self):
            return "Edge<%s <-> %s>" % (repr(self.source), repr(self.target))
    
        @property
        def source(self): return self._source
    
        @property
        def target(self): return self._target
    
    class Graph(object):
        def __init__(self, multi = False, directed = False, key_func = None, neighbors_func = None):
            self.nodes = {}
            self._is_directed = directed
            self._is_multi = multi
            self.neighbors_func = neighbors_func
            self.key_func = key_func or (lambda x: x)
    
        @property
        def is_directed(self): return self._is_directed
        @property
        def is_multi(self): return self._is_multi
    
        def get_edge(self, source, target):
            return self.nodes.get(source, {}).get(target, None)
    
        def add_nodes(self, *nodes):
            return [self.add_node(node) for node in nodes]
    
        def add_node(self, node):
            """
            Adds or update a node (any hashable) in the graph.
            """
            if node not in self.nodes: self.nodes[node] = {}
            return self.nodes[node]
    
        def neighbors(self, node):
            """Return the neighbors of a node."""
            if self.neighbors_func:
                return self.neighbors_func(node)
            else:
                return self.nodes.get(node, {})
    
        def iter_neighbors(self, node, reverse = False):
            """
            Return an iterator of neighbors (along with any edge data) for a particular node.
            Override this method for custom node storage and inspection strategies.
            """
            neighbors = self.neighbors(node)
            if type(neighbors) is dict:
                if reverse: return reversed(self.neighbors(node).items())
                else: return self.neighbors(node).iteritems()
            else:
                if reverse: return reversed(neighbors)
                else: return neighbors
    
        def add_raw_edge(self, edge):
            self.add_nodes(edge.source,edge.target)
            source,target = edge.source,edge.target
            self.nodes[source][target] = edge
            if not self.is_directed:
                self.nodes[target][source] = edge
            return edge
    
        def add_edge(self, source, target):
            return self.add_raw_edge(Edge(source, target))
    
        def add_edges(self, *edges):
            return [self.add_edge(*e) for e in edges]
    
    import itertools
    from collections import deque, defaultdict
    
    DISCOVERED = 0
    PROCESSED = 1
    
    class Traversal(object):
        """
        A class that provides delegate methods to assist in graph traversal.
        """
        def __init__(self, graph):
            self.graph = graph
    
            # The parent nodes for each of the nodes
            self.parents = defaultdict(lambda: None)
    
            # Marks a node's state - can be missing (undiscovered), discovered (0) and processed (1)
            self.node_state = defaultdict(lambda: None)
    
            self.curr_time, self.entry_times, self.exit_times = 0, {}, {}
    
            # Set this flag to true if you want the traversal to stop
            self.terminated = False
            
            self.copiedNodes = {}
    
        def should_process_children(self, node): 
            if node.label not in self.copiedNodes:
                print "Processing Node: ", node.label
                self.copiedNodes[node.label] = UndirectedGraphNode(node.label)
    
        def process_node(self, node): return True
        def process_edge(self, source, target, edge_data): 
            print "Processing Edge: ", source.label, target.label
            self.should_process_children(source)
            self.should_process_children(target)
            self.copiedNodes[source.label].neighbors.insert(0, self.copiedNodes[target.label])
            # self.copiedNodes[target.label].neighbors.append(self.copiedNodes[source.label])
    
        def select_children(self, node, reverse = False): return self.graph.iter_neighbors(node, reverse = reverse)
    
    
    def dfs(node, traversal):
        """
        Recursive DFS traversal of a graph.
        
        Traversal object contains the following:
    
            should_process_children(node):
                This method is called before a node is processed.  If this method
                returns a False then the node is not processed (and not marked as processed).
                If this method returns False, the success nodes of this node will also not be 
                visited.
    
            process_node(node):
                This method is called when a node is ready to be processed (after it has been
                visited).  Only if this method returns True then the node is marked as "processed".
    
            process_edge(source, target, edge_data):
                When a node is reached, process_edge is called on the edge that lead to
                the node.   If this method is returned False, then the node is no longer 
                traversed.
    
            select_children(node, reverse = False):
                Called to select the children of the node that are up for traversal from the given node
                along with the order the children are to be traversed.
    
                By default returns all the children in no particular order.
                Returns an iterator of tuples - (node, edge_data)
    
            parents[node -> node]       -   A map in which the parent nodes of a node are stored.
            node_state[node -> int]     -   A map storing the discovery/processing state of a node.
            entry_times[Node -> int]    -   Contains the entry time of a particular node.
            exit_times[Node -> int]     -   Contains the exit time of a particular node (ie when all of 
                                            a node's children have also been processed).
        """
        if traversal.terminated: return
    
        g = traversal.graph
        node_key = g.key_func(node)
        traversal.node_state[node_key] = DISCOVERED
        traversal.entry_times[node_key] = traversal.curr_time
        traversal.curr_time += 1
    
        if traversal.should_process_children(node) is not False:
            # Now go through all children
            children = list(traversal.select_children(node, reverse = True))
            # print "Node, Children: ", g.key_func(node), children
            for n,edge in children:
                child_key = g.key_func(n)
                traversal.process_edge(node, n, edge)
                if traversal.node_state[child_key] == None: # Node has not even been discovered yet
                    traversal.parents[child_key] = node
                    dfs(n, traversal)
            if traversal.process_node(node) is not False:
                traversal.node_state[node_key] = PROCESSED
                traversal.curr_time += 1
                traversal.exit_times[node_key] = traversal.curr_time
    

Log in to reply
 

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.