Skip to content

Commit

Permalink
Added QCBM algorithm with example
Browse files Browse the repository at this point in the history
  • Loading branch information
Gopal-Dahale committed May 31, 2024
1 parent c8a1ab6 commit b0d8446
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 0 deletions.
8 changes: 8 additions & 0 deletions examples/QCBM/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Quantum Circuit Born Machine

Quantum Circuit Born Machine (QCBM) [1] is a generative modeling algorithm which uses Born rule from quantum mechanics to sample from a quantum state $|\psi \rangle$ learned by training an ansatz $U(\theta)$ [1][2]. In this tutorial we show how `torchquantum` can be used to model a Gaussian mixture with QCBM.

## References

1. Liu, Jin-Guo, and Lei Wang. “Differentiable learning of quantum circuit born machines.” Physical Review A 98.6 (2018): 062324.
2. Gili, Kaitlin, et al. "Do quantum circuit born machines generalize?." Quantum Science and Technology 8.3 (2023): 035021.
61 changes: 61 additions & 0 deletions examples/QCBM/qcbm_gaussian_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchquantum.algorithm import QCBM, MMDLoss
import torchquantum as tq


# Function to create a gaussian mixture
def gaussian_mixture_pdf(x, mus, sigmas):
mus, sigmas = np.array(mus), np.array(sigmas)
vars = sigmas**2
values = [
(1 / np.sqrt(2 * np.pi * v)) * np.exp(-((x - m) ** 2) / (2 * v))
for m, v in zip(mus, vars)
]
values = np.sum([val / sum(val) for val in values], axis=0)
return values / np.sum(values)

# Create a gaussian mixture
n_wires = 6
x_max = 2**n_wires
x_input = np.arange(x_max)
mus = [(2 / 8) * x_max, (5 / 8) * x_max]
sigmas = [x_max / 10] * 2
data = gaussian_mixture_pdf(x_input, mus, sigmas)

# This is the target distribution that the QCBM will learn
target_probs = torch.tensor(data, dtype=torch.float32)

# Ansatz
layers = tq.RXYZCXLayer0({"n_blocks": 6, "n_wires": n_wires, "n_layers_per_block": 1})

qcbm = QCBM(n_wires, layers)

# To train QCBMs, we use MMDLoss with radial basis function kernel.
bandwidth = torch.tensor([0.25, 60])
space = torch.arange(2**n_wires)
mmd = MMDLoss(bandwidth, space)

# Optimization
optimizer = torch.optim.Adam(qcbm.parameters(), lr=0.01)
for i in range(100):
optimizer.zero_grad(set_to_none=True)
pred_probs = qcbm()
loss = mmd(pred_probs, target_probs)
loss.backward()
optimizer.step()
print(i, loss.item())

# Visualize the results
with torch.no_grad():
pred_probs = qcbm()

plt.plot(x_input, target_probs, linestyle="-.", label=r"$\pi(x)$")
plt.bar(x_input, pred_probs, color="green", alpha=0.5, label="samples")
plt.xlabel("Samples")
plt.ylabel("Prob. Distribution")

plt.legend()
plt.show()
33 changes: 33 additions & 0 deletions test/algorithm/test_qcbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from torchquantum.algorithm.qcbm import QCBM, MMDLoss
import torchquantum as tq
import torch
import pytest


def test_qcbm_forward():
n_wires = 3
n_layers = 3
ops = []
for l in range(n_layers):
for q in range(n_wires):
ops.append({"name": "rx", "wires": q, "params": 0.0, "trainable": True})
for q in range(n_wires - 1):
ops.append({"name": "cnot", "wires": [q, q + 1]})

data = torch.ones(2**n_wires)
qmodule = tq.QuantumModule.from_op_history(ops)
qcbm = QCBM(n_wires, qmodule)
probs = qcbm()
expected = torch.tensor([1.0, 0, 0, 0, 0, 0, 0, 0])
assert torch.allclose(probs, expected)


def test_mmd_loss():
n_wires = 2
bandwidth = torch.tensor([0.1, 1.0])
space = torch.arange(2**n_wires)

mmd = MMDLoss(bandwidth, space)
loss = mmd(torch.zeros(4), torch.zeros(4))
print(loss)
assert torch.isclose(loss, torch.tensor(0.0), rtol=1e-5)
1 change: 1 addition & 0 deletions torchquantum/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
from .hamiltonian import *
from .qft import *
from .grover import *
from .qcbm import *
96 changes: 96 additions & 0 deletions torchquantum/algorithm/qcbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import torch.nn as nn

import torchquantum as tq

__all__ = ["QCBM", "MMDLoss"]


class MMDLoss(nn.Module):
"""Squared maximum mean discrepancy with radial basis function kerne"""

def __init__(self, scales, space):
"""
Initialize MMDLoss object. Calculates and stores the kernel matrix.
Args:
scales: Bandwidth parameters.
space: Basis input space.
"""
super().__init__()

gammas = 1 / (2 * (scales**2))

# squared Euclidean distance
sq_dists = torch.abs(space[:, None] - space[None, :]) ** 2

# Kernel matrix
self.K = sum(torch.exp(-gamma * sq_dists) for gamma in gammas) / len(scales)
self.scales = scales

def k_expval(self, px, py):
"""
Kernel expectation value
Args:
px: First probability distribution
py: Second probability distribution
Returns:
Expectation value of the RBF Kernel.
"""

return px @ self.K @ py

def forward(self, px, py):
"""
Squared MMD loss.
px: First probability distribution
py: Second probability distribution
Returns:
Squared MMD loss.
"""
pxy = px - py
return self.k_expval(pxy, pxy)


class QCBM(nn.Module):
"""
Quantum Circuit Born Machine (QCBM)
Attributes:
ansatz: An Ansatz object
n_wires: Number of wires in the ansatz used.
Methods:
__init__: Initialize the QCBM object.
forward: Returns the probability distribution (output from measurement).
"""

def __init__(self, n_wires, ansatz):
"""
Initialize QCBM object
Args:
ansatz (Ansatz): An Ansatz object
n_wires (int): Number of wires in the ansatz used.
"""
super().__init__()

self.ansatz = ansatz
self.n_wires = n_wires

def forward(self):
"""
Execute and obtain the probability distribution
Returns:
Probabilities (torch.Tensor)
"""
qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=1, device="cpu")
self.ansatz(qdev)
probs = torch.abs(qdev.states.flatten()) ** 2
return probs

0 comments on commit b0d8446

Please sign in to comment.