# How to plot (visualize) a neural network in python using Graphviz ?

Published: September 18, 2020

To illustrate a research project that used a neural network, I needed a simple visualization tool. Here you will find some results based on the library Graphviz:

### Install Graphviz in python

To install Graphviz using anaconda, enter the following two commands:

````conda install -c anaconda graphviz`
```

Note: at this stage if you try: import graphviz you will get the error message: ModuleNotFoundError: No module named 'graphviz'. It is then necessary to install python-graphviz as well:

````conda install -c conda-forge python-graphviz`
```

### Plot a simple graph with graphviz

Now we can plot a simple graph with graphviz ( see for example the User Guide)

````>>> from graphviz import Digraph`
`>>> dot = Digraph(comment='A simple Graph')`
`>>> dot.node('A', 'Cloudy')`
`>>> dot.node('B', 'Sunny')`
`>>> dot.node('C', 'Rainy')`
`>>> dot.edges(['AB', 'AC'])`
`>>> dot.edge('B', 'C', constraint='false')`
`>>> dot.format = 'png'`
`>>> dot.render('my_graph', view=False) `
`'my_graph.png'`
```

returns

Note: dot.source returns all balises required to build the graph (that can be saved in a text file to build the graph as well)

````>>> print(dot.source) `
`// A simple Graph`
`digraph {`
`    A [label=Cloudy]`
`    B [label=Sunny]`
`    C [label=Rainy]`
`    A -> B`
`    A -> C`
`    B -> C [constraint=false]`
`}`
```

### Plot a neural network with graphviz

Then, I used the template provided by Zeyuan Hu (see Zeyuan Hu' page) to plot a simple neural network:

````>>> graph = temp = '''`
`... digraph G {`
`... `
`...      graph[ fontname = "Helvetica-Oblique",`
`...             fontsize = 12,`
`...             label = "",`
`...             size = "7.75,10.25" ];`
`... `
`...     rankdir = LR;`
`...     splines=false;`
`...     edge[style=invis];`
`...     ranksep= 1.4;`
`...     {`
`...     node [shape=circle, color=chartreuse, style=filled, fillcolor=chartreuse];`
`...     x1 [label=<x1>];`
`...     x2 [label=<x2>]; `
`... }`
`... {`
`...     node [shape=circle, color=dodgerblue, style=filled, fillcolor=dodgerblue];`
`...     a12 [label=<a<sub>1</sub><sup>(2)</sup>>];`
`...     a22 [label=<a<sub>2</sub><sup>(2)</sup>>];`
`...     a32 [label=<a<sub>3</sub><sup>(2)</sup>>];`
`...     a42 [label=<a<sub>4</sub><sup>(2)</sup>>];`
`...     a52 [label=<a<sub>5</sub><sup>(2)</sup>>];`
`...     a13 [label=<a<sub>1</sub><sup>(3)</sup>>];`
`...     a23 [label=<a<sub>2</sub><sup>(3)</sup>>];`
`...     a33 [label=<a<sub>3</sub><sup>(3)</sup>>];`
`...     a43 [label=<a<sub>4</sub><sup>(3)</sup>>];`
`...     a53 [label=<a<sub>5</sub><sup>(3)</sup>>];`
`... }`
`... {`
`...     node [shape=circle, color=coral1, style=filled, fillcolor=coral1];`
`...     O1 [label=<y1>];`
`...     O2 [label=<y2>]; `
`...     O3 [label=<y3>]; `
`... }`
`...     {`
`...         rank=same;`
`...         x1->x2;`
`...     }`
`...     {`
`...         rank=same;`
`...         a12->a22->a32->a42->a52;`
`...     }`
`...     {`
`...         rank=same;`
`...         a13->a23->a33->a43->a53;`
`...     }`
`...     {`
`...         rank=same;`
`...         O1->O2->O3;`
`...     }`
`...     l0 [shape=plaintext, label="layer 1 (input layer)"];`
`...     l0->x1;`
`...     {rank=same; l0;x1};`
`...     l1 [shape=plaintext, label="layer 2 (hidden layer)"];`
`...     l1->a12;`
`...     {rank=same; l1;a12};`
`...     l2 [shape=plaintext, label="layer 3 (hidden layer)"];`
`...     l2->a13;`
`...     {rank=same; l2;a13};`
`...     l3 [shape=plaintext, label="layer 4 (output layer)"];`
`...     l3->O1;`
`...     {rank=same; l3;O1};`
`...     edge[style=solid, tailport=e, headport=w];`
`...     {x1; x2} -> {a12;a22;a32;a42;a52};`
`...     {a12;a22;a32;a42;a52} -> {a13;a23;a33;a43;a53};`
`...     {a13;a23;a33;a43;a53} -> {O1,O2,O3};`
`... }'''`
`>>> from graphviz import Source`
`>>> dot = Source(graph)`
`>>> dot.format = 'png'`
`>>> dot.render('neural_network_01', view=False) `
`'neural_network_01.png'`
```

returns

Note 1: to change image size or the font size see the lines:

````...      graph[ fontname = "Helvetica-Oblique",`
`...             fontsize = 12,`
`...             label = "",`
`...             size = "7.75,10.25" ];`
```

Note 2: work well in a jupyter notebook (see image below)

Image

of