diff --git a/test/layers/test_rotgate.py b/test/layers/test_rotgate.py new file mode 100644 index 00000000..beec6d72 --- /dev/null +++ b/test/layers/test_rotgate.py @@ -0,0 +1,54 @@ +import torchquantum as tq +import qiskit +from qiskit import Aer, execute + +from torchquantum.util import ( + switch_little_big_endian_matrix, + find_global_phase, +) + +from qiskit.circuit.library import GR, GRX, GRY +import numpy as np + +all_pairs = [ + # {"qiskit": GR, "tq": tq.layer.GlobalR, "params": 2}, + {"qiskit": GRX, "tq": tq.layer.GlobalRX, "params": 1}, + {"qiskit": GRY, "tq": tq.layer.GlobalRY, "params": 1}, +] + +ITERATIONS = 10 + +# test each pair +for pair in all_pairs: + # test 2-5 wires + for num_wires in range(2, 5): + # try multiple random parameters + for _ in range(ITERATIONS): + # generate random parameters + params = [ + np.random.uniform(-2 * np.pi, 2 * np.pi) for _ in range(pair["params"]) + ] + + # create the qiskit circuit + qiskit_circuit = pair["qiskit"](num_wires, *params) + + # get the unitary from qiskit + backend = Aer.get_backend("unitary_simulator") + result = execute(qiskit_circuit, backend).result() + unitary_qiskit = result.get_unitary(qiskit_circuit) + + # create tq circuit + qdev = tq.QuantumDevice(num_wires) + tq_circuit = pair["tq"](num_wires, *params) + tq_circuit(qdev) + + # get the unitary from tq + unitary_tq = tq_circuit.get_unitary(qdev) + unitary_tq = switch_little_big_endian_matrix(unitary_tq.data.numpy()) + + # phase? + phase = find_global_phase(unitary_tq, unitary_qiskit, 1e-4) + + assert np.allclose( + unitary_tq * phase, unitary_qiskit, atol=1e-6 + ), f"{pair} not equal with {params=}!" diff --git a/torchquantum/layer/general.py b/torchquantum/layer/general.py index 0a61aa29..50b2b0c1 100644 --- a/torchquantum/layer/general.py +++ b/torchquantum/layer/general.py @@ -31,12 +31,12 @@ Op2QAllLayer, RandomOp1All, ) +from torchquantum.operator.operators import R __all__ = [ "GlobalR", "GlobalRX", "GlobalRY", - "GlobalRZ", ] @@ -51,17 +51,13 @@ def __init__( ): """Create the layer""" super().__init__() - self.ops_all = tq.QuantumModuleList() self.n_wires = n_wires - self.ops_list = [ - {"name": "rot", "params": [phi, theta, 0], "wires": k} - for k in range(self.n_wires) - ] + self.params = torch.tensor([[theta, phi]]) @tq.static_support - def forward(self, q_device): - qmodule = tq.QuantumModule.from_op_history(self.ops_list) - qmodule(q_device) + def forward(self, q_device, x=None): + for k in range(self.n_wires): + R()(q_device, wires=k, params=self.params) class GlobalRX(GlobalR):