Skip to content

Commit

Permalink
Add random seed generator to NCH (#335)
Browse files Browse the repository at this point in the history
* Update classification.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gcattan and pre-commit-ci[bot] authored Dec 4, 2024
1 parent ba1fef2 commit 95c0ed1
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pyriemann_qiskit/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,8 @@ class NearestConvexHull(BaseEstimator, ClassifierMixin, TransformerMixin):
Subsampling strategy of training set to estimate distance to hulls.
"min" estimates hull using the n_samples_per_hull closest matrices.
"random" estimates hull using n_samples_per_hull random matrices.
seed : float, default=None
Optional random seed to use when subsampling is set to `random`.
References
----------
Expand All @@ -887,6 +889,7 @@ def __init__(
n_hulls_per_class=3,
n_samples_per_hull=10,
subsampling="min",
seed=None,
):
"""Init."""
self.n_jobs = n_jobs
Expand All @@ -895,6 +898,7 @@ def __init__(
self.matrices_per_class_ = {}
self.debug = False
self.subsampling = subsampling
self.seed = seed

if subsampling not in ["min", "random"]:
raise ValueError(f"Unknown subsampling type {subsampling}.")
Expand All @@ -917,6 +921,8 @@ def fit(self, X, y):
The NearestConvexHull instance.
"""

self.random_generator = random.Random(self.seed)

if self.debug:
print("Start NCH Train")
self.classes_ = np.unique(y)
Expand Down Expand Up @@ -970,7 +976,7 @@ def _process_sample_random_hull(self, x):
if self.n_samples_per_hull == -1: # use all data per class
hull_data = self.matrices_per_class_[c]
else: # use a subset of the data per class
random_samples = random.sample(
random_samples = self.random_generator.sample(
range(self.matrices_per_class_[c].shape[0]),
k=self.n_samples_per_hull,
)
Expand Down Expand Up @@ -1174,6 +1180,7 @@ def _init_algo(self, n_features):
n_samples_per_hull=self.n_samples_per_hull,
n_jobs=self.n_jobs,
subsampling=self.subsampling,
seed=self.seed,
)

self._optimizer = _get_docplex_optimizer_from_params_bag(
Expand Down

0 comments on commit 95c0ed1

Please sign in to comment.