K-means#

Resources:

K-means on MNIST#

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml

%config InlineBackend.figure_format = 'svg'

X, y = fetch_openml('mnist_784', return_X_y=True, parser='auto')
X = X.astype(float).values / 255
y = y.astype(int).values

Apply mini-batch K-means:

from sklearn.cluster import MiniBatchKMeans
kmeans_mini = MiniBatchKMeans(n_clusters=10, n_init='auto')
%time kmeans_mini.fit(X)
print("Intertia:", kmeans_mini.inertia_)
print("Class labels:", kmeans_mini.labels_)
CPU times: user 1.48 s, sys: 320 ms, total: 1.8 s
Wall time: 1.13 s
Intertia: 2789522.3601457523
Class labels: [8 5 1 ... 3 9 7]

Calculate silhouette score:

from sklearn.metrics import silhouette_score
%time silhouette_score(X, kmeans_mini.labels_, metric='euclidean')

Now plot the cluster centers:

plt.figure(figsize=(10, 10))
for i in range(10):
    plt.subplot(3, 4, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(kmeans_mini.cluster_centers_[i].reshape(28, 28), cmap='gray')
    plt.title(f"Label index: {i}", size=15)
../_images/fd480fd893937dbfdfaab228d27d3c023de5d79e4607bb17826cb9f57132f3d2.svg

Can you guess who is who here?

Now take the true K-means.

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=10, n_init='auto')
%time kmeans.fit(X)
print("Intertia:", kmeans.inertia_)
print("Class labels:", kmeans.labels_)
CPU times: user 17.4 s, sys: 578 ms, total: 18 s
Wall time: 10.4 s
Intertia: 2744522.4342046715
Class labels: [3 6 5 ... 7 3 0]

Silhouette score of K-means:

from sklearn.metrics import silhouette_score
%time silhouette_score(X, kmeans.labels_, metric='euclidean')
CPU times: user 4min 54s, sys: 22.9 s, total: 5min 17s
Wall time: 3min 21s
0.05599485888756957

Once again plot the centers of clusters:

plt.figure(figsize=(10, 10))
for i in range(10):
    plt.subplot(3, 4, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(kmeans.cluster_centers_[i].reshape(28, 28), cmap='gray')
    plt.title(f"Label index: {i}", size=15)
../_images/e4914ef53c18d393a7df86493e10507ee7de3c038a877b83e3066a34d449cca7.svg