Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/quantum-art-example #184

Merged
merged 4 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ Datasets
:toctree: generated/

get_mne_sample
get_linearly_separable_dataset
get_qiskit_dataset
generate_linearly_separable_dataset
generate_qiskit_dataset
get_feature_dimension
MockDataset

Expand Down
9 changes: 6 additions & 3 deletions examples/toys_dataset/plot_classifier_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import make_moons, make_circles
from sklearn.svm import SVC
from pyriemann_qiskit.datasets import get_linearly_separable_dataset, get_qiskit_dataset
from pyriemann_qiskit.datasets import (
generate_linearly_separable_dataset,
generate_qiskit_dataset,
)
from pyriemann_qiskit.classification import QuanticSVM

# uncomment to run comparison with QuanticVQC (disabled for CI/CD)
Expand Down Expand Up @@ -51,8 +54,8 @@
datasets = [
make_moons(noise=0.3, random_state=0),
make_circles(noise=0.2, factor=0.5, random_state=1),
get_linearly_separable_dataset(),
get_qiskit_dataset(),
generate_linearly_separable_dataset(),
generate_qiskit_dataset(),
]

figure = plt.figure(figsize=(15, 9))
Expand Down
10 changes: 5 additions & 5 deletions examples/toys_dataset/plot_quantum_art_vqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from pyriemann_qiskit.utils.hyper_params_factory import gen_two_local
import matplotlib.pyplot as plt
from pyriemann_qiskit.datasets import get_linearly_separable_dataset
from pyriemann_qiskit.datasets import generate_linearly_separable_dataset
from pyriemann_qiskit.classification import QuanticVQC
from pyriemann_qiskit.visualization import weights_spiral

Expand All @@ -25,9 +25,9 @@
# the variational quantum circuit which is used by VQC.
#
# The idea is simple :
# - We initialize a VQC with different number of parameters and number of samples
# - We train the VQC a couple of time and we store the fitted weights.
# - We compute variability of the weight and display it in a fashion way.
# - We initialize a VQC with different number of parameters and number of samples.
# - We train the VQC a couple of times and we store the fitted weights.
# - We compute variability of the weight and display it in a fashionable way.

# Let's start by defining some plot area.
fig, axes = plt.subplots(2, 2)
Expand All @@ -42,7 +42,7 @@
vqc = QuanticVQC(gen_var_form=gen_two_local(reps=n_reps))

# Get data. We will use a toy dataset here.
X, y = get_linearly_separable_dataset(n_samples=n_samples)
X, y = generate_linearly_separable_dataset(n_samples=n_samples)

# Compute and display weight variability after training
axe = axes[i, j]
Expand Down
8 changes: 4 additions & 4 deletions pyriemann_qiskit/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from .utils import (
get_mne_sample,
get_linearly_separable_dataset,
get_qiskit_dataset,
generate_linearly_separable_dataset,
generate_qiskit_dataset,
get_feature_dimension,
MockDataset,
)


__all__ = [
"get_mne_sample",
"get_linearly_separable_dataset",
"get_qiskit_dataset",
"generate_linearly_separable_dataset",
"generate_qiskit_dataset",
"get_feature_dimension",
"MockDataset",
]
7 changes: 5 additions & 2 deletions pyriemann_qiskit/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def get_mne_sample(n_trials=10, include_auditory=False):
return X, y


def get_qiskit_dataset(n_samples=30):
def generate_qiskit_dataset(n_samples=30):
"""Return qiskit dataset.

Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.
Rename from `get_qiskit_dataset` to `generate_qiskit_dataset`.

Parameters
----------
Expand All @@ -132,14 +133,16 @@ def get_qiskit_dataset(n_samples=30):
return (X, y)


def get_linearly_separable_dataset(n_samples=100):
def generate_linearly_separable_dataset(n_samples=100):
"""Return a linearly separable dataset.

Notes
-----
.. versionadded:: 0.0.1
.. versionchanged:: 0.1.0
Added `n_samples` parameter.
Rename from `get_linearly_separable_dataset` to
`generate_linearly_separable_dataset`.

Parameters
----------
Expand Down