Skip to content

Commit

Permalink
Add a new benchmark over many datasets (#264)
Browse files Browse the repository at this point in the history
* Initial version of NearestConvexHull.

* Added script for testing.

* First version that runs.

* Improved code.

* Added support for parallel processing.
It gives an error: AttributeError: Pipeline has none of the following attributes: decision_function.

* renamed

* New version that uses a new class that implements a NCH classifier.

* small update

* Updated to newest code - the new version of the distance function.
Added an example that runs on a small number of test samples, so that we can get results quicker.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* reinforce constraint on weights

* - remove constraints on weights
- limite size of training set
- change to slsqp optimizer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added n_max_hull parameter. MOABB support tested.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* added multiple hulls.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Code cleanups.
Added second parameter that specifies the number of hulls.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Improved code.
Added support for transform().
Added a new pipeline [NCH+LDA]

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* updated default parameters

* General improvements.
Improvements requested by GC.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* removed commented code

* Small adjustments.

* Better class separation.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added support for n_samples_per_hull = -1 which takes all the samples for a class.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update pyriemann_qiskit/classification.py

Set of SPD matrices.

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/classification.py

Added new lines to before Parameters

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/classification.py

[y == c, :, :] => [y == c]

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Update pyriemann_qiskit/classification.py

NearestConvexHull text change

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* Improvements proposed by Quentin.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added comment for the optimizer.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Added some comments in classification.
Changes about the global optimizer so, that it is more evident that a global one is used.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Implemented min hull.
Added support for both "min-hull" and "random-hull" using the constructor parameter "hull-type".

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Reverted to previous version as requested by Gregoire.

* fix lint issues

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* This is the new benchmark that uses a large number of P300 and Motor Imagery databases.
It automatically handles paradigms and pipelines differences.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* added n_jobs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* update cache settings

* Better configuration.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* enable parallel

* Update heavy_benchmark.py

Added documentation.
Improved statistical plots.
Added parameters.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Heavy benchmark has been updated to its last version. Improved documentation.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Improved.

* Update heavy_benchmark.py

* Update heavy_benchmark.py

flake8

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* Update heavy_benchmark.py

* Update heavy_benchmark.py

flake8

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: gcattan <gcattan@hotmail.fr>
Co-authored-by: Gregoire Cattan <gregoire.cattan@ibm.com>
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
  • Loading branch information
5 people authored Jun 24, 2024
1 parent 28686ff commit 74bacb6
Show file tree
Hide file tree
Showing 2 changed files with 583 additions and 0 deletions.
92 changes: 92 additions & 0 deletions benchmarks/hb_nch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
"""
A demo on how to use benchmark_alpha.
Performs a benchmark of several variations of the NCH algorithm.
@author: anton andreev
"""

from pyriemann.estimation import XdawnCovariances
from sklearn.pipeline import make_pipeline
from pyriemann.classification import MDM
import os

from pyriemann_qiskit.classification import QuanticNCH
from heavy_benchmark import benchmark_alpha, plot_stat

# start configuration
hb_max_n_subjects = 3
hb_n_jobs = 12
hb_overwrite = False
# end configuration

labels_dict = {"Target": 1, "NonTarget": 0}
pipelines = {}

pipelines["NCH+RANDOM_HULL"] = make_pipeline(
# applies XDawn and calculates the covariance matrix, output it matrices
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
QuanticNCH(
n_hulls_per_class=1,
n_samples_per_hull=3,
n_jobs=12,
subsampling="random",
quantum=False,
),
)

pipelines["NCH+MIN_HULL"] = make_pipeline(
# applies XDawn and calculates the covariance matrix, output it matrices
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
QuanticNCH(
n_hulls_per_class=1,
n_samples_per_hull=3,
n_jobs=12,
subsampling="min",
quantum=False,
),
)

# this is a non quantum pipeline
pipelines["XD+MDM"] = make_pipeline(
XdawnCovariances(
nfilter=3,
classes=[labels_dict["Target"]],
estimator="lwf",
xdawn_estimator="scm",
),
MDM(),
)

results = benchmark_alpha(
pipelines,
max_n_subjects=hb_max_n_subjects,
n_jobs=hb_n_jobs,
overwrite=hb_overwrite,
)

print("Results:")
print(results)

print("Averaging the session performance:")
print(results.groupby("pipeline").mean("score")[["score", "time"]])

# save results
save_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "results_dataframe.csv"
)
results.to_csv(save_path, index=True)

# Provides statistics
plot_stat(results)
Loading

0 comments on commit 74bacb6

Please sign in to comment.