Skip to content

Commit

Permalink
fixed aliasing issue and added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
01110011011101010110010001101111 committed Nov 18, 2023
1 parent a24f0cf commit 43109a9
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 9 deletions.
54 changes: 54 additions & 0 deletions test/layers/test_rotgate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torchquantum as tq
import qiskit
from qiskit import Aer, execute

from torchquantum.util import (
switch_little_big_endian_matrix,
find_global_phase,
)

from qiskit.circuit.library import GR, GRX, GRY
import numpy as np

all_pairs = [
# {"qiskit": GR, "tq": tq.layer.GlobalR, "params": 2},
{"qiskit": GRX, "tq": tq.layer.GlobalRX, "params": 1},
{"qiskit": GRY, "tq": tq.layer.GlobalRY, "params": 1},
]

ITERATIONS = 10

# test each pair
for pair in all_pairs:
# test 2-5 wires
for num_wires in range(2, 5):
# try multiple random parameters
for _ in range(ITERATIONS):
# generate random parameters
params = [
np.random.uniform(-2 * np.pi, 2 * np.pi) for _ in range(pair["params"])
]

# create the qiskit circuit
qiskit_circuit = pair["qiskit"](num_wires, *params)

# get the unitary from qiskit
backend = Aer.get_backend("unitary_simulator")
result = execute(qiskit_circuit, backend).result()
unitary_qiskit = result.get_unitary(qiskit_circuit)

# create tq circuit
qdev = tq.QuantumDevice(num_wires)
tq_circuit = pair["tq"](num_wires, *params)
tq_circuit(qdev)

# get the unitary from tq
unitary_tq = tq_circuit.get_unitary(qdev)
unitary_tq = switch_little_big_endian_matrix(unitary_tq.data.numpy())

# phase?
phase = find_global_phase(unitary_tq, unitary_qiskit, 1e-4)

assert np.allclose(
unitary_tq * phase, unitary_qiskit, atol=1e-6
), f"{pair} not equal with {params=}!"
14 changes: 5 additions & 9 deletions torchquantum/layer/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
Op2QAllLayer,
RandomOp1All,
)
from torchquantum.operator.operators import R

__all__ = [
"GlobalR",
"GlobalRX",
"GlobalRY",
"GlobalRZ",
]


Expand All @@ -51,17 +51,13 @@ def __init__(
):
"""Create the layer"""
super().__init__()
self.ops_all = tq.QuantumModuleList()
self.n_wires = n_wires
self.ops_list = [
{"name": "rot", "params": [phi, theta, 0], "wires": k}
for k in range(self.n_wires)
]
self.params = torch.tensor([[theta, phi]])

@tq.static_support
def forward(self, q_device):
qmodule = tq.QuantumModule.from_op_history(self.ops_list)
qmodule(q_device)
def forward(self, q_device, x=None):
for k in range(self.n_wires):
R()(q_device, wires=k, params=self.params)


class GlobalRX(GlobalR):
Expand Down

0 comments on commit 43109a9

Please sign in to comment.