Graphviz scripts to create simple visualisations of neural networks


Simple feed forward neural networks

Often while working with neural networks in Deep Learning and Machine Learning fields it is often easy to picturise the network architecture into a concise diagram which conveys lot of useful information. In what follows are some scripts that are inspired from martisak dotnets github repo including a rudementry graphing script which generates a DOT lannguage script interpretable by graphviz. The script found on the github repo was later ported to graphviz and expanded to add more functionalities to the script. The added functionalities enable activation function annotations and recurrent neural networks representation.

Firstly the graphviz module is imported and some global variables are set

try:
    import graphviz as G
except ImportError as e:
    print('ModuleNotFoundError: "graphviz" package not available, install it with "pip install graphviz"')


# boolean variables to denote dense or sparse connections between layers
DENSE = True
SPARSE = False

PENWIDTH = '15'
FONT = 'Hilda 10'

Now the network architecture details are defined based on which a neural network will be graphed. The details include the number of nodes/perceptrons present in each layer, the type of connections between two layers. Since the connections are in-between layers, the length of connections list has to be 1 less than the length of layers list.

layer_nodes = [6, 4, 4, 4]
connections = [DENSE, DENSE, SPARSE] 
assert len(connections) == (len(layer_nodes) - 1), '"connections" array should be 1 less than the #layers'
for i, type_of_connections in enumerate(connections):
    if type_of_connections == SPARSE:
        assert layer_nodes[i] == layer_nodes[i+1], "If connection type is SPARSE then the number of nodes in the adjacent layers must be equal"

A graph in graphviz mainly consists of three components, namely, nodes, edges, and the graph itself. Just like defining a class while programming before creating an object, overhere the graphviz library provides a generic Digraph class which is to be instantiated with our desired object presets. The following piece of code instantiates a directed graph with some nodes, edges, and graph attributes based on which the graph will be drawn.

dot = G.Digraph(comment='Neural Network', 
                graph_attr={'nodesep':'0.04', 'ranksep':'0.05', 'bgcolor':'white', 'splines':'line', 'rankdir':'LR', 'fontname':FONT},
                node_attr={'fixedsize':'true', 'label':"", 'style':'filled', 'color':'none', 'fillcolor':'gray', 'shape':'circle', 'penwidth':PENWIDTH, 'width':'0.4', 'height':'0.4'},
                edge_attr={'color':'gray30', 'arrowsize':'.4'})

Create nodes

for layer_no in range(len(layer_nodes)):
    with dot.subgraph(name='cluster_'+str(layer_no)) as c:
        c.attr(color='transparent') # comment this if graph background is needed
        if layer_no == 0:                 # first layer
            c.attr(label='Input')
        elif layer_no == len(layer_nodes)-1:   # last layer
            c.attr(label='Output')
        else:                      # layers in between
            c.attr(label='Hidden')
        for a in range(layer_nodes[layer_no]):
            if layer_no == 0: # or i == len(layers)-1: # first or last layer
                c.node('l'+str(layer_no)+str(a), '', fillcolor='black')#, fontcolor='white'
            if layer_no == len(layer_nodes)-1:
                c.node('l'+str(layer_no)+str(a), '', fontcolor='white', fillcolor='black')#, fontcolor='white'
            else:
                # unicode characters can be used to inside the nodes as follows
                # for a list of unicode characters refer this https://pythonforundergradengineers.com/unicode-characters-in-python.html
                c.node('l'+str(layer_no)+str(a), '\u03C3', fontsize='12') # to place "sigma" inside the nodes of a layer
                
                # for normal textual representation like 'relu' and 'tanh', the following approach can be taken
                # c.node('l'+str(layer_no)+str(a), 'relu', fontsize='12')

Create edges

for layer_no in range(len(layer_nodes)-1):
    for node_no in range(layer_nodes[layer_no]):
        if connections[layer_no] == DENSE:
            for b in range(layer_nodes[layer_no+1]):
                dot.edge('l'+str(layer_no)+str(node_no), 'l'+str(layer_no+1)+str(b),)
        elif connections[layer_no] == SPARSE:
            dot.edge('l'+str(layer_no)+str(node_no), 'l'+str(layer_no+1)+str(node_no))                

Render

dot
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> %3 cluster_0 Input cluster_1 Hidden cluster_2 Hidden cluster_3 Output l00 σ l10 σ l00->l10 l11 σ l00->l11 l12 σ l00->l12 l13 σ l00->l13 l01 σ l01->l10 l01->l11 l01->l12 l01->l13 l02 σ l02->l10 l02->l11 l02->l12 l02->l13 l03 σ l03->l10 l03->l11 l03->l12 l03->l13 l04 σ l04->l10 l04->l11 l04->l12 l04->l13 l05 σ l05->l10 l05->l11 l05->l12 l05->l13 l20 σ l10->l20 l21 σ l10->l21 l22 σ l10->l22 l23 σ l10->l23 l11->l20 l11->l21 l11->l22 l11->l23 l12->l20 l12->l21 l12->l22 l12->l23 l13->l20 l13->l21 l13->l22 l13->l23 l30 l20->l30 l31 l21->l31 l32 l22->l32 l33 l23->l33

Save/Export

dot.format = 'JPEG' # or PDF, SVG, JPEG, PNG, etc. 
dot.render('./example_network')
'./example_network.jpeg'

Recurrent neural network

The previously used code can be modified and adapted easily for variuous architectures, for adding additional details, for customizing, and much more. The same code if now modified to represent recurrent neural networks instead.

layer_nodes = [6, 4, 4, 4]
connections = [DENSE, DENSE, DENSE]
# additional variable to denote which layers consist of recurrent units
recurrent = [False, True, True, False]

assert len(connections) == (len(layer_nodes) - 1), '"connections" array should be 1 less than the #layers'
for i, type_of_connections in enumerate(connections):
    if type_of_connections == SPARSE:
        assert layer_nodes[i] == layer_nodes[i+1], "If connection type is SPARSE then the number of nodes in the adjacent layers must be equal"

dot = G.Digraph(comment='Neural Network', 
                graph_attr={'nodesep':'0.04', 'ranksep':'0.05', 'bgcolor':'white', 'splines':'line', 'rankdir':'LR', 'fontname':FONT},
                node_attr={'fixedsize':'true', 'label':"", 'style':'filled', 'color':'none', 'fillcolor':'gray', 'shape':'circle', 'penwidth':PENWIDTH, 'width':'0.4', 'height':'0.4'},
                edge_attr={'color':'gray30', 'arrowsize':'.4'})

for layer_no in range(len(layer_nodes)):
    with dot.subgraph(name='cluster_'+str(layer_no)) as c:
        c.attr(color='transparent') # comment this if graph background is needed
        if layer_no == 0:                 # first layer
            c.attr(label='Input')
        elif layer_no == len(layer_nodes)-1:   # last layer
            c.attr(label='Output')
        else:                      # layers in between
            c.attr(label='Hidden')
        for a in range(layer_nodes[layer_no]):
            if layer_no == 0: # or i == len(layers)-1: # first or last layer
                c.node('l'+str(layer_no)+str(a), '', fillcolor='black')#, fontcolor='white'
            if layer_no == len(layer_nodes)-1:
                c.node('l'+str(layer_no)+str(a), '', fontcolor='white', fillcolor='black')#, fontcolor='white'
            else:
                # unicode characters can be used to inside the nodes as follows
                # for a list of unicode characters refer this https://pythonforundergradengineers.com/unicode-characters-in-python.html
                # c.node('l'+str(layer_no)+str(a), '\u03C3', fontsize='12') # to place "sigma" inside the nodes of a layer
                
                # for normal textual representation like 'relu' and 'tanh', the following approach can be taken
                c.node('l'+str(layer_no)+str(a), 'tanh', fontsize='12')
                
for layer_no in range(len(layer_nodes)-1):
    for node_no in range(layer_nodes[layer_no]):
        if connections[layer_no] == DENSE:
            # to place recuurent units
            # change the label 'x10' to denote the number of time steps into which the recurrent unit unrolls in time
            if recurrent[layer_no]:
                dot.edge('l'+str(layer_no)+str(node_no), 'l'+str(layer_no)+str(node_no), xlabel='x10', color='blue', fontcolor='blue')
            for b in range(layer_nodes[layer_no+1]):
                dot.edge('l'+str(layer_no)+str(node_no), 'l'+str(layer_no+1)+str(b),)
        elif connections[layer_no] == SPARSE:
            dot.edge('l'+str(layer_no)+str(node_no), 'l'+str(layer_no+1)+str(node_no))                
dot
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> %3 cluster_0 Input cluster_1 Hidden cluster_2 Hidden cluster_3 Output l00 tanh l10 tanh l00->l10 l11 tanh l00->l11 l12 tanh l00->l12 l13 tanh l00->l13 l01 tanh l01->l10 l01->l11 l01->l12 l01->l13 l02 tanh l02->l10 l02->l11 l02->l12 l02->l13 l03 tanh l03->l10 l03->l11 l03->l12 l03->l13 l04 tanh l04->l10 l04->l11 l04->l12 l04->l13 l05 tanh l05->l10 l05->l11 l05->l12 l05->l13 l10->l10 x10 l20 tanh l10->l20 l21 tanh l10->l21 l22 tanh l10->l22 l23 tanh l10->l23 l11->l11 x10 l11->l20 l11->l21 l11->l22 l11->l23 l12->l12 x10 l12->l20 l12->l21 l12->l22 l12->l23 l13->l13 x10 l13->l20 l13->l21 l13->l22 l13->l23 l20->l20 x10 l30 l20->l30 l31 l20->l31 l32 l20->l32 l33 l20->l33 l21->l21 x10 l21->l30 l21->l31 l21->l32 l21->l33 l22->l22 x10 l22->l30 l22->l31 l22->l32 l22->l33 l23->l23 x10 l23->l30 l23->l31 l23->l32 l23->l33

Save/Export

dot.format = 'pdf' # or PDF, SVG, JPEG, PNG, etc. 
dot.render('./example_recurrent_network')
'./example_recurrent_network.pdf'

Some more additional scripts that are useful


1.Unrolled representation of a single recurrent unit

TIMESTEPS = 6
TIME_OFFSET = 3

unrolled = G.Digraph(node_attr={'shape':'circle', 'fixedsize':'true'}, graph_attr={'style':'invis', 'rankdir':'BT', 'color':'transparent'})
for step in range(TIMESTEPS+2):
    if step == 0 or step == TIMESTEPS+1:
        with unrolled.subgraph(name='cluster_'+str(i)) as c:
            c.node('a'+str(step), '', color='transparent')
            c.node('b'+str(step), '...', color='transparent') 
            c.node('c'+str(step), '', color='transparent')
            c.edge('a'+str(step), 'b'+str(step), style='invis') 
            c.edge('b'+str(step), 'c'+str(step), style='invis')
    else:
        with unrolled.subgraph(name='cluster_'+str(i)) as c:
            c.node('a'+str(step), '', color='transparent')
            c.node('b'+str(step), 't'+'{:=+d}'.format(TIME_OFFSET-step) if TIME_OFFSET-step else 't') 
            c.node('c'+str(step), '', color='transparent');
            c.edge('a'+str(step), 'b'+str(step)); c.edge('b'+str(step), 'c'+str(step));
for step in range(1, TIMESTEPS+2):
    unrolled.edge('b'+str(step-1), 'b'+str(step), constraint='false', dir='back', color='blue')
unrolled
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> %3 cluster_2 a0 b0 ... c0 b1 t+2 b0->b1 a1 a1->b1 c1 b1->c1 b2 t+1 b1->b2 a2 a2->b2 c2 b2->c2 b3 t b2->b3 a3 a3->b3 c3 b3->c3 b4 t-1 b3->b4 a4 a4->b4 c4 b4->c4 b5 t-2 b4->b5 a5 a5->b5 c5 b5->c5 b6 t-3 b5->b6 a6 a6->b6 c6 b6->c6 b7 ... b6->b7 a7 c7
unrolled.render('./unrolled')
'./unrolled.pdf'

2.Single recurrent unit

ru = G.Digraph(node_attr={'shape':'circle', 'fixedsize':'true'}, graph_attr={'style':'invis', 'rankdir':'LR'})
ru.node('a', '', color='transparent')
ru.node('b', 'N')
ru.node('c', '', color='transparent')
ru.edge('a', 'b')
ru.edge('b', 'c')
ru.edge('b', 'b', color='blue')
ru
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> %3 a b N a->b b->b c b->c
ru.render('./rnn')
'./rnn.pdf'