Tree Deserializer and Visualizer for Python

  • 74

    Wrote some tools for my own local testing. For example deserialize('[1,2,3,null,null,4,null,null,5]') will turn that into a tree and return the root as explained in the FAQ. I also wrote a visualizer. Two examples:


    enter image description here


    enter image description here

    Here's the code. If you save it as a Python script and run it, it should as a demo show the above two pictures in turtle windows (one after the other). And you can of course import it from other scripts and then it will only provide the class/functions and not show the demo.

    class TreeNode:
        def __init__(self, val, left=None, right=None):
            self.val = val
            self.left = left
            self.right = right
        def __repr__(self):
            return 'TreeNode({})'.format(self.val)
    def deserialize(string):
        if string == '{}':
            return None
        nodes = [None if val == 'null' else TreeNode(int(val))
                 for val in string.strip('[]{}').split(',')]
        kids = nodes[::-1]
        root = kids.pop()
        for node in nodes:
            if node:
                if kids: node.left  = kids.pop()
                if kids: node.right = kids.pop()
        return root
    def drawtree(root):
        def height(root):
            return 1 + max(height(root.left), height(root.right)) if root else -1
        def jumpto(x, y):
            t.goto(x, y)
        def draw(node, x, y, dx):
            if node:
                t.goto(x, y)
                jumpto(x, y-20)
                t.write(node.val, align='center', font=('Arial', 12, 'normal'))
                draw(node.left, x-dx, y-60, dx/2)
                jumpto(x, y-20)
                draw(node.right, x+dx, y-60, dx/2)
        import turtle
        t = turtle.Turtle()
        t.speed(0); turtle.delay(0)
        h = height(root)
        jumpto(0, 30*h)
        draw(root, 0, 30*h, 40*h)
    if __name__ == '__main__':

  • 1

    Thanks for this very useful tool. I'm always learning new things about python from you!

  • 0

    Great things! I have the same feeling as what totolipton expresses in the answer below :-)

  • 0

    Great tool. I need to put this in the FAQ or somewhere more visible :)

  • 0

    You are my hero

  • 0

    Hi Stefan, do you mind updating your code from # to null since we're no longer using # in binary tree representation? I plan to link your article to the FAQ page so people can use your nice tool. Thanks!

  • 1

    @1337c0d3r Ok, I changed to "null" and "[]" and linked to the FAQ. Thanks for the link from there.

  • 0

    Awesome, thanks!

  • 0

    Hi StefanPochmann,
    I want to ask you how would you serialize a tree into strings? I tried it using level order traversal and got string with a lot of nulls in the back, and then I have to delete those nulls. It works but it's ugly. I wonder if you have any cool methods?

  • 0

    @totolipton Hmm, for this tool I'd probably be lazy and just do what you did :-). But I'll think about it. Btw, you just gave me the idea that this could become a new LeetCode problem or two...

  • 0

    @totolipton Wrote a serializer now. Four, actually. Not really happy with any of them. What do you think? And how does yours look?

    Note: The first two modify the list while iterating it, which as far as I know isn't guaranteed to work.

    def serialize(root):
        queue = [root]
        vals = (queue.extend((node.left, node.right)) or str(node.val) if node else 'null'
                for node in queue)
        return '[' + ','.join(vals).strip('nul,') + ']'
    def serialize(root):
        s = ''
        queue = [root]
        for node in queue:
            if node:
                s += str(node.val) + ','
                queue += node.left, node.right
                s += 'null,'
        return '[' + s.strip('nul,') + ']'
    def serialize(root):
        s = ''
        queue = deque([root])
        while queue:
            node = queue.popleft()
            if node:
                s += str(node.val) + ','
                queue += node.left, node.right
                s += 'null,'
        return '[' + s.strip('nul,') + ']'
    def serialize(root):
        vals = []
        queue = deque([root])
        todo = 1 if root else 0
        while todo:
            node = queue.popleft()
            if node:
                todo -= 1
                for kid in node.left, node.right:
                    if kid:
                        todo += 1            
        return '[' + ','.join(vals) + ']'

  • 0

    Hi StefanPochmann! Nice solutions as always! I'm not comfortable with modify the list while iterating it either. My solution looks like a longer version of your third solution, and I like your 4th solution the best, using todo to keep count of the non-empty node and not having to strip all the nulls in the end.

    A preliminary test using a larger tree generate by [1](2*16-1) shows that your method 1 and 2 performs faster than 3 and 4, which make sense since you are not popping anything, only appending.

  • 1

    Great! BTW, I think the first version is the most Stefanic :-)

  • 1

    Thanks @StefanPochmann

    I have created a nodejs pacakage like this.

    var List = require('leetcode').List;
    // { val: 1, next: { val: 2, next: { val: 3, next: null } } }  
    var l = List.create([1, 2, 3]);
    // [1, 2, 3]
    var tree = require('leetcode').Tree;
    var eq = require('assert').equal;
    var t1 = tree.create([1,null,2,3]);
    var t2 = tree.create([1,null,2,3]);
    // Notice: you should implement isSameTree youself
    eq(isSameTree(t1, t2), true);

  • 1

    I wrote a tree auto generator:

    import random
    class TreeNode(object):
        def __init__(self, x):
            self.val = x
            self.left = None
            self.right = None
    def generateCharSetX():
        return [x for x in range(1000)]
    def generateNodesFromX(x, min_elements, max_elements):
        num_nodes = random.randrange(min_elements, max_elements)
        set_nodes = set([])
        while len(set_nodes) < num_nodes:
            set_nodes.add(x[random.randrange(0, len(x))])
        return set_nodes
    #Data structure involved functions
    def getTree(root, s):
        if root == None and len(s) > 0:
        getTree(root.left, s)
        getTree(root.right, s)
    def generateTree(s):
        val = s.pop()
        root = TreeNode(val)
        while len(s) > 0:
            val = s.pop()
            curr = root
            while 1:
                if curr.val > val and curr.left == None:
                    curr.left = TreeNode(val)
                elif curr.val < val and curr.right == None:
                    curr.right = TreeNode(val)
                if curr.val > val:
                    curr = curr.left
                elif curr.val < val:
                    curr = curr.right
        return root

    How to use:

    root = generateTree(generateNodesFromX(generateCharSetX(), 20, 40))
    res = []
    getTree(root, res)
    print res


    [17, 12, 'null', 'null', 934, 49, 'null', 584, 552, 551, 460, 372, 86, 'null', 122, 'null', 'null', 437, 'null', 440, 'null', 'null', 507, 'null', 541, 'null', 'null', 'null', 'null', 658, 627, 'null', 'null', 788, 725, 'null', 'null', 857, 'null', 'null', 936, 'null', 'null']

  • 0

    Visualize the tree with pygraphviz. 1 to 1 mapping to the serialization.

    import pygraphviz
    def build_graph(root):
        G = pygraphviz.AGraph(strict=True, directed=True)
        G.graph_attr["rankdir"] = "TB"
        #G.graph_attr["splines"] = "ortho"
        G.graph_attr["splines"] = "curved"
        #G.graph_attr["size"] = "10,3"
        G.graph_attr["ordering"] = "out"
        def _id(node, _static={"seq":0}):
            if node: return id(node)
            _static["seq"] += 1
            return "_%d" % _static["seq"]
        def graphviz_customize(node):
            return {
                "label": node.val if node else "#",
                "shape": "box" if node else "oval",
                "color": "blue", # if node else "black",
                "fontsize": 8,
                "width": 0.05 if node else 0.02, # in inches
                "height": 0.02 if node else 0.02,
        # mimic the serialiation
        if not root:
            id_root = _id(root)
            G.add_node(id_root, **graphviz_customize(root))
            return G
        wq = deque([(root, _id(root), "", 0)])
        count = 1
        while count and wq:
            node, id_node, id_parent, right_child = wq.popleft()
            G.add_node(id_node, **graphviz_customize(node))
            if id_parent:
                if right_child:
                    G.add_edge(id_parent, id_node, color='black', style='dashed')
                    G.add_edge(id_parent, id_node, color='red', style='dashed')
            if not node: continue
            count -= 1
            for idx, nd in enumerate((node.left, node.right)):
                if nd:
                    count += 1
                wq.append((nd, _id(nd), id_node, idx))
        return G
    from io import BytesIO
    from IPython.display import Image, display
    def display_graph(G):
        imgbuf = BytesIO()
        G.draw(imgbuf, format='png', prog='dot')
        img = Image(imgbuf.getvalue())
    ## main ##
    root = deserialize('[2,1,3,0,7,9,1,2,null,1,0,null,null,8,8,null,null,null,null,7]')
    G = build_graph(root)

  • 1

    Here's another utility to graph leetcode trees

    $ npm install -g treevis
    $ treevis <array>

    will do the trick.

Log in to reply

Looks like your connection to LeetCode Discuss was lost, please wait while we try to reconnect.