From 95c0ed1e2e165f7108ed175cffcb26956a819b40 Mon Sep 17 00:00:00 2001 From: gcattan Date: Wed, 4 Dec 2024 14:57:00 +0100 Subject: [PATCH] Add random seed generator to NCH (#335) * Update classification.py * [pre-commit.ci] auto fixes from pre-commit.com hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pyriemann_qiskit/classification.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyriemann_qiskit/classification.py b/pyriemann_qiskit/classification.py index 301c8068..671b8011 100644 --- a/pyriemann_qiskit/classification.py +++ b/pyriemann_qiskit/classification.py @@ -872,6 +872,8 @@ class NearestConvexHull(BaseEstimator, ClassifierMixin, TransformerMixin): Subsampling strategy of training set to estimate distance to hulls. "min" estimates hull using the n_samples_per_hull closest matrices. "random" estimates hull using n_samples_per_hull random matrices. + seed : float, default=None + Optional random seed to use when subsampling is set to `random`. References ---------- @@ -887,6 +889,7 @@ def __init__( n_hulls_per_class=3, n_samples_per_hull=10, subsampling="min", + seed=None, ): """Init.""" self.n_jobs = n_jobs @@ -895,6 +898,7 @@ def __init__( self.matrices_per_class_ = {} self.debug = False self.subsampling = subsampling + self.seed = seed if subsampling not in ["min", "random"]: raise ValueError(f"Unknown subsampling type {subsampling}.") @@ -917,6 +921,8 @@ def fit(self, X, y): The NearestConvexHull instance. """ + self.random_generator = random.Random(self.seed) + if self.debug: print("Start NCH Train") self.classes_ = np.unique(y) @@ -970,7 +976,7 @@ def _process_sample_random_hull(self, x): if self.n_samples_per_hull == -1: # use all data per class hull_data = self.matrices_per_class_[c] else: # use a subset of the data per class - random_samples = random.sample( + random_samples = self.random_generator.sample( range(self.matrices_per_class_[c].shape[0]), k=self.n_samples_per_hull, ) @@ -1174,6 +1180,7 @@ def _init_algo(self, n_features): n_samples_per_hull=self.n_samples_per_hull, n_jobs=self.n_jobs, subsampling=self.subsampling, + seed=self.seed, ) self._optimizer = _get_docplex_optimizer_from_params_bag(