Skip to content

Commit

Permalink
Take joint expval in Measure functions, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
yannick-couzinie committed Nov 20, 2024
1 parent 6ff80a8 commit 1da8350
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
57 changes: 57 additions & 0 deletions test/measurement/test_measuremultipletimes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import torchquantum as tq


def test_non_trivial_pauli_expectation():
class TestCircuit(tq.QuantumModule):
def __init__(self):
super().__init__()

self.meas = tq.measurement.MeasureMultipleTimes([
{'wires': range(2), 'observables': 'xx'},
{'wires': range(2), 'observables': 'yy'},
{'wires': range(2), 'observables': 'zz'},
])

def forward(self, qdev: tq.QuantumDevice):
"""
Prepare and measure the expexctation value of the state
exp(-i pi/8)/sqrt(2) * (cos pi/12,
-i sin pi/12,
-i sin pi/12 * exp(i pi/4),
-i sin pi/12 * exp(i pi/4))
"""
# prepare bell state
tq.h(qdev, 0)
tq.cnot(qdev, [0, 1])

# add some phases
tq.rz(qdev, wires=0, params=np.pi / 4)
tq.rx(qdev, wires=1, params=np.pi / 6)
return self.meas(qdev)

test_circuit = TestCircuit()
qdev = tq.QuantumDevice(bsz=1, n_wires=2) # Batch size 1 for testing

# Run the circuit
meas_results = test_circuit(qdev)[0]

# analytical results for XX, YY, ZZ expval respectively
expval_xx = np.cos(np.pi / 4)
expval_yy = -np.cos(np.pi / 4) * np.cos(np.pi / 6)
expval_zz = np.cos(np.pi / 6)

atol = 1e-6

assert np.isclose(meas_results[0].item(), expval_xx, atol=atol), \
f"Expected {expval_xx}, got {meas_results[0].item()}"
assert np.isclose(meas_results[1].item(), expval_yy, atol=atol), \
f"Expected {expval_yy}, got {meas_results[1].item()}"
assert np.isclose(meas_results[2].item(), expval_zz, atol=atol), \
f"Expected {expval_zz}, got {meas_results[2].item()}"

print("Test passed!")


if __name__ == "__main__":
test_non_trivial_pauli_expectation()
20 changes: 11 additions & 9 deletions torchquantum/measurement/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,15 +367,14 @@ def forward(self, qdev: tq.QuantumDevice):

observables = []
for wire in range(qdev.n_wires):
observables.append(tq.I())
observables.append("I")

for wire, observable in zip(layer["wires"], layer["observables"]):
observables[wire] = tq.op_name_dict[observable]()
observables[wire] = observable

res = expval(
res = expval_joint_analytical(
qdev_new,
wires=list(range(qdev.n_wires)),
observables=observables,
observable="".join(observables),
)

if self.v_c_reg_mapping is not None:
Expand All @@ -390,7 +389,8 @@ def forward(self, qdev: tq.QuantumDevice):
res = res[:, perm]
res_all.append(res)

return torch.cat(res_all)

return torch.stack(res_all, dim=-1)

def set_v_c_reg_mapping(self, mapping):
self.v_c_reg_mapping = mapping
Expand Down Expand Up @@ -421,7 +421,8 @@ def __init__(self, obs_list, v_c_reg_mapping=None):
)

def forward(self, qdev: tq.QuantumDevice):
res_all = self.measure_multiple_times(qdev).prod(-1)
# returns batch x len(obs_list) object, return sum
res_all = self.measure_multiple_times(qdev)

return res_all.sum(-1)

Expand Down Expand Up @@ -449,8 +450,9 @@ def __init__(self, obs_list, v_c_reg_mapping=None):
)

def forward(self, qdev: tq.QuantumDevice):
res_all = self.measure_multiple_times(qdev).prod(-1)

# returns batch x len(obs_list) object, return sum times coefficient
res_all = self.measure_multiple_times(qdev)

return (res_all * torch.tensor(self.obs_list[0]["coefficient"])).sum(-1)


Expand Down

0 comments on commit 1da8350

Please sign in to comment.