Skip to content

Commit

Permalink
removing gmms from plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
kvarada committed Oct 31, 2024
1 parent 96661ed commit 2d2617a
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 28 deletions.
21 changes: 20 additions & 1 deletion lectures/code/plotting_functions_unsup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,7 +1275,26 @@ def plot_original_clustered(X, model, labels):
discrete_scatter(
model.cluster_centers_[:, 0], model.cluster_centers_[:, 1], y=np.arange(0,k), s=15,
markers='*', markeredgewidth=1.0, ax=ax[1])


def plot_kmeans(X, k):
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].set_title("Original dataset")
ax[0].set_xlabel("Feature 0")
ax[0].set_ylabel("Feature 1")
discrete_scatter(X[:, 0], X[:, 1], ax=ax[0]);
# cluster the data into three clusters
# plot the cluster assignments and cluster centers

kmeans = KMeans(n_clusters=k, n_init='auto', random_state=42)
kmeans.fit(X)
ax[1].set_title(f"KMeans clusters n_clusters={k}")
ax[1].set_xlabel("Feature 0")
ax[1].set_ylabel("Feature 1")
discrete_scatter(X[:, 0], X[:, 1], kmeans.labels_, markers='o', ax=ax[1])
discrete_scatter(
kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], range(0,k), s=15,
markers='*', markeredgewidth=1.0, ax=ax[1])

def plot_kmeans_gmm(X, k):
fig, ax = plt.subplots(1, 3, figsize=(16, 4))
ax[0].set_title("Original dataset")
Expand Down
Loading

0 comments on commit 2d2617a

Please sign in to comment.