How to create a dendrogram in Python using scipy and matplotlib ?

Published: September 06, 2024

Updated: September 07, 2024

Tags: Matplotlib;

DMCA.com Protection Status

Introduction

A dendrogram is a tree-like diagram used to visualize the arrangement of clusters created by hierarchical clustering. It shows how individual data points (or clusters) are merged step-by-step based on their similarity or distance. The vertical lines represent the distances at which clusters are combined, with lower merges indicating more similarity. Dendrograms are commonly used in data analysis to explore the underlying structure of a dataset and determine the optimal number of clusters.

Creating a dendrogram in Python can be done using the scipy library, which provides hierarchical clustering tools. Here's a step-by-step guide:

Install required libraries

If you haven't already installed scipy and matplotlib, you can install them using:

1
pip install scipy matplotlib

Basic Code

Import necessary modules

You will need to import the linkage and dendrogram functions from scipy.cluster.hierarchy, and matplotlib.pyplot for plotting:

1
2
3
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage

Prepare your data

For this example, let's create a random dataset:

1
2
3
4
5
6
# Sample data (for example, 5 data points with 2 features each)
X = np.array([[1, 2],
              [3, 4],
              [5, 6],
              [7, 8],
              [9, 10]])

Perform hierarchical clustering

Use the linkage function to compute the hierarchical clustering. You can choose a method like ward, single, complete, or average. Here's an example with the ward method:

1
2
# Perform hierarchical clustering
Z = linkage(X, method='ward')

Plot the dendrogram

Finally, plot the dendrogram using dendrogram:

1
2
3
4
5
6
7
# Plot dendrogram
plt.figure(figsize=(10, 7))
dendrogram(Z)
plt.title("Dendrogram")
plt.xlabel("Data points")
plt.ylabel("Euclidean distances")
plt.show()

A dendrogram plot will be displayed with hierarchical clustering of your dataset, showing how data points are grouped at various distances.

How to create a dendrogram in Python using scipy and matplotlib ?
How to create a dendrogram in Python using scipy and matplotlib ?

Full Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage

# Sample data (for example, 5 data points with 2 features each)
X = np.array([[1, 2],
              [3, 4],
              [5, 6],
              [7, 8],
              [9, 10]])

# Perform hierarchical clustering
Z = linkage(X, method='ward')

# Plot dendrogram
plt.figure(figsize=(10, 7))
dendrogram(Z)
plt.title("Dendrogram")
plt.xlabel("Data points")
plt.ylabel("Euclidean distances")
plt.show()

Customizing the Dendrogram

Here’s a step-by-step explanation of the provided Python code, along with suggestions for customizing the dendrogram:

1. Importing Libraries:

1
2
3
from scipy.cluster.hierarchy import dendrogram, linkage
import matplotlib.pyplot as plt
import numpy as np
  • scipy.cluster.hierarchy.dendrogram: Used to generate the dendrogram plot.
  • scipy.cluster.hierarchy.linkage: Performs the hierarchical clustering.
  • matplotlib.pyplot: Used for plotting.
  • numpy: Generates random data points for clustering.

2. Custom Dendrogram Function:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
def custom_dendrogram(*args, **kwargs):
    dendro_data = dendrogram(*args, **kwargs)

    if not kwargs.get('no_plot', False):
        for icoord, dcoord in zip(dendro_data['icoord'], dendro_data['dcoord']):
            x_coord = 0.5 * sum(icoord[1:3])  # Midpoint of the cluster
            height = dcoord[1]  # Distance at which the clusters merge
            plt.plot(x_coord, height, 'ro')  # Plot a red dot at the merge point
            plt.annotate(f"{height:.3g}", (x_coord, height), xytext=(0, -8),
                         textcoords='offset points', va='top', ha='center')  # Annotate height

    return dendro_data
  • This function extends the default dendrogram by adding red dots and annotated cluster heights to each merge point, enhancing the dendrogram's readability.
  • The annotations show the exact distances at which clusters merge, providing more detail on the hierarchical structure.

3. Generating Random Data:

1
2
3
np.random.seed(12312)
num_points = 100
data = np.random.multivariate_normal([0, 0], np.array([[4.0, 2.5], [2.5, 1.4]]), size=num_points)
  • Generates 100 random 2D points from a multivariate normal distribution with a specified mean and covariance matrix.
  • This data serves as the input for hierarchical clustering.

4. Scatter Plot of Data:

1
2
3
4
5
6
7
plt.figure(figsize=(6, 5))
plt.scatter(data[:, 0], data[:, 1])
plt.title("Scatter Plot of Data Points")
plt.axis('equal')
plt.grid(True)
plt.savefig('scatter_plot.png')
plt.show()

How to create a dendrogram in Python using scipy and matplotlib ?
How to create a dendrogram in Python using scipy and matplotlib ?

  • A scatter plot is created to visualize the generated data points before performing the clustering.
  • This helps understand the distribution and structure of the data.

5. Hierarchical Clustering:

1
linkage_matrix = linkage(data, method="single")
  • The hierarchical clustering is performed using the "single" linkage method, which merges clusters based on the minimum distance between points in different clusters.
  • The result is stored in the linkage_matrix, which is used to create the dendrograms.

6. First Dendrogram (Without Leaf Counts):

1
2
3
4
5
6
7
plt.figure(figsize=(10, 4))
dendro_data = custom_dendrogram(linkage_matrix, color_threshold=1, p=6, truncate_mode='lastp', show_leaf_counts=False)
plt.title("Dendrogram (Without Leaf Counts)")
plt.xlabel("Cluster Index")
plt.ylabel("Distance")
plt.savefig('dendrogram_without_leaf_counts.png')
plt.show()

How to create a dendrogram in Python using scipy and matplotlib ?
How to create a dendrogram in Python using scipy and matplotlib ?

  • The first dendrogram is generated and shows only the last 6 clusters (truncate_mode='lastp', p=6) without displaying the leaf counts.
  • The color_threshold=1 option highlights merges above a certain distance.

7. Second Dendrogram (With Leaf Counts):

1
2
3
4
5
6
7
plt.figure(figsize=(10, 4))
dendro_data = custom_dendrogram(linkage_matrix, color_threshold=1, p=6, truncate_mode='lastp', show_leaf_counts=True)
plt.title("Dendrogram (With Leaf Counts)")
plt.xlabel("Cluster Index")
plt.ylabel("Distance")
plt.savefig('dendrogram_with_leaf_counts.png')
plt.show()

How to create a dendrogram in Python using scipy and matplotlib ?
How to create a dendrogram in Python using scipy and matplotlib ?

  • This second dendrogram is similar to the first but includes the leaf counts, showing how many data points are in each cluster.

Summary:

  • Custom Dendrograms: The custom_dendrogram function enhances the standard dendrogram by adding visual markers and annotations.
  • Scatter Plot: Displays the randomly generated 2D data points.
  • Hierarchical Clustering: Uses single linkage to create a clustering hierarchy, which is visualized in two dendrograms—one without and one with leaf counts.

Full Code

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from scipy.cluster.hierarchy import dendrogram, linkage
import matplotlib.pyplot as plt
import numpy as np

# Custom function for generating a dendrogram with distance annotations
def custom_dendrogram(*args, **kwargs):
    # Create the standard dendrogram
    dendro_data = dendrogram(*args, **kwargs)

    # Add annotations for cluster heights if no_plot is False
    if not kwargs.get('no_plot', False):
        # Loop through the clusters to add custom red dots and distance annotations
        for icoord, dcoord in zip(dendro_data['icoord'], dendro_data['dcoord']):
            x_coord = 0.5 * sum(icoord[1:3])  # Find the midpoint of the cluster
            height = dcoord[1]  # Distance (height) at which the clusters are merged
            plt.plot(x_coord, height, 'ro')  # Plot a red dot at the merge point
            plt.annotate(f"{height:.3g}", (x_coord, height), xytext=(0, -8),
                         textcoords='offset points', va='top', ha='center')  # Annotate the height

    return dendro_data

# Generate random 2D data points for hierarchical clustering
np.random.seed(12312)  # Set seed for reproducibility
num_points = 100  # Number of points
data = np.random.multivariate_normal([0, 0], np.array([[4.0, 2.5], [2.5, 1.4]]), size=num_points)

# Scatter plot of the generated data points
plt.figure(figsize=(6, 5))
plt.scatter(data[:, 0], data[:, 1])
plt.title("Scatter Plot of Data Points")
plt.axis('equal')  # Ensure equal scaling on both axes
plt.grid(True)
plt.savefig('scatter_plot.png')
plt.show()

# Perform hierarchical clustering using the 'single' linkage method
linkage_matrix = linkage(data, method="single")

# Plot the first dendrogram (without leaf counts)
plt.figure(figsize=(10, 4))
dendro_data = custom_dendrogram(linkage_matrix, 
                                color_threshold=1, 
                                p=6, 
                                truncate_mode='lastp', 
                                show_leaf_counts=False)
plt.title("Dendrogram (Without Leaf Counts)")
plt.xlabel("Cluster Index")
plt.ylabel("Distance")
plt.savefig('dendrogram_without_leaf_counts.png')
plt.show()

# Plot the second dendrogram (with leaf counts)
plt.figure(figsize=(10, 4))
dendro_data = custom_dendrogram(linkage_matrix, 
                                color_threshold=1, 
                                p=6, 
                                truncate_mode='lastp', 
                                show_leaf_counts=True)
plt.title("Dendrogram (With Leaf Counts)")
plt.xlabel("Cluster Index")
plt.ylabel("Distance")
plt.savefig('dendrogram_with_leaf_counts.png')
plt.show()

Visualizing Hierarchical Clustering with Overlaid Dendrograms on a Distance Matrix

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np  # Import NumPy for random number generation
import pylab  # Import Pylab (part of matplotlib for plotting)
import scipy.cluster.hierarchy as sch  # Import hierarchical clustering methods from SciPy

# Generate random features (1D array of 40 elements) and initialize a distance matrix.
x = np.random.rand(40)  # Create a random array of 40 elements using NumPy
D = np.zeros([40, 40])  # Initialize a 40x40 zero matrix to store distances

# Populate the distance matrix by calculating the absolute differences between the elements.
for i in range(40):
    for j in range(40):
        D[i, j] = abs(x[i] - x[j])  # Compute the absolute difference between points

# Create the figure and plot the first dendrogram.
fig = pylab.figure(figsize=(8, 8))  # Create an 8x8 inch figure
ax1 = fig.add_axes([0.09, 0.1, 0.2, 0.6])  # Add the first subplot for the first dendrogram

# Perform hierarchical clustering using the 'centroid' method and plot the first dendrogram.
Y = sch.linkage(D, method='centroid')  # Compute the hierarchical clustering with centroid linkage
Z1 = sch.dendrogram(Y, orientation='right')  # Generate the dendrogram with 'right' orientation
ax1.set_xticks([])  # Remove x-axis ticks
ax1.set_yticks([])  # Remove y-axis ticks

# Compute and plot the second dendrogram.
ax2 = fig.add_axes([0.3, 0.71, 0.6, 0.2])  # Add the second subplot for the second dendrogram
Y = sch.linkage(D, method='single')  # Perform hierarchical clustering using the 'single' linkage method
Z2 = sch.dendrogram(Y)  # Generate the dendrogram (default orientation)
ax2.set_xticks([])  # Remove x-axis ticks
ax2.set_yticks([])  # Remove y-axis ticks

# Reorder and plot the distance matrix according to the dendrogram's leaf order.
axmatrix = fig.add_axes([0.3, 0.1, 0.6, 0.6])  # Add the main subplot for the reordered distance matrix
idx1 = Z1['leaves']  # Get the order of leaves from the first dendrogram
idx2 = Z2['leaves']  # Get the order of leaves from the second dendrogram
D = D[idx1, :]  # Reorder rows of the distance matrix based on the first dendrogram
D = D[:, idx2]  # Reorder columns of the distance matrix based on the second dendrogram
im = axmatrix.matshow(D, aspect='auto', origin='lower', cmap=pylab.cm.YlGnBu)  # Plot the reordered matrix with color

# Remove ticks for the matrix plot.
axmatrix.set_xticks([])  # Remove x-axis ticks
axmatrix.set_yticks([])  # Remove y-axis ticks

# Add a colorbar to show the scale of the distances.
axcolor = fig.add_axes([0.91, 0.1, 0.02, 0.6])  # Add an axis for the colorbar
pylab.colorbar(im, cax=axcolor)  # Create and add the colorbar

# Display the plot and save it as an image.
fig.savefig('dendrogram_example_02.png')  # Save the figure as a PNG file
fig.show()  # Display the figure

How to create a dendrogram in Python using scipy and matplotlib ?
How to create a dendrogram in Python using scipy and matplotlib ?

References

Image

of