diff --git a/pyannote/audio/pipelines/clustering.py b/pyannote/audio/pipelines/clustering.py index cd4b38935..15d61b719 100644 --- a/pyannote/audio/pipelines/clustering.py +++ b/pyannote/audio/pipelines/clustering.py @@ -471,6 +471,79 @@ def cluster( return clusters +class KMeansGPU(BaseClustering): + def __init__( + self, + metric: str = "", + max_num_embeddings: int = np.inf, + constrained_assignment: bool = False, + ): + """KMeans clustering + + Parameters + ---------- + metric : {""}, optional + Distance metric to use. KMeansGPU only supports the default value. + """ + super().__init__( + metric=metric, + max_num_embeddings=max_num_embeddings, + constrained_assignment=constrained_assignment, + ) + + def cluster( + self, embeddings, min_clusters: int, max_clusters: int, num_clusters: int = None + ): + try: + import cuml + import cupy as cp + from cuml.metrics.cluster import silhouette_score + except ImportError: + raise ImportError( + "KMeansGPU requires cuML. You can install it with 'https://docs.rapids.ai/install'." + ) + + assert max_clusters >= min_clusters > 0 + + num_embeddings = len(embeddings) + + may_single = False + if max_clusters > 1 and min_clusters == 1: + min_clusters = 2 + may_single = True + elif max_clusters == 1: + return np.zeros((num_embeddings,)) + + if num_embeddings <= min_clusters or num_embeddings == num_clusters: + return np.arange(num_embeddings) + + if num_clusters is not None: + agg_clust = cuml.cluster.KMeans(n_clusters=num_clusters) + clusters = agg_clust.fit_predict(embeddings) + return clusters.get() + + embeddings = cp.asarray(embeddings) + + best_score = -1 + best_clusters = None + + for num_clusters in range(min_clusters, min(max_clusters + 1, num_embeddings)): + agg_clust = cuml.cluster.KMeans(n_clusters=num_clusters) + clusters = agg_clust.fit_predict(embeddings) + + score = silhouette_score(embeddings, clusters) + + if score > best_score: + best_score = score + best_clusters = clusters + + if may_single: + if num_clusters == 2 and best_score < 0.25: + return np.zeros((num_embeddings,)) + + return best_clusters.get() + + class OracleClustering(BaseClustering): """Oracle clustering""" @@ -558,4 +631,5 @@ def __call__( class Clustering(Enum): AgglomerativeClustering = AgglomerativeClustering + KMeansGPU = KMeansGPU OracleClustering = OracleClustering diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 535da47de..f064b59ad 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -1,6 +1,6 @@ import numpy as np -from pyannote.audio.pipelines.clustering import AgglomerativeClustering +from pyannote.audio.pipelines.clustering import AgglomerativeClustering, KMeansGPU def test_agglomerative_clustering_num_cluster(): @@ -26,4 +26,83 @@ def test_agglomerative_clustering_num_cluster(): clusters = clustering.cluster( embeddings=embeddings, min_clusters=2, max_clusters=2, num_clusters=2 ) + print(clusters) assert np.array_equal(clusters, np.array([0, 1])) + + +def test_kmeans_clustering_num_cluster_gpu_too_small(): + clustering = KMeansGPU().instantiate({}) + + embeddings = np.array([[1.0, 1.0, 1.0, 1.0], [1.0, 2.0, 1.0, 2.0]]) + + # request 2 clusters + clusters = clustering.cluster( + embeddings=embeddings, min_clusters=2, max_clusters=2, num_clusters=2 + ) + + assert np.array_equal(clusters, np.array([0, 1])) + + # generate a 256-dimensional random vector + v = np.random.rand(256) + + # define the range and standard deviation of the generated cluster center + cluster_center_std_dev = 2.0 + + # generate 8 cluster centers randomly + num_clusters = 8 + cluster_centers = np.random.normal( + np.mean(v), cluster_center_std_dev, size=(num_clusters, 256) + ) + + a, b, c = clustering.set_num_clusters(num_clusters, 10, 1, 10) + assert a == 8 + + a, b, c = clustering.set_num_clusters(num_clusters, None, 1, 10) + assert a is None and b == 1 and c == 8 + + a, b, c = clustering.set_num_clusters(num_clusters, None, 8, 10) + assert a == 8 + + a, b, c = clustering.set_num_clusters(num_clusters, None, 7, 10) + assert a is None and b == 7 and c == 8 + + clustering.cluster( + embeddings=cluster_centers, num_clusters=a, min_clusters=b, max_clusters=c + ) + + +def test_kmeans_clustering_num_cluster_gpu_large(): + clustering = KMeansGPU().instantiate({}) + + # generate a 256-dimensional random vector + v = np.random.rand(256) + + # define the range and standard deviation of the generated cluster center + cluster_center_std_dev = 2.0 + vector_std_dev = 1 + + # generate 5 cluster centers randomly + num_clusters = 5 + cluster_centers = np.random.normal( + np.mean(v), cluster_center_std_dev, size=(num_clusters, 256) + ) + + # generate 2000 * 32 vectors + num_vectors_per_cluster = int(2000 * 32 / num_clusters) + all_vectors = [] + + for center in cluster_centers: + vectors = np.random.normal( + center, vector_std_dev, size=(num_vectors_per_cluster, 256) + ) + all_vectors.append(vectors) + + # stack all vectors + all_vectors = np.vstack(all_vectors) + + np.random.shuffle(all_vectors) + + clusters = clustering.cluster( + embeddings=all_vectors, min_clusters=2, max_clusters=10 + ) + assert np.unique(clusters).shape[0] == num_clusters