How to plot (visualize) a neural network in python using Graphviz ?


To illustrate a research project that used a neural network, I needed a simple visualization tool. Here you will find some results based on the library Graphviz:

Install Graphviz in python

To install Graphviz using anaconda, enter the following two commands:

conda install -c anaconda graphviz

Note: at this stage if you try: import graphviz you will get the error message: ModuleNotFoundError: No module named 'graphviz'. It is then necessary to install python-graphviz as well:

conda install -c conda-forge python-graphviz

Plot a simple graph with graphviz

Now we can plot a simple graph with graphviz ( see for example the User Guide)

>>> from graphviz import Digraph
>>> dot = Digraph(comment='A simple Graph')
>>> dot.node('A', 'Cloudy')
>>> dot.node('B', 'Sunny')
>>> dot.node('C', 'Rainy')
>>> dot.edges(['AB', 'AC'])
>>> dot.edge('B', 'C', constraint='false')
>>> dot.format = 'png'
>>> dot.render('my_graph', view=False) 
'my_graph.png'

returns

Example of a simple graph with graphviz
Example of a simple graph with graphviz

Note: dot.source returns all balises required to build the graph (that can be saved in a text file to build the graph as well)

>>> print(dot.source) 
// A simple Graph
digraph {
    A [label=Cloudy]
    B [label=Sunny]
    C [label=Rainy]
    A -> B
    A -> C
    B -> C [constraint=false]
}

Plot a neural network with graphviz

Then, I used the template provided by Zeyuan Hu (see Zeyuan Hu' page) to plot a simple neural network:

>>> graph = temp = '''
... digraph G {
... 
...      graph[ fontname = "Helvetica-Oblique",
...             fontsize = 12,
...             label = "",
...             size = "7.75,10.25" ];
... 
...     rankdir = LR;
...     splines=false;
...     edge[style=invis];
...     ranksep= 1.4;
...     {
...     node [shape=circle, color=chartreuse, style=filled, fillcolor=chartreuse];
...     x1 [label=<x1>];
...     x2 [label=<x2>]; 
... }
... {
...     node [shape=circle, color=dodgerblue, style=filled, fillcolor=dodgerblue];
...     a12 [label=<a<sub>1</sub><sup>(2)</sup>>];
...     a22 [label=<a<sub>2</sub><sup>(2)</sup>>];
...     a32 [label=<a<sub>3</sub><sup>(2)</sup>>];
...     a42 [label=<a<sub>4</sub><sup>(2)</sup>>];
...     a52 [label=<a<sub>5</sub><sup>(2)</sup>>];
...     a13 [label=<a<sub>1</sub><sup>(3)</sup>>];
...     a23 [label=<a<sub>2</sub><sup>(3)</sup>>];
...     a33 [label=<a<sub>3</sub><sup>(3)</sup>>];
...     a43 [label=<a<sub>4</sub><sup>(3)</sup>>];
...     a53 [label=<a<sub>5</sub><sup>(3)</sup>>];
... }
... {
...     node [shape=circle, color=coral1, style=filled, fillcolor=coral1];
...     O1 [label=<y1>];
...     O2 [label=<y2>]; 
...     O3 [label=<y3>]; 
... }
...     {
...         rank=same;
...         x1->x2;
...     }
...     {
...         rank=same;
...         a12->a22->a32->a42->a52;
...     }
...     {
...         rank=same;
...         a13->a23->a33->a43->a53;
...     }
...     {
...         rank=same;
...         O1->O2->O3;
...     }
...     l0 [shape=plaintext, label="layer 1 (input layer)"];
...     l0->x1;
...     {rank=same; l0;x1};
...     l1 [shape=plaintext, label="layer 2 (hidden layer)"];
...     l1->a12;
...     {rank=same; l1;a12};
...     l2 [shape=plaintext, label="layer 3 (hidden layer)"];
...     l2->a13;
...     {rank=same; l2;a13};
...     l3 [shape=plaintext, label="layer 4 (output layer)"];
...     l3->O1;
...     {rank=same; l3;O1};
...     edge[style=solid, tailport=e, headport=w];
...     {x1; x2} -> {a12;a22;a32;a42;a52};
...     {a12;a22;a32;a42;a52} -> {a13;a23;a33;a43;a53};
...     {a13;a23;a33;a43;a53} -> {O1,O2,O3};
... }'''
>>> from graphviz import Source
>>> dot = Source(graph)
>>> dot.format = 'png'
>>> dot.render('neural_network_01', view=False) 
'neural_network_01.png'

returns

How to plot (visualize) a neural network in python using Graphviz ?
How to plot (visualize) a neural network in python using Graphviz ?

Note 1: to change image size or the font size see the lines:

...      graph[ fontname = "Helvetica-Oblique",
...             fontsize = 12,
...             label = "",
...             size = "7.75,10.25" ];

Note 2: work well in a jupyter notebook (see image below)

How to plot (visualize) a neural network in python using Graphviz in a jupyter notebook?
How to plot (visualize) a neural network in python using Graphviz in a jupyter notebook?

References

Image

of