From 9a43800c34ae05f218c12f1a8882c0589deb6a7b Mon Sep 17 00:00:00 2001 From: gcattan Date: Mon, 2 Oct 2023 21:57:28 +0200 Subject: [PATCH] fix/quantum-art-example (#184) * fix typo * rename get_qiskit_dataset to generate_qiskit_dataset and get_linearly_separable_dataset to generate_linearly_separable_dataset * [pre-commit.ci] auto fixes from pre-commit.com hooks * lint --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- doc/api.rst | 4 ++-- examples/toys_dataset/plot_classifier_comparison.py | 9 ++++++--- examples/toys_dataset/plot_quantum_art_vqc.py | 10 +++++----- pyriemann_qiskit/datasets/__init__.py | 8 ++++---- pyriemann_qiskit/datasets/utils.py | 7 +++++-- 5 files changed, 22 insertions(+), 16 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index f852a8e7..0fdf5c52 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -132,8 +132,8 @@ Datasets :toctree: generated/ get_mne_sample - get_linearly_separable_dataset - get_qiskit_dataset + generate_linearly_separable_dataset + generate_qiskit_dataset get_feature_dimension MockDataset diff --git a/examples/toys_dataset/plot_classifier_comparison.py b/examples/toys_dataset/plot_classifier_comparison.py index 94b43817..73dc5a1b 100644 --- a/examples/toys_dataset/plot_classifier_comparison.py +++ b/examples/toys_dataset/plot_classifier_comparison.py @@ -19,7 +19,10 @@ from sklearn.preprocessing import StandardScaler from sklearn.datasets import make_moons, make_circles from sklearn.svm import SVC -from pyriemann_qiskit.datasets import get_linearly_separable_dataset, get_qiskit_dataset +from pyriemann_qiskit.datasets import ( + generate_linearly_separable_dataset, + generate_qiskit_dataset, +) from pyriemann_qiskit.classification import QuanticSVM # uncomment to run comparison with QuanticVQC (disabled for CI/CD) @@ -51,8 +54,8 @@ datasets = [ make_moons(noise=0.3, random_state=0), make_circles(noise=0.2, factor=0.5, random_state=1), - get_linearly_separable_dataset(), - get_qiskit_dataset(), + generate_linearly_separable_dataset(), + generate_qiskit_dataset(), ] figure = plt.figure(figsize=(15, 9)) diff --git a/examples/toys_dataset/plot_quantum_art_vqc.py b/examples/toys_dataset/plot_quantum_art_vqc.py index 472c05a3..09c25905 100644 --- a/examples/toys_dataset/plot_quantum_art_vqc.py +++ b/examples/toys_dataset/plot_quantum_art_vqc.py @@ -13,7 +13,7 @@ from pyriemann_qiskit.utils.hyper_params_factory import gen_two_local import matplotlib.pyplot as plt -from pyriemann_qiskit.datasets import get_linearly_separable_dataset +from pyriemann_qiskit.datasets import generate_linearly_separable_dataset from pyriemann_qiskit.classification import QuanticVQC from pyriemann_qiskit.visualization import weights_spiral @@ -25,9 +25,9 @@ # the variational quantum circuit which is used by VQC. # # The idea is simple : -# - We initialize a VQC with different number of parameters and number of samples -# - We train the VQC a couple of time and we store the fitted weights. -# - We compute variability of the weight and display it in a fashion way. +# - We initialize a VQC with different number of parameters and number of samples. +# - We train the VQC a couple of times and we store the fitted weights. +# - We compute variability of the weight and display it in a fashionable way. # Let's start by defining some plot area. fig, axes = plt.subplots(2, 2) @@ -42,7 +42,7 @@ vqc = QuanticVQC(gen_var_form=gen_two_local(reps=n_reps)) # Get data. We will use a toy dataset here. - X, y = get_linearly_separable_dataset(n_samples=n_samples) + X, y = generate_linearly_separable_dataset(n_samples=n_samples) # Compute and display weight variability after training axe = axes[i, j] diff --git a/pyriemann_qiskit/datasets/__init__.py b/pyriemann_qiskit/datasets/__init__.py index 3af9f475..5361940b 100644 --- a/pyriemann_qiskit/datasets/__init__.py +++ b/pyriemann_qiskit/datasets/__init__.py @@ -1,7 +1,7 @@ from .utils import ( get_mne_sample, - get_linearly_separable_dataset, - get_qiskit_dataset, + generate_linearly_separable_dataset, + generate_qiskit_dataset, get_feature_dimension, MockDataset, ) @@ -9,8 +9,8 @@ __all__ = [ "get_mne_sample", - "get_linearly_separable_dataset", - "get_qiskit_dataset", + "generate_linearly_separable_dataset", + "generate_qiskit_dataset", "get_feature_dimension", "MockDataset", ] diff --git a/pyriemann_qiskit/datasets/utils.py b/pyriemann_qiskit/datasets/utils.py index fe634518..93b13133 100644 --- a/pyriemann_qiskit/datasets/utils.py +++ b/pyriemann_qiskit/datasets/utils.py @@ -98,7 +98,7 @@ def get_mne_sample(n_trials=10, include_auditory=False): return X, y -def get_qiskit_dataset(n_samples=30): +def generate_qiskit_dataset(n_samples=30): """Return qiskit dataset. Notes @@ -106,6 +106,7 @@ def get_qiskit_dataset(n_samples=30): .. versionadded:: 0.0.1 .. versionchanged:: 0.1.0 Added `n_samples` parameter. + Rename from `get_qiskit_dataset` to `generate_qiskit_dataset`. Parameters ---------- @@ -132,7 +133,7 @@ def get_qiskit_dataset(n_samples=30): return (X, y) -def get_linearly_separable_dataset(n_samples=100): +def generate_linearly_separable_dataset(n_samples=100): """Return a linearly separable dataset. Notes @@ -140,6 +141,8 @@ def get_linearly_separable_dataset(n_samples=100): .. versionadded:: 0.0.1 .. versionchanged:: 0.1.0 Added `n_samples` parameter. + Rename from `get_linearly_separable_dataset` to + `generate_linearly_separable_dataset`. Parameters ----------