Unsupervised Machine Learning With Python: Clustering. Mean Shift Algorithm

Unsupervised Machine Learning With Python: Clustering. Mean Shift Algorithm

It is yet another well-liked and effective clustering method applied in unsupervised learning. It is a non-parametric method because it makes no assumptions. It is also known as mean shift cluster analysis or hierarchical clustering. The fundamental steps of this algorithm would be as follows:

  • We must begin by looking at the data points that have been given their own cluster.
  • The centroids are now computed, and the locations of new centroids are updated.
  • Repeating this procedure brings us closer to the apex of the cluster, or to the area of higher density.
  • The point where centroids stop moving is where this method terminates.

Let us employ this algorithm and train a model in the Python Programming Language.

MEAN SHIFT ALGORITHM EXAMPLE

We begin by importing all necessary packages into our Python script:

import numpy as np
import numpy as np
from sklearn.cluster import MeanShift
import matplotlib.pyplot as plt
from matplotlib import style
from sklearn.datasets import make_blobs        

We set an optional style for our matplotlib canvas:

style.use("ggplot")        

We generate dummy data:

centers = [[2, 2],
           [4, 5],
           [3, 10]]

X, _ = make_blobs(n_samples = 500,
                  centers = centers,
                  cluster_std = 1)        

We may visualize the dummy data as follows:

plt.title("Mean Shift Algorithm Clusters")
plt.xlabel("X-AXIS")
plt.ylabel("Y-AXIS")
plt.scatter(X[:, 0], X[:, 1])
plt.show()        

The output to the above block of code will show as follows:

Next, we may instantiate and train an object of the MeanShift class:

algorithm = MeanShift()

model = algorithm.fit(X)        

We may obtain the predicted clusters of our model, as well as the cluster centers:

predictions = model.labels_

cluster_centers = model.cluster_centers_        

We will be able to check the data points for the cluster centers of the model, as well as the suggested number of clusters that are detected:

print(cluster_centers)
print("Estimated Number Of Clusters:", n_clusters_)        

It is good to note that the output we see in the above image and the graph below will differ each time you run your Python script.

Finally, we may proceed to visualize our KMeans Model using MatPlotLib:

for i in range(len(X)):
    plt.plot(X[i][0], X[i][1], colors[predictions[i]], markersize = 10)
    plt.scatter(cluster_centers[:,0],cluster_centers[:,1],
                marker="x",color='k', s=150, linewidths = 5, zorder=10)

plt.title("Mean Shift Algorithm Clusters")
plt.xlabel("X-AXIS")
plt.ylabel("Y-AXIS")
plt.show()        

The visualization shows as follows:


Can't wait to dive into this insightful article on the Mean Shift Clustering Algorithm! 🐍

To view or add a comment, sign in

Insights from the community

Others also viewed

Explore topics