Skip to content

Commit

Permalink
fix/quantum-art-example (#184)
Browse files Browse the repository at this point in the history
* fix typo

* rename get_qiskit_dataset to generate_qiskit_dataset
and get_linearly_separable_dataset to generate_linearly_separable_dataset

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* lint

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
gcattan and pre-commit-ci[bot] authored Oct 2, 2023
1 parent aa7f15a commit 9a43800
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 16 deletions.
4 changes: 2 additions & 2 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ Datasets
:toctree: generated/

get_mne_sample
get_linearly_separable_dataset
get_qiskit_dataset
generate_linearly_separable_dataset
generate_qiskit_dataset
get_feature_dimension
MockDataset

Expand Down
9 changes: 6 additions & 3 deletions examples/toys_dataset/plot_classifier_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles
from sklearn.svm import SVC
from pyriemann_qiskit.datasets import get_linearly_separable_dataset, get_qiskit_dataset
from pyriemann_qiskit.datasets import (
generate_linearly_separable_dataset,
generate_qiskit_dataset,
)
from pyriemann_qiskit.classification import QuanticSVM

# uncomment to run comparison with QuanticVQC (disabled for CI/CD)
Expand Down Expand Up @@ -51,8 +54,8 @@
datasets = [
make_moons(noise=0.3, random_state=0),
make_circles(noise=0.2, factor=0.5, random_state=1),
get_linearly_separable_dataset(),
get_qiskit_dataset(),
generate_linearly_separable_dataset(),
generate_qiskit_dataset(),
]

figure = plt.figure(figsize=(15, 9))
Expand Down
10 changes: 5 additions & 5 deletions examples/toys_dataset/plot_quantum_art_vqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from pyriemann_qiskit.utils.hyper_params_factory import gen_two_local
import matplotlib.pyplot as plt
from pyriemann_qiskit.datasets import get_linearly_separable_dataset
from pyriemann_qiskit.datasets import generate_linearly_separable_dataset
from pyriemann_qiskit.classification import QuanticVQC
from pyriemann_qiskit.visualization import weights_spiral

Expand All @@ -25,9 +25,9 @@
# the variational quantum circuit which is used by VQC.
#
# The idea is simple :
# - We initialize a VQC with different number of parameters and number of samples
# - We train the VQC a couple of time and we store the fitted weights.
# - We compute variability of the weight and display it in a fashion way.
# - We initialize a VQC with different number of parameters and number of samples.
# - We train the VQC a couple of times and we store the fitted weights.
# - We compute variability of the weight and display it in a fashionable way.

# Let's start by defining some plot area.
fig, axes = plt.subplots(2, 2)
Expand All @@ -42,7 +42,7 @@
vqc = QuanticVQC(gen_var_form=gen_two_local(reps=n_reps))

# Get data. We will use a toy dataset here.
X, y = get_linearly_separable_dataset(n_samples=n_samples)
X, y = generate_linearly_separable_dataset(n_samples=n_samples)

# Compute and display weight variability after training
axe = axes[i, j]
Expand Down
8 changes: 4 additions & 4 deletions pyriemann_qiskit/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from .utils import (
get_mne_sample,
get_linearly_separable_dataset,
get_qiskit_dataset,
generate_linearly_separable_dataset,
generate_qiskit_dataset,
get_feature_dimension,
MockDataset,
)


__all__ = [
"get_mne_sample",
"get_linearly_separable_dataset",
"get_qiskit_dataset",
"generate_linearly_separable_dataset",
"generate_qiskit_dataset",
"get_feature_dimension",
"MockDataset",
]
7 changes: 5 additions & 2 deletions pyriemann_qiskit/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def get_mne_sample(n_trials=10, include_auditory=False):
return X, y


def get_qiskit_dataset(n_samples=30):
def generate_qiskit_dataset(n_samples=30):
"""Return qiskit dataset.
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.
Rename from `get_qiskit_dataset` to `generate_qiskit_dataset`.
Parameters
----------
Expand All @@ -132,14 +133,16 @@ def get_qiskit_dataset(n_samples=30):
return (X, y)


def get_linearly_separable_dataset(n_samples=100):
def generate_linearly_separable_dataset(n_samples=100):
"""Return a linearly separable dataset.
Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.
Rename from `get_linearly_separable_dataset` to
`generate_linearly_separable_dataset`.
Parameters
----------
Expand Down

0 comments on commit 9a43800

Please sign in to comment.