-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create other light benchmarks to parallelize the work (#320)
* create other light benchmarks to parralelize the work * missing scripts * correct cache_key * fix wrong file in main branch * factorize code in base module * [pre-commit.ci] auto fixes from pre-commit.com hooks * flake8 * remove relative import * fix for light_benchmark workflow --------- Co-authored-by: Gregoire Cattan <gregoire.cattan@ibm.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
6050e7f
commit 541ce27
Showing
7 changed files
with
325 additions
and
109 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
name: Light Benchmark | ||
|
||
on: | ||
pull_request: | ||
paths: | ||
- 'pyriemann_qiskit/**' | ||
- 'examples/**' | ||
- '.github/workflows/light_benchmark_nch_min_hull.yml' | ||
- 'benchmarks/light_benchmark_nch_min_hull.py' | ||
- 'setup.py' | ||
|
||
jobs: | ||
light_benchmark_nch_min_hull: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Cache dependencies | ||
uses: actions/cache@v4 | ||
with: | ||
path: ~/.cache/pip | ||
key: light_benchmark_nch_min_hull.yml | ||
- name: Install dependencies | ||
run: | | ||
pip install .[docs] | ||
- name: Run benchmark script (PR) | ||
id: run-benchmark-pr | ||
run: | | ||
python benchmarks/light_benchmark_nch_min_hull.py pr | ||
- uses: actions/checkout@v4 | ||
with: | ||
ref: 'main' | ||
- name: Install dependencies | ||
run: | | ||
pip install .[docs] | ||
- name: Run benchmark script (main) | ||
id: run-benchmark-main | ||
run: | | ||
python benchmarks/light_benchmark_nch_min_hull.py ${{steps.run-benchmark-pr.outputs.NCH_MIN_HULL}} | ||
- name: Compare performance | ||
run: | | ||
echo ${{steps.run-benchmark-main.outputs.success}} | ||
if [[ "${{steps.run-benchmark-main.outputs.success}}" == "0" ]]; then | ||
exit 1 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
name: Light Benchmark | ||
|
||
on: | ||
pull_request: | ||
paths: | ||
- 'pyriemann_qiskit/**' | ||
- 'examples/**' | ||
- '.github/workflows/light_benchmark_nch_qaoacv.yml' | ||
- 'benchmarks/light_benchmark_qaoacv.py' | ||
- 'setup.py' | ||
|
||
jobs: | ||
light_benchmark_nch_qaoa: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Cache dependencies | ||
uses: actions/cache@v4 | ||
with: | ||
path: ~/.cache/pip | ||
key: light_benchmark_nch_qaoacv.yml | ||
- name: Install dependencies | ||
run: | | ||
pip install .[docs] | ||
- name: Run benchmark script (PR) | ||
id: run-benchmark-pr | ||
run: | | ||
python benchmarks/light_benchmark_nch_qaoacv.py pr | ||
- uses: actions/checkout@v4 | ||
with: | ||
ref: 'main' | ||
- name: Install dependencies | ||
run: | | ||
pip install .[docs] | ||
- name: Run benchmark script (main) | ||
id: run-benchmark-main | ||
run: | | ||
python benchmarks/light_benchmark_nch_qaoacv.py ${{steps.run-benchmark-pr.outputs.NCH_MIN_HULL_QAOACV}} | ||
- name: Compare performance | ||
run: | | ||
echo ${{steps.run-benchmark-main.outputs.success}} | ||
if [[ "${{steps.run-benchmark-main.outputs.success}}" == "0" ]]; then | ||
exit 1 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
""" | ||
==================================================================== | ||
Light Benchmark | ||
==================================================================== | ||
Common script to run light benchmarks | ||
""" | ||
# Author: Gregoire Cattan | ||
# Modified from plot_classify_P300_bi.py of pyRiemann | ||
# License: BSD (3-clause) | ||
|
||
import sys | ||
import warnings | ||
|
||
from moabb import set_log_level | ||
from moabb.datasets import bi2012 | ||
from moabb.paradigms import P300 | ||
from sklearn.metrics import balanced_accuracy_score | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.preprocessing import LabelEncoder | ||
|
||
|
||
print(__doc__) | ||
|
||
############################################################################## | ||
# getting rid of the warnings about the future | ||
warnings.simplefilter(action="ignore", category=FutureWarning) | ||
warnings.simplefilter(action="ignore", category=RuntimeWarning) | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
set_log_level("info") | ||
|
||
############################################################################## | ||
# Prepare data | ||
# ------------- | ||
# | ||
############################################################################## | ||
|
||
|
||
def _set_output(key: str, value: str): | ||
print(f"::set-output name={key}::{value}") # noqa: E231 | ||
|
||
|
||
def run(pipelines): | ||
paradigm = P300(resample=128) | ||
|
||
dataset = bi2012() # MOABB provides several other P300 datasets | ||
|
||
X, y, _ = paradigm.get_data(dataset, subjects=[1]) | ||
|
||
# Reduce the dataset size for Ci | ||
_, X, _, y = train_test_split(X, y, test_size=0.7, random_state=42, stratify=y) | ||
|
||
y = LabelEncoder().fit_transform(y) | ||
|
||
# Separate into train and test | ||
X_train, X_test, y_train, y_test = train_test_split( | ||
X, y, test_size=0.33, random_state=42, stratify=y | ||
) | ||
|
||
# Compute scores | ||
scores = {} | ||
|
||
for key, pipeline in pipelines.items(): | ||
pipeline.fit(X_train, y_train) | ||
y_pred = pipeline.predict(X_test) | ||
score = balanced_accuracy_score(y_test, y_pred) | ||
scores[key] = score | ||
|
||
print("Scores: ", scores) | ||
|
||
# Compare scores between PR and main branches | ||
is_pr = sys.argv[1] == "pr" | ||
|
||
if is_pr: | ||
for key, score in scores.items(): | ||
_set_output(key, score) | ||
else: | ||
success = True | ||
i = 0 | ||
for key, score in scores.items(): | ||
i = i + 1 | ||
pr_score = sys.argv[i] | ||
pr_score_trun = int(float(pr_score) * 100) | ||
score_trun = int(score * 100) | ||
better_pr_score = pr_score_trun >= score_trun | ||
success = success and better_pr_score | ||
print( | ||
f"{key}: {pr_score_trun} (PR) >= {score_trun} (main): {better_pr_score}" | ||
) | ||
_set_output("success", "1" if success else "0") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.