diff --git a/tests/test_clustering.py b/tests/test_clustering.py index 3ce10106..c8ea040e 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -6,7 +6,7 @@ @pytest.mark.parametrize("reducer,cluster_reduced", [("tsne", True), ("umap", True), ("umap", False)]) -def test_full_pipeline(reducer, cluster_reduced): +def test_full_pipeline(reducer, cluster_reduced, tmp_path): cluster = TextClustering(reducer=reducer, cluster_reduced=cluster_reduced, embedding_random_state=42, reducer_random_state=43, @@ -23,6 +23,17 @@ def test_full_pipeline(reducer, cluster_reduced): assert len(cluster.cluster_kws) == len(cluster.cluster_ids) == 6 + cluster.save(folder=tmp_path) + + cluster_new = TextClustering() + cluster_new.load(folder=tmp_path) + + # Asserts all coordinates of the loaded points are equal + assert (cluster_new.embedded_points != cluster.embedded_points).sum() == 0 + assert (cluster_new.reduced_points != cluster.reduced_points).sum() == 0 + assert cluster_new.reducer_class.__class__ == cluster.reducer_class.__class__ + assert cluster_new.clustering_class.__class__ == cluster.clustering_class.__class__ + @pytest.mark.parametrize("reducer", ["tsne", "umap"]) def test_parameter_search(reducer): diff --git a/wellcomeml/ml/clustering.py b/wellcomeml/ml/clustering.py index 6c6b36c5..1943c7f0 100644 --- a/wellcomeml/ml/clustering.py +++ b/wellcomeml/ml/clustering.py @@ -1,6 +1,7 @@ from collections import defaultdict import logging import os +import pickle from wellcomeml.ml import vectorizer from wellcomeml.logger import logger @@ -39,6 +40,7 @@ class TextClustering(object): cluster_names: Names of the clusters cluster_kws: Keywords for the clusters (only if embedding=tf-idf) """ + def __init__(self, embedding='tf-idf', reducer='umap', clustering='dbscan', cluster_reduced=True, n_kw=10, params={}, embedding_random_state=None, reducer_random_state=None, @@ -113,6 +115,8 @@ class that is a sklearn.base.ClusterMixin 'random_state'): self.clustering_class.random_state = clustering_random_state + self.embedded_points = None + self.reduced_points = None self.cluster_ids = None self.cluster_names = None self.cluster_kws = None @@ -120,6 +124,12 @@ class that is a sklearn.base.ClusterMixin self.silhouette = None self.optimise_results = {} + self.embedded_points_filename = 'embedded_points.npy' + self.reduced_points_filename = 'reduced_points.npy' + self.vectorizer_filename = 'vectorizer.pkl' + self.reducer_filename = 'reducer.pkl' + self.clustering_filename = 'clustering.pkl' + def fit(self, X, *_): """ Fits all clusters in the pipeline @@ -131,22 +141,28 @@ def fit(self, X, *_): A TextClustering object """ - self._fit_step(X, step='vectorizer') - self._fit_step(step='reducer') - self._fit_step(step='clustering') + self.fit_step(X, step='vectorizer') + self.fit_step(step='reducer') + self.fit_step(step='clustering') if self.embedding == 'tf-idf' and self.n_kw: self._find_keywords(self.embedded_points.toarray(), n_kw=self.n_kw) return self - def _fit_step(self, X=None, step='vectorizer'): + def fit_step(self, X=None, y=None, step='vectorizer'): """Internal function for partial fitting only a certain step""" if step == 'vectorizer': self.embedded_points = self.vectorizer.fit_transform(X) elif step == 'reducer': - self.reduced_points = \ - self.reducer_class.fit_transform(self.embedded_points) + if self.embedded_points is None: + raise ValueError( + 'You must embed/vectorise the points before reducing dimensionality' + ) + if X is None: + X = self.embedded_points + + self.reduced_points = self.reducer_class.fit_transform(X=X, y=y) elif step == 'clustering': points = ( self.reduced_points if self.cluster_reduced else @@ -260,7 +276,9 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2, # Prunes result to actually optimise under constraints best_silhouette = 0 best_params = {} + grid.fit(X, y=None) + for params, silhouette, noise, n_clusters in zip( grid.cv_results_['params'], grid.cv_results_['mean_test_silhouette'], @@ -292,6 +310,74 @@ def optimise(self, X, param_grid, n_cluster_range=None, max_noise=0.2, return best_params + def save(self, folder, components='all', create_folder=True): + """ + Saves the different steps of the pipeline + + Args: + folder(str): path to folder + components(list or 'all'): List of components to save. Options are: 'embbedded_points', + 'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads + 'all' (you can get all components by listing the class param + TextClustering.components) + + """ + if create_folder: + os.makedirs(folder, exist_ok=True) + + if components == 'all' or 'embedded_points' in components: + np.save(os.path.join(folder, self.embedded_points_filename), self.embedded_points) + + if components == 'all' or 'reduced_points' in components: + np.save(os.path.join(folder, self.reduced_points_filename), self.reduced_points) + + if components == 'all' or 'vectorizer' in components: + with open(os.path.join(folder, self.vectorizer_filename), 'wb') as f: + pickle.dump(self.vectorizer, f) + + if components == 'all' or 'reducer' in components: + with open(os.path.join(folder, self.reducer_filename), 'wb') as f: + pickle.dump(self.reducer_class, f) + + if components == 'all' or 'clustering_model' in components: + with open(os.path.join(folder, self.clustering_filename), 'wb') as f: + pickle.dump(self.clustering_class, f) + + def load(self, folder, components='all'): + """ + Loads the different steps of the pipeline + + Args: + folder(str): path to folder + components(list or 'all'): List of components to load. Options are: 'embbedded_points', + 'reduced_points', 'vectorizer', 'reducer', and 'clustering_model'. By default, loads + 'all' (you can get all components by listing the class param + TextClustering.components) + + """ + + if components == 'all' or 'embedded_points' in components: + self.embedded_points = np.load(os.path.join(folder, self.embedded_points_filename), + allow_pickle=True) + if not self.embedded_points.shape: + self.embedded_points = self.embedded_points[()] + + if components == 'all' or 'reduced_points' in components: + self.reduced_points = np.load(os.path.join(folder, self.reduced_points_filename), + allow_pickle=True) + + if components == 'all' or 'vectorizer' in components: + with open(os.path.join(folder, self.vectorizer_filename), 'rb') as f: + self.vectorizer = pickle.load(f) + + if components == 'all' or 'reducer' in components: + with open(os.path.join(folder, self.reducer_filename), 'rb') as f: + self.reducer_class = pickle.load(f) + + if components == 'all' or 'clustering_model' in components: + with open(os.path.join(folder, self.clustering_filename), 'rb') as f: + self.clustering_class = pickle.load(f) + def stability(self): """Function to calculate how stable the clusters are""" raise NotImplementedError