K-Means in Machine Learning

Many clustering algorithms are available in Scikit-Learn and elsewhere, but perhaps the simplest to understand is an algorithm known as k-means clustering, which is implemented in sklearn.cluster.KMeans.


The k-means algorithm searches for a pre-determined number of clusters within an unlabeled multidimensional dataset. It accomplishes this using a simple conception of what the optimal clustering looks like:

  • The “cluster centre” is the arithmetic mean of all the points belonging to the cluster.
  • Each point is closer to its cluster centre than to other cluster centres.

Those two assumptions are the basis of the k-means model. We will soon dive into exactly how the algorithm reaches this solution, but for now, let’s take a look at a simple dataset and see the k-means result.

First, let’s generate a two-dimensional dataset containing four distinct blobs. To emphasize that this is an unsupervised algorithm, we will leave the labels out of the visualization.

import matplotlib.pyplot as plt
import seaborn as sns; sns.set()  # for plot styling
import numpy as np
from sklearn.datasets.samples_generator import make_blobs
X, y_true = make_blobs(n_samples=300, centers=4,
                       cluster_std=0.60, random_state=0)
plt.scatter(X[:, 0], X[:, 1], s=50)Code language: Python (python)
image for post

By eye, it is relatively easy to pick out the four clusters. The k-means algorithm does this automatically, and in Scikit-Learn uses the standard estimator API:

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=4)
y_kmeans = kmeans.predict(X)Code language: Python (python)

Let’s visualize the results by plotting the data coloured by these labels. We will also plot the cluster centers as determined by the k-means estimator:

plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='viridis')

centers = kmeans.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, alpha=0.5)Code language: Python (python)

The good news is that the k-means algorithm (at least in this simple case) assigns the points to clusters very similarly to how we might assign them by eye.

But you might wonder how this algorithm finds these clusters so quickly! After all, the number of possible combinations of cluster assignments is exponential in the name of data points—an exhaustive search would be very, very costly. Fortunately for us, such a thorough search is not necessary: instead, the typical approach to k-means involves an intuitive, iterative approach known as expectation–maximization.

k-Means Algorithm: Expectation–Maximization

Expectation–maximization (E–M) is a robust algorithm that comes up in a variety of contexts within data science. k-means is a particularly easy-to-understand and straightforward application of the algorithm, and we will walk through it briefly here. In short, the expectation-maximization approach here consists of the following procedure:

  1. Guess some cluster centers
  2. Repeat until converged
    1. E-Step: assign points to the nearest cluster center
    2. M-Step: set the cluster centers to the mean

Here the “E-step” or “Expectation step” is so-named because it involves updating our expectation of which cluster each point belongs to. The “M-step” or “Maximization step” is so-named because it consists in maximizing some fitness function that defines the location of the cluster centers—in this case, that maximization is accomplished by taking a simple mean of the data in each cluster.

The k-Means algorithm is simple enough that we can write it in a few lines of code. The following is a very basic implementation:

from sklearn.metrics import pairwise_distances_argmin

def find_clusters(X, n_clusters, rseed=2):
    # 1. Randomly choose clusters
    rng = np.random.RandomState(rseed)
    i = rng.permutation(X.shape[0])[:n_clusters]
    centers = X[i]
    while True:
        # 2a. Assign labels based on closest center
        labels = pairwise_distances_argmin(X, centers)
        # 2b. Find new centers from means of points
        new_centers = np.array([X[labels == i].mean(0)
                                for i in range(n_clusters)])
        # 2c. Check for convergence
        if np.all(centers == new_centers):
        centers = new_centers
    return centers, labels

centers, labels = find_clusters(X, 4)
plt.scatter(X[:, 0], X[:, 1], c=labels,
            s=50, cmap='viridis')Code language: Python (python)

Example: K-Means Algorithm on Digits

To start, let’s take a look at applying k-means on the same simple digits data. Here we will attempt to use k-means to try to identify similar digits without using the original label information; this might be similar to a first step in extracting meaning from a new dataset about which you don’t have any a priori label information.

Also, read – 10 Machine Learning Projects to Boost your Portfolio.

We will start by loading the digits and then finding the KMeans clusters. Recall that the numbers consist of 1,797 samples with 64 features, where each of the 64 elements is the brightness of one pixel in an 8×8 image:

from sklearn.datasets import load_digits
digits = load_digits()
digits.data.shapeCode language: Python (python)
(1797, 64)

Now let’s do the clustering:

kmeans = KMeans(n_clusters=10, random_state=0)
clusters = kmeans.fit_predict(digits.data)
kmeans.cluster_centers_.shapeCode language: Python (python)
(10, 64)

The result is 10 clusters in 64 dimensions. Notice that the cluster centers themselves are 64-dimensional points, and can themselves be interpreted as the “typical” digit within the cluster. Let’s see what these cluster centers look like:

fig, ax = plt.subplots(2, 5, figsize=(8, 3))
centers = kmeans.cluster_centers_.reshape(10, 8, 8)
for axi, center in zip(ax.flat, centers):
    axi.set(xticks=[], yticks=[])
    axi.imshow(center, interpolation='nearest', cmap=plt.cm.binary)Code language: Python (python)

We see that even without the labels, KMeans can find clusters whose centers are recognizable digits, with perhaps the exception of 1 and 8.

Because k-means knows nothing about the identity of the cluster, the 0–9 labels may be permuted. We can fix this by matching each learned cluster label with the correct names found in them:

from scipy.stats import mode

labels = np.zeros_like(clusters)
for i in range(10):
    mask = (clusters == i)
    labels[mask] = mode(digits.target[mask])[0]Code language: Python (python)

Now we can check how accurate our unsupervised clustering was in finding similar digits within the data:

from sklearn.metrics import accuracy_score
accuracy_score(digits.target, labels)Code language: Python (python)

With just a simple k-means algorithm, we discovered the correct grouping for 80% of the input digits! Let’s check the confusion matrix for this:

from sklearn.metrics import confusion_matrix
mat = confusion_matrix(digits.target, labels)
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
plt.xlabel('true label')
plt.ylabel('predicted label')Code language: Python (python)
confusion matrix

As we might expect from the cluster centers we visualized before, the main point of confusion is between the eights and ones. But this still shows that using k-means, we can essentially build a digit classifier without reference to any known labels.

Also, read – Employee Turnover Prediction with Machine Learning.

I hope you liked this article on k-means algorithm. Feel free to ask your valuable questions in the comments section. Don’t forget to subscribe for my daily newsletters below to get email notifications if you like my work.

Follow Us:

Aman Kharwal
Aman Kharwal

I'm a writer and data scientist on a mission to educate others about the incredible power of data📈.

Articles: 1536

Leave a Reply