K-Means Walk-Through Example
Understand how K-means clustering partitions data by iteratively assigning points to clusters and updating centroids. Learn step-by-step with a Python example using synthetic data and sklearn. Gain hands-on experience visualizing clusters and how the algorithm converges to minimize variance within clusters.
In the previous lesson, we discussed that K-means clustering is an algorithm designed to partition a dataset into distinct clusters by minimizing the total variance within each cluster. In this lesson, we will move from theory to practice by performing a dry run of the K-means algorithm on a small synthetic dataset. We will follow Lloyd’s algorithm step-by-step, using Python and the sklearn library to visualize how data points are assigned to clusters and how centroids are iteratively updated.
-means algorithm
For a given dataset and value of , -means clustering has the following steps:
- Choose : Select the number of clusters, , such that .
- Initialize centroids: Choose number of centroids, typically randomly, as initial cluster centers.
- Calculate dissimilarity: Find the similarity score (distance, e.g., Euclidean) of each data point with respect to each centroid.
- Assign clusters: Based on the similarity score, assign each data point to the cluster whose centroid is the closest (least distant).
- Update centroids: From these new groupings, find new centroids by taking the mean of all data points assigned to that cluster.
- Repeat: Repeat steps 3 to 5 until the difference between old and new centroids is negligible (convergence).
If the steps above seem unclear, don’t worry. We will illustrate each step with an example.
Dry running the example
Let’s say we have the following dataset:
Step 1: Plotting the data
Run the following code widget in order to plot the data. Here, x and y contain all the x and y-coordinates to represent our synthetic data.
Let’s start with the first step of -means clustering and decide how many clusters we want if the number isn’t given already. Let the number of clusters be three, which means .
Step 2: Assigning values to centroids
The second step is to assign number of centroids with the random value. Since is , we’ll get three centroids , , and . Also, assign them random values yields:
In the following code, Cx and Cy represent x and y coordinates of the centroids:
Step 3: Calculating the dissimilarity score
The third step is to find the dissimilarity score of each data point (15 total) with each centroid. We’ll be using the Euclidean distance as the dissimilarity score. The function euclidean_distances takes two arrays, where each array is an array of points. Let’s see how to calculate the dissimilarity score using sklearn:
Here is the explanation for the code above:
-
Lines 3–7: We define two lists
xandyrepresenting the x and y-coordinates of the data points. Similarly,CxandCyrepresent the x and y-coordinates of initial centroids. -
Lines 10–11: We convert the data points and the centroid coordinates into arrays of 2-D points because
euclidean_distancesexpects its inputs to be matrices, where each row represents one point. The function computes pairwise distances between every point in the first matrix and every point in the second. By restructuring the data into lists of [x, y] pairs, we ensure the function receives properly formatted inputs and can correctly compute the distances. -
Line 13: We use the
euclidean_distancesfunction to calculate the Euclidean distances between data points and centroids and print the resulting array.
The code output will be a 2D array where each row represents a data point, and each column represents a centroid. The value at position in the array represents the Euclidean distance between the data point and the centroid.
The dissimilarity scores were calculated using sklearn and are also given below:
Dissimilarity Scores
Data Points | Centroid_1 (1, 1) | Centroid_2 (7, 2) | Centroid_3 (5, 6.5) |
1, 2 | 1 | 6 | 6.020797289 |
2, 1 | 1 | 5.099019514 | 6.264982043 |
2, 1.5 | 1.118033989 | 5.024937811 | 5.830951895 |
2.5, 3.5 | 2.915475947 | 4.74341649 | 3.905124838 |
3, 4 | 3.605551275 | 4.472135955 | 3.201562119 |
4, 3.5 | 3.905124838 | 3.354101966 | 3.16227766 |
4, 7.5 | 7.158910532 | 6.264982043 | 1.414213562 |
5, 6 | 6.403124237 | 4.472135955 | 0.5 |
5, 7 | 7.211102551 | 5.385164807 | 0.5 |
5.5, 2 | 4.609772229 | 1.5 | 4.527692569 |
6, 1.5 | 5.024937811 | 1.118033989 | 5.099019514 |
6, 3 | 5.385164807 | 1.414213562 | 3.640054945 |
6, 5.5 | 6.726812024 | 3.640054945 | 1.414213562 |
6.5, 5 | 6.800735254 | 3.041381265 | 2.121320344 |
7, 2.5 | 6.184658438 | 0.5 | 4.472135955 |
Now, don’t panic seeing the giant table. All this table tells is the distance of each data point from each centroid. For example, let’s see the first data point (let’s call it ) with respect to first centroid () as below:
Step 4: Assigning the clusters
After calculating the distances of each point from the centroids, the fourth step is to assign them to relevant clusters. This is done by selecting the centroid that’s least distant from each data point, therefore assigning it to the appropriate cluster.
This code creates a pandas
Here is the explanation for the code above:
- Line 15: Computes the Euclidean distance between every data point and every centroid.
- Line 18: Creates string labels such as “1,2” and “2,1” to use as DataFrame row names.
- Lines 19–21: Stores the distance matrix in a clean table format:
- Rows = data points
- Columns = distance to each centroid
- Line 22: Labels the index column for readability.
- Line 25: Assigns each point to the nearest centroid.
idxmin(axis=1)finds the column with the smallest distance for each row. - Line 26: The
DataFrameis printed to the console using theprintstatement.
Thedf DataFrame can also be visualized in tabular form as seen below:
Cluster Assignment
Data Points | Centroid_1 (1, 1) | Centroid_2 (7, 2) | Centroid_3 (5, 6.5) | Cluster |
1, 2 | 1 | 6 | 6.020797289 | C1 |
2, 1 | 1 | 5.099019514 | 6.264982043 | C1 |
2, 1.5 | 1.118033989 | 5.024937811 | 5.830951895 | C1 |
2.5, 3.5 | 2.915475947 | 4.74341649 | 3.905124838 | C1 |
3, 4 | 3.605551275 | 4.472135955 | 3.201562119 | C3 |
4, 3.5 | 3.905124838 | 3.354101966 | 3.16227766 | C3 |
4, 7.5 | 7.158910532 | 6.264982043 | 1.414213562 | C3 |
5, 6 | 6.403124237 | 4.472135955 | 0.5 | C3 |
5, 7 | 7.211102551 | 5.385164807 | 0.5 | C3 |
5.5, 2 | 4.609772229 | 1.5 | 4.527692569 | C2 |
6, 1.5 | 5.024937811 | 1.118033989 | 5.099019514 | C2 |
6, 3 | 5.385164807 | 1.414213562 | 3.640054945 | C2 |
6, 5.5 | 6.726812024 | 3.640054945 | 1.414213562 | C3 |
6.5, 5 | 6.800735254 | 3.041381265 | 2.121320344 | C3 |
7, 2.5 | 6.184658438 | 0.5 | 4.472135955 | C2 |
The table above shows us which data point got assigned to which cluster. For example, the data point got assigned to the cluster. This means that was closer to the centroid of . Let’s see this visually:
Let’s code this step in Python now:
Here is the explanation for the code above:
-
Lines 33–35: Encodes cluster labels (
'C1(1,1)','C2(7,2)','C3(5,6.5)') as integers: 0, 1, 2 and converts the encoded column to integers for easy use in plotting(plt.scatter(c=...)). This ensures that each cluster can be mapped consistently to a color. -
Lines 38–44: Defines a function to assign a color to each cluster based on the encoded integer value.
-
Line 46: Applies the color mapping to all data points to create a list of colors for plotting.
-
Lines 49–51: The code plots the data points with colors corresponding to their assigned clusters and overlays the centroids as squares using the same cluster colors. The
alpha=0.4parameter makes the points semi-transparent, whiles=75ensures that the centroid markers are clearly visible.
Step 5: Recomputing the centroids
Alright, now we’re getting somewhere. The image above is somewhat clustered. Here comes our fifth step, which will recompute the centroids of each cluster.
The cluster consists of four data points , and , that is, . The centroid, , of can be calculated as follows:
Similarly, the centroid of the cluster can be calculated as follows:
Finally, the centroid of the cluster can be calculated as follows:
Now, let’s see how our centroids have moved.
The above illustration shows the new position of the updated centroids. This looks promising as these updated centroid locations truly represent the center of their clusters.
Step 6: Repeating the steps
Lastly, moving on to the sixth step. This step says that if our centroid and the updated centroid are different, we must perform all these steps (–) again. For simplicity’s sake, we’ll fast-forward this process via Python code.
Let’s put this all together. In the following coding widget, update_clusters() is responsible for assigning clusters to each data point, and update_centroids() will update the centroids for each cluster by taking the mean of each data point within that cluster. Furthermore, we can control the number of iterations to be performed by -means by updating iterations in line 67.
Following is the explanation for the code above:
-
Lines 13–39: The first function,
update_clusters, takes the dataset and centroid positions as inputs, computes the dissimilarity scores (using the Euclidean distance measure) between each data point and each centroid, assigns each data point to the cluster whose centroid is the closest, and encodes the clusters with colors. -
Lines 42–55: The second function,
update_centroids, is responsible for updating the positions of the centroids.- Line 42: Defines a function
update_centroidsthat takes a DataFramedfas input. This DataFrame contains the data points and their assigned cluster labels (encoded as integers inClusters_encoded). - Line 43: Initializes a NumPy array means with zeros to store the updated centroid positions. The shape is [number of clusters, 2], because each centroid has an x and y coordinate.
len(df['Clusters_encoded'].unique())gives the number of unique clusters. - Line 44: Loops over each cluster ID to compute the mean position of points assigned to that cluster.
- Line 45: Initializes a counter count to track the number of data points belonging to the current cluster.
- Line 46: Iterates over the index of df (which are tuples representing the x, y coordinates of data points).
df.index.where(df['Clusters_encoded']==cluster_id)selects points that belong to the current cluster. - Line 47: Checks if the point is not
NaN. This is necessary becausewhere()fills non-matching entries withNaN. - Lines 48–49: Adds the x-coordinate (
pt[0]) and y-coordinate (pt[1]) of the point to the running sum for the current cluster centroid. - Line 50: Increments the counter for the number of points in the current cluster.
- Lines 52–53: Divides the summed coordinates by count to calculate the mean x and y positions of all points in the cluster. This gives the updated centroid location for the current cluster.
- Line 55: Returns the array means containing the updated centroid coordinates for all clusters.
- Line 42: Defines a function
-
Lines 58–64: The third function,
map_color, is used to map an integer value to a color to visualize the clusters correctly. -
Lines 67–77: Finally, the code runs a loop for a specified number of iterations. It updates the clusters and the centroids and plots the results at each iteration using the
matplotliblibrary. The output is a set of subplots showing how the clusters evolve over the iterations.
Now, we’ll see how sklearn performs the same job.
Following is the explanation for the code above:
-
Lines 6–12: The code first defines the dataset by creating arrays
xandyand concatenating them to form the 2D arrayX. -
Line 15: Next, the code uses the
KMeans()function from thesklearn.clusterpackage to partition the dataset into three clusters. -
Line 16: The
fit()method of theKMeansclass is called on the dataX, which performs the clustering and assigns each data point to one of the three clusters. -
Line 17: The
predict()method is called on the fitted model to get the cluster labels for each data point. Thepredict()method of theKMeansobject is used to obtain the predicted cluster assignments for each point in the dataset. -
Lines 32–40: To visualize the clustering results, the code defines a color mapping function
map_color()that maps the cluster assignments tocolors. Then, it plots the actual dataset and the dataset with the clusters highlighted side by side usingplt.subplots(). The actual dataset and clustered data are plotted using thescatter()function with the colors defined bymap_color(). Thecluster_centers_attribute of theKMeansobject is used to plot the cluster centers as squares on the clustered dataset plot. Finally, the code usesplt.show()to display the plot.
Conclusion
This walk-through demonstrated the core mechanism of the K-means algorithm using Lloyd’s iterative approach.
We observed that by iteratively calculating the distance (or dissimilarity), assigning points to the closest centroid, and then recomputing the centroid as the mean of the new cluster members, the algorithm minimizes the total within-cluster variance.
This process enables the initial, randomly selected cluster centers to converge into stable positions that represent the true means of the data groups.