What is scikit learn clustering?
Scikit learn is one of the most popular open-source machine learning libraries in the Python ecosystem.
It contains supervised and unsupervised machine learning algorithms for use in regression, classification, and clustering.
What is clustering?
Clustering, also known as cluster analysis, is an unsupervised machine learning approach used to identify data points with similar characteristics to create distinct groups or clusters from the data.
Clustering Algorithms fall into the unsupervised machine learning category because they use data that is not pre-labeled.
They are widely used with applications in customer segmentation, anomaly detection, market segmentation, text analysis, and many more fields, as they help reveal relationships within the data.
Clustering algorithms
Clustering algorithms can be grouped into four broad categories, namely:
-
Hierarchical clustering algorithms: These are best used on data containing hierarchies as they organize data points in a top-down manner, creating a tree of clusters. For example, agglomerative hierarchal clustering algorithm.
-
Centroid-based clustering algorithms: These algorithms are widely used in clustering because they are easy to implement. They randomly group data points based on cluster centers known as centroids. These algorithms use distance metrics such as Euclidean distance to determine a central point or centroid and therefore know which data points form part of that cluster. The downsides of these algorithms are that outliers are included as part of the clusters and are not easily visible, and the number of centroids or clusters is arbitrary and will need tuning. For example, K-means, mean Shift clustering, and mini-Batch K-means clustering.
-
Density-based clustering algorithms: These algorithms use the density or composition structure of the data, as opposed to distance, to create clusters and hence clusters can be of any shape. The advantage is that they do not assign outliers to any groups and can be useful, especially when detecting data anomalies. For example, DBSCAN, and OPTICS.
-
Distribution-based clustering algorithms: These algorithms group data points based on the probability that they belong to the same distribution, such as binomial or normal distributions. For example, the Gaussian mixture model.
Code example
Let’s see an example in which we use DBSCAN to populate clusters from a randomly generated dataset and then plot these clusters using matplotlib.
import numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom sklearn.datasets import make_classificationfrom sklearn.cluster import DBSCANX, _= make_classification(n_samples=1000,n_features=2,n_informative=2,n_redundant=0,n_clusters_per_class=1,random_state=4)df = pd.DataFrame(X)print(df.shape)# # define the modeldbscan_model = DBSCAN(eps=0.35,min_samples=20)# # train the modeldbscan_model.fit(df)# #visualize the clusters.plt.figure(figsize=(10,10))plt.scatter(df[0],df[1],c = dbscan_model.labels_,s=15)plt.title('DBSCAN Clustering',fontsize=20)plt.xlabel('Feature 1',fontsize=14)plt.ylabel('Feature 2',fontsize=14)plt.show()
Code explanation
Let’s go through the code presented above:
-
Lines 1–5: We import the neccessary libraries for use.
-
Lines 7–14: We create a random dataset with 1000 samples and 2 features.
-
Lines 16–17: We convert the dataset output
Xinto a data frame and print the shape of the data frame. -
Line 20: We initialize the DBSCAN model with an
eps=0.35andmin_samples=20, both of which need to be tuned to obtain the optimal number of clusters and detect noise better. -
Line 23: We fit the model to the dataset and generate clusters.
-
Line 26–30: We visualize the clusters using a scatter plot.
Free Resources