-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathKMeans.py
106 lines (83 loc) · 3.88 KB
/
KMeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
import matplotlib.pyplot as plt
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1-x2)**2))
class KMeans:
'''
KMeans is an _unsupervised learning_ method that clusters data sets into
k different clusters. Each sample is assigned to the cluster with the nearest
mean, then the means (centroids) and clusters are updated during an iterative
optimization process
1. Initialize cluster centers
2. Repeat until converged
2a. Update cluster labels: Assign pointes to the nearest cluster center (centroid)
2b. Update the cluster centers (centraois)
'''
def __init__(self, K=5, max_iters=100, plot_steps=False) -> None:
self.K = K
self.max_iters = max_iters
self.plot_steps = plot_steps
# list of sample indices for each cluster
self.clusters = [[] for _ in range(self.K)]
# store the centers (mean vector for each cluster)
self.centroids = []
def predict(self, X: np.ndarray): # no y value since this is unsupervised
self.X = X
self.n_samples, self.n_features = X.shape
# initialize centroids
random_sample_idxs = np.random.choice(self.n_samples, self.K, replace=False)
self.centroids = [self.X[idx] for idx in random_sample_idxs]
# optimize clusters
for i in range(self.max_iters):
# assign samples to the closest centroids (create initial clusters)
self.clusters = self._create_clusters(self.centroids)
if self.plot_steps:
self.plot()
#calculate new centroids from clusters
centroids_old = self.centroids
self.centroids = self._get_centroids(self.clusters)
if self._is_converged(centroids_old, self.centroids):
print(f"Centroids have converged to a common value after {i} steps.")
break
if self.plot_steps:
self.plot()
# classify samples as the index of their clusters
return self._get_cluster_labels(self.clusters)
def _get_cluster_labels(self, clusters) -> np.ndarray:
# each sample will get the label of thecluster it was assigned to
labels = np.empty(self.n_samples)
for cluster_idx, cluster in enumerate(clusters):
for sample_idx in cluster:
labels[sample_idx] = cluster_idx
return labels
def _create_clusters(self, centroids):
# assign the samples to the closest centroids
clusters = [[] for _ in range(self.K)]
for idx, sample in enumerate(self.X):
centroid_idx = self._closest_centroid(sample, centroids)
clusters[centroid_idx].append(idx)
return clusters
def _closest_centroid(self, sample, centroids):
# distance of the current sample to each centroid
distances = [euclidean_distance(sample, point) for point in centroids]
closest_idx = np.argmin(distances)
return closest_idx
def _get_centroids(self, clusters):
# assign the mean value of the clusters to the centroids
centroids = np.zeros((self.K, self.n_features))
for cluster_idx, cluster in enumerate(clusters):
cluster_mean = np.mean(self.X[cluster], axis=0)
centroids[cluster_idx] = cluster_mean
return centroids
def _is_converged(self, centroids_old, centroids):
# check distances between old and new centroids for all centroids
distances = [euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]
return sum(distances) == 0
def plot(self):
_, ax = plt.subplots(figsize=(12, 8))
for i, index in enumerate(self.clusters):
point = self.X[index].T
ax.scatter(*point)
for point in self.centroids:
ax.scatter(*point, marker="x", color="black", linewidth=2)
plt.show()