Skip to content

Commit

Permalink
Merge pull request #151 from dtlics/hadamard
Browse files Browse the repository at this point in the history
Hadamard
  • Loading branch information
Hanrui-Wang authored Jun 14, 2023
2 parents bb297ad + 92e7894 commit 094991a
Show file tree
Hide file tree
Showing 7 changed files with 440 additions and 61 deletions.
107 changes: 107 additions & 0 deletions examples/hadamard_grad/circ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torchquantum as tq

op_types=(
tq.Hadamard,
tq.SHadamard,
tq.PauliX,
tq.PauliY,
tq.PauliZ,
tq.S,
tq.T,
tq.SX,
tq.CNOT,
)

class Circ1(tq.QuantumModule):
def __init__(self):
super().__init__()
self.n_wires = 4
self.gate1 = tq.operator.OpPauliExp(coeffs=[1.0], paulis=["YXIX"], theta=0.5, trainable=True)

def forward(self):
qdev = tq.QuantumDevice(
n_wires=self.n_wires, bsz=1, device='cpu', record_op=True
)
self.gate1(qdev, wires=[0, 1, 2, 3])
expval = tq.measurement.expval_joint_analytical(qdev, observable="ZZZZ")

return expval, qdev

class Circ2(tq.QuantumModule):
def __init__(self):
super().__init__()
self.n_wires = 4
self.gate1 = tq.operator.OpPauliExp(coeffs=[1.0, 0.5], paulis=["YXIX", "ZIXZ"], theta=0.5, trainable=True)

def forward(self):
qdev = tq.QuantumDevice(
n_wires=self.n_wires, bsz=1, device='cpu', record_op=True
)
self.gate1(qdev, wires=[0, 1, 2, 3])
expval = tq.measurement.expval_joint_analytical(qdev, observable="ZZZZ")

return expval, qdev

class Circ3(tq.QuantumModule):
def __init__(self):
super().__init__()
self.n_wires = 4
self.gate1 = tq.operator.OpPauliExp(coeffs=[1.0], paulis=["XIII"], theta=0.5, trainable=True) # Note this is RX gate
self.gate2 = tq.operator.OpPauliExp(coeffs=[0.722], paulis=["YXIX"], theta=0.5, trainable=True)
self.gate3 = tq.operator.OpPauliExp(coeffs=[1.0, 0.5], paulis=["YXIX", "ZIXZ"], theta=0.2, trainable=True)
self.gate4 = tq.operator.OpPauliExp(coeffs=[1.0], paulis=["YYYY"], theta=0.5, trainable=True)

self.random_layer1 = tq.RandomLayer(op_types=op_types, n_ops=50, wires=list(range(self.n_wires)))
self.random_layer2 = tq.RandomLayer(op_types=op_types, n_ops=50, wires=list(range(self.n_wires)))
self.random_layer3 = tq.RandomLayer(op_types=op_types, n_ops=50, wires=list(range(self.n_wires)))
self.random_layer4 = tq.RandomLayer(op_types=op_types, n_ops=50, wires=list(range(self.n_wires)))
self.random_layer5 = tq.RandomLayer(op_types=op_types, n_ops=50, wires=list(range(self.n_wires)))


def forward(self):
qdev = tq.QuantumDevice(
n_wires=self.n_wires, bsz=1, device='cpu', record_op=True
)

qdev.h(wires=1)
qdev.h(wires=3)
qdev.cnot(wires=[1, 0])
qdev.cnot(wires=[3, 2])

self.random_layer1(qdev)
self.gate1(qdev, wires=[0, 1, 2, 3])
self.random_layer2(qdev)
self.gate2(qdev, wires=[0, 1, 2, 3])
self.random_layer3(qdev)
self.gate3(qdev, wires=[0, 1, 2, 3])
self.random_layer4(qdev)
self.gate4(qdev, wires=[0, 1, 2, 3])
self.random_layer5(qdev)

expval = tq.measurement.expval_joint_analytical(qdev, observable="ZZZZ")

return expval, qdev


if __name__ == '__main__':

circ1 = Circ1()
expval1, qdev1 = circ1()
print('expval:')
print(expval1)
print('op_history:')
print(qdev1.op_history)

circ2 = Circ2()
expval2, qdev2 = circ2()
print('expval:')
print(expval2)
print('op_history:')
print(qdev2.op_history)

circ3 = Circ3()
expval3, qdev3 = circ3()
print('expval:')
print(expval3)
print('op_history:')
print(qdev3.op_history)
144 changes: 144 additions & 0 deletions examples/hadamard_grad/example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tutorial on Hadamard Test Based Gradient Estimation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set Up\n",
"\n",
"We define three types of example circuits in ./circ.py and here we use Circ1 as an example"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from hadamard_grad import hadamard_grad\n",
"from circ import Circ1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run Circuit\n",
"\n",
"We need to run the circuit first to record its operations in the q_device"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"expval:\n",
"tensor([0.8776], grad_fn=<SelectBackward0>)\n",
"op_history:\n",
"[{'name': 'OpPauliExp', 'wires': [0, 1, 2, 3], 'coeffs': [1.0], 'paulis': ['YXIX'], 'inverse': False, 'trainable': True, 'params': 0.5}]\n"
]
}
],
"source": [
"circ1 = Circ1()\n",
"expval1, qdev1 = circ1()\n",
"print('expval:')\n",
"print(expval1)\n",
"print('op_history:')\n",
"print(qdev1.op_history)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient Estimation\n",
"\n",
"To facilitate torchquantum v0.1.7, we need to manually extrac the following three information from the q_device and circuit:\n",
"\n",
"- op_history: The history of quantum operations applied on the q_device.\n",
"- n_wires: The number of wires (quantum bits) in the q_device.\n",
"- observable: The observable of the original circuit, for which the gradients are computed."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"op_history = qdev1.op_history\n",
"n_wires = qdev1.n_wires\n",
"observable = 'ZZZZ'"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/vast/palmer/home.grace/dl2276/AllProjects/torchquantum/torchquantum/functional/func_controlled_unitary.py:99: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" unitary = torch.tensor(torch.zeros(2**n_wires, 2**n_wires, dtype=C_DTYPE))\n"
]
},
{
"data": {
"text/plain": [
"[tensor(-0.4794)]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hadamard_grad(op_history, n_wires, observable)\n",
"# -0.47942554"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.11 ('tq')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "0e29a0e2839e50afa2f3a774fbdc59fa5f031cf090cdc0c9e6a9a24240713eb0"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
84 changes: 84 additions & 0 deletions examples/hadamard_grad/hadamard_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import torchquantum as tq
from torchquantum.measurement import expval_joint_analytical

def gradient_circuit(left, target, right, n_wires, observable):
'''
Compute the gradient for the target gate.
Parameters:
- left: The gates to the left of the target gate.
- target: The target gate for which the gradient is computed.
- right: The gates to the right of the target gate.
- n_wires: The number of wires (quantum bits) in the circuit.
- observable: The observable of the original circuit, for which the gradients are computed.
Returns:
- The computed gradient for the target gate, analytical
'''
ancilla_qubit = n_wires

gradient = 0
for coeff, pauli in zip(target['coeffs'], target['paulis']):

dev = tq.QuantumDevice(n_wires=n_wires+1, bsz=1, device="cpu") # use device='cuda' for GPU
# ancilla
dev.h(wires=ancilla_qubit)
# left
for op_info in left:
if 'coeffs' in op_info:
# Evolution
op = tq.operator.OpPauliExp(coeffs=op_info['coeffs'], paulis=op_info['paulis'], theta=op_info['params'], trainable=False)
op(dev, wires=op_info['wires'])
else:
# other gates
op = tq.QuantumModule.from_op_history([op_info])
op(dev)
# target
generator = tq.algorithm.Hamiltonian(coeffs=[1.0], paulis=[pauli]) # ZZZZ
dev.controlled_unitary(params=generator.matrix, c_wires=[ancilla_qubit], t_wires=target['wires'])
# ancilla
dev.h(wires=ancilla_qubit)
# right
for op_info in right:
if 'coeffs' in op_info:
# Evolution
op = tq.operator.OpPauliExp(coeffs=op_info['coeffs'], paulis=op_info['paulis'], theta=op_info['params'], trainable=False)
op(dev, wires=op_info['wires'])
else:
# other gates
op = tq.QuantumModule.from_op_history([op_info])
op(dev)

# measurement
original_measurement = observable # 'ZZZZ'
expval = expval_joint_analytical(dev, original_measurement+'Y')
gradient -= coeff * torch.mean(expval)

return gradient

def hadamard_grad(op_history, n_wires, observable):
'''
Return the gradients for parameters in the q_device.
Parameters:
- op_history: The history of quantum operations applied on the q_device.
- n_wires: The number of wires (quantum bits) in the q_device.
- observable: The observable of the original circuit, for which the gradients are computed.
Returns:
- A list of gradients, ordered as the list the operations in op_history
'''
gradient_list = []
for i, op in enumerate(op_history):
if not op['trainable']:
gradient_list.append(None)
else:
left = op_history[:i+1]
target = op
right = op_history[i+1:]
gradient_list.append(
gradient_circuit(left, target, right, n_wires, observable)
)

return gradient_list
54 changes: 0 additions & 54 deletions examples/hadamard_test/circ.py

This file was deleted.

Loading

0 comments on commit 094991a

Please sign in to comment.