Skip to content

Commit

Permalink
Merge pull request #236 from GenericP3rson/test_branch
Browse files Browse the repository at this point in the history
Test Updates + Bug Fixes (Updated Version of #206)
  • Loading branch information
Hanrui-Wang authored Feb 21, 2024
2 parents fefb10b + e4233ff commit 497320c
Show file tree
Hide file tree
Showing 19 changed files with 201 additions and 95 deletions.
33 changes: 32 additions & 1 deletion .github/workflows/functional_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand All @@ -36,3 +36,34 @@ jobs:
- name: Test with pytest
run: |
python -m pytest -m "not skip"
- name: Install TorchQuantum
run: |
pip install --editable .
- name: Test Examples
run: |
python3 examples/qubit_rotation/qubit_rotation.py --epochs 1
python3 examples/vqe/vqe.py --epochs 1 --steps_per_epoch 1
python3 examples/train_unitary_prep/train_unitary_prep.py --epochs 1
python3 examples/train_state_prep/train_state_prep.py --epochs 1
python3 examples/superdense_coding/superdense_coding_torchquantum.py
python3 examples/regression/run_regression.py --epochs 1
python3 examples/param_shift_onchip_training/param_shift.py
python3 examples/mnist/mnist_2qubit_4class.py --epochs 1
python3 examples/hadamard_grad/circ.py
python3 examples/encoder_examples/encoder_8x2ry.py
python3 examples/converter_tq_qiskit/convert.py
python3 examples/amplitude_encoding_mnist/mnist_new.py --epochs 1
python3 examples/amplitude_encoding_mnist/mnist_example.py --epochs 1
python3 examples/PauliSumOp/pauli_sum_op.py
python3 examples/regression/new_run_regression.py --epochs 1
python3 examples/quanvolution/quanvolution_trainable_quantum_layer.py --epochs 1
python3 examples/grover/grover_example_sudoku.py
python3 examples/param_shift_onchip_training/param_shift.py
python3 examples/quanvolution/quanvolution.py --epochs 1
python3 examples/quantum_lstm/qlstm.py --epochs 1
python3 examples/qaoa/max_cut_backprop.py --steps 1
python3 examples/optimal_control/optimal_control.py --epochs 1
python3 examples/optimal_control/optimal_control_gaussian.py --epochs 1
python3 examples/optimal_control/optimal_control_multi_qubit.py --epochs 1
python3 examples/save_load_example/save_load.py
python3 examples/mnist/mnist.py --epochs 1
4 changes: 2 additions & 2 deletions examples/grover/grover_example_sudoku.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"""

import torchquantum as tq
from torchquantum.algorithms import Grover
from torchquantum.algorithm import Grover


# To simplify the process, we can compile this set of comparisons into a list of clauses for convenience.
Expand Down Expand Up @@ -90,4 +90,4 @@ def XOR(input0, input1, output):
print("b = ", key[1])
print("c = ", key[2])
print("d = ", key[3])
print("")
print("")
66 changes: 34 additions & 32 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def main():
"--static", action="store_true", help="compute with " "static mode"
)
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
parser.add_argument("--qiskit-simulation", action="store_true", help="run on a real quantum computer")
parser.add_argument(
"--wires-per-block", type=int, default=2, help="wires per block int static mode"
)
Expand Down Expand Up @@ -243,38 +244,39 @@ def main():
# test
valid_test(dataflow, "test", model, device, qiskit=False)

# run on Qiskit simulator and real Quantum Computers
try:
from qiskit import IBMQ
from torchquantum.plugin import QiskitProcessor

# firstly perform simulate
print(f"\nTest with Qiskit Simulator")
processor_simulation = QiskitProcessor(use_real_qc=False)
model.set_qiskit_processor(processor_simulation)
valid_test(dataflow, "test", model, device, qiskit=True)

# then try to run on REAL QC
backend_name = "ibmq_lima"
print(f"\nTest on Real Quantum Computer {backend_name}")
# Please specify your own hub group and project if you have the
# IBMQ premium plan to access more machines.
processor_real_qc = QiskitProcessor(
use_real_qc=True,
backend_name=backend_name,
hub="ibm-q",
group="open",
project="main",
)
model.set_qiskit_processor(processor_real_qc)
valid_test(dataflow, "test", model, device, qiskit=True)
except ImportError:
print(
"Please install qiskit, create an IBM Q Experience Account and "
"save the account token according to the instruction at "
"'https://github.com/Qiskit/qiskit-ibmq-provider', "
"then try again."
)
if args.qiskit_simulation:
# run on Qiskit simulator and real Quantum Computers
try:
from qiskit import IBMQ
from torchquantum.plugin import QiskitProcessor

# firstly perform simulate
print(f"\nTest with Qiskit Simulator")
processor_simulation = QiskitProcessor(use_real_qc=False)
model.set_qiskit_processor(processor_simulation)
valid_test(dataflow, "test", model, device, qiskit=True)

# then try to run on REAL QC
backend_name = "ibmq_lima"
print(f"\nTest on Real Quantum Computer {backend_name}")
# Please specify your own hub group and project if you have the
# IBMQ premium plan to access more machines.
processor_real_qc = QiskitProcessor(
use_real_qc=True,
backend_name=backend_name,
hub="ibm-q",
group="open",
project="main",
)
model.set_qiskit_processor(processor_real_qc)
valid_test(dataflow, "test", model, device, qiskit=True)
except ImportError:
print(
"Please install qiskit, create an IBM Q Experience Account and "
"save the account token according to the instruction at "
"'https://github.com/Qiskit/qiskit-ibmq-provider', "
"then try again."
)


if __name__ == "__main__":
Expand Down
19 changes: 15 additions & 4 deletions examples/optimal_control/optimal_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,22 @@
import torch.optim as optim

import torchquantum as tq
import pdb
import argparse
import numpy as np

if __name__ == "__main__":
pdb.set_trace()
parser = argparse.ArgumentParser()
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
parser.add_argument(
"--epochs", type=int, default=1000, help="number of training epochs"
)

args = parser.parse_args()

if args.pdb:
import pdb
pdb.set_trace()

# target_unitary = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64)
theta = 0.6
target_unitary = torch.tensor(
Expand All @@ -41,11 +52,11 @@
dtype=torch.complex64,
)

pulse = tq.QuantumPulseDirect(n_steps=4, hamil=[[0, 1], [1, 0]])
pulse = tq.pulse.QuantumPulseDirect(n_steps=4, hamil=[[0, 1], [1, 0]])

optimizer = optim.Adam(params=pulse.parameters(), lr=5e-3)

for k in range(1000):
for k in range(args.epochs):
# loss = (abs(pulse.get_unitary() - target_unitary)**2).sum()
loss = (
1
Expand Down
19 changes: 15 additions & 4 deletions examples/optimal_control/optimal_control_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,22 @@
import torch.optim as optim

import torchquantum as tq
import pdb
import argparse
import numpy as np

if __name__ == "__main__":
pdb.set_trace()
parser = argparse.ArgumentParser()
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
parser.add_argument(
"--epochs", type=int, default=1000, help="number of training epochs"
)

args = parser.parse_args()

if args.pdb:
import pdb
pdb.set_trace()

# target_unitary = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64)
theta = 1.1
target_unitary = torch.tensor(
Expand All @@ -41,11 +52,11 @@
dtype=torch.complex64,
)

pulse = tq.QuantumPulseGaussian(hamil=[[0, 1], [1, 0]])
pulse = tq.pulse.QuantumPulseGaussian(hamil=[[0, 1], [1, 0]])

optimizer = optim.Adam(params=pulse.parameters(), lr=5e-3)

for k in range(1000):
for k in range(args.epochs):
# loss = (abs(pulse.get_unitary() - target_unitary)**2).sum()
loss = (
1
Expand Down
23 changes: 17 additions & 6 deletions examples/optimal_control/optimal_control_multi_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,22 @@
import torch.optim as optim

import torchquantum as tq
import pdb
import argparse
import numpy as np

if __name__ == "__main__":
pdb.set_trace()
parser = argparse.ArgumentParser()
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
parser.add_argument(
"--epochs", type=int, default=1000, help="number of training epochs"
)

args = parser.parse_args()

if args.pdb:
import pdb
pdb.set_trace()

# target_unitary = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex64)
theta = 0.6
target_unitary = torch.tensor(
Expand All @@ -43,9 +54,9 @@
dtype=torch.complex64,
)

pulse_q0 = tq.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
pulse_q1 = tq.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
pulse_q01 = tq.QuantumPulseDirect(
pulse_q0 = tq.pulse.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
pulse_q1 = tq.pulse.QuantumPulseDirect(n_steps=10, hamil=[[0, 1], [1, 0]])
pulse_q01 = tq.pulse.QuantumPulseDirect(
n_steps=10,
hamil=[
[1, 0, 0, 0],
Expand All @@ -62,7 +73,7 @@
lr=5e-3,
)

for k in range(1000):
for k in range(args.epochs):
u_0 = pulse_q0.get_unitary()
u_1 = pulse_q1.get_unitary()
u_01 = pulse_q01.get_unitary()
Expand Down
9 changes: 8 additions & 1 deletion examples/qaoa/max_cut_backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import random
import numpy as np
import argparse

from torchquantum.functional import mat_dict

Expand Down Expand Up @@ -172,6 +173,12 @@ def backprop_optimize(model, n_steps=100, lr=0.1):


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--steps", type=int, default=300, help="number of steps"
)
args = parser.parse_args()

# create a input_graph
input_graph = [(0, 1), (0, 3), (1, 2), (2, 3)]
n_wires = 4
Expand All @@ -184,7 +191,7 @@ def main():
# print("The circuit is", circ.draw(output="mpl"))
# circ.draw(output="mpl")
# use backprop
backprop_optimize(model, n_steps=300, lr=0.01)
backprop_optimize(model, n_steps=args.steps, lr=0.01)
# use parameter shift rule
# param_shift_optimize(model, n_steps=500, step_size=100000)

Expand Down
22 changes: 17 additions & 5 deletions examples/quantum_lstm/qlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import torch.nn as nn
import torchquantum as tq
import torchquantum.functional as tqf
import argparse


class QLSTM(nn.Module):
Expand Down Expand Up @@ -358,6 +359,19 @@ def plot_history(history_classical, history_quantum):
plt.show()

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
parser.add_argument("--display", action="store_true", help="display results with matplotlib")
parser.add_argument(
"--epochs", type=int, default=300, help="number of training epochs"
)

args = parser.parse_args()

if args.pdb:
import pdb
pdb.set_trace()

tag_to_ix = {"DET": 0, "NN": 1, "V": 2} # Assign each tag with a unique index
ix_to_tag = {i:k for k,i in tag_to_ix.items()}

Expand All @@ -380,7 +394,7 @@ def main():

embedding_dim = 8
hidden_dim = 6
n_epochs = 300
n_epochs = args.epochs

model_classical = LSTMTagger(embedding_dim,
hidden_dim,
Expand All @@ -404,10 +418,8 @@ def main():

print_result(model_quantum, training_data, word_to_ix, ix_to_tag)

plot_history(history_classical, history_quantum)
if args.display:
plot_history(history_classical, history_quantum)

if __name__ == "__main__":
import pdb
pdb.set_trace()

main()
Loading

0 comments on commit 497320c

Please sign in to comment.