Skip to content

Commit

Permalink
Merge pull request #270 from AbdullahKazi500/AbdullahKazi500-patch-2
Browse files Browse the repository at this point in the history
  • Loading branch information
01110011011101010110010001101111 authored Jun 13, 2024
2 parents eda17c1 + d433bbe commit 6b30997
Show file tree
Hide file tree
Showing 5 changed files with 766 additions and 0 deletions.
74 changes: 74 additions & 0 deletions examples/QuantumGan/ README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Quantum Generative Adversarial Network (QGAN) Example

This repository contains an example implementation of a Quantum Generative Adversarial Network (QGAN) using PyTorch and TorchQuantum. The example is provided in a Jupyter Notebook for interactive exploration.

## Overview

A QGAN consists of two main components:

1. **Generator:** This network generates fake quantum data samples.
2. **Discriminator:** This network tries to distinguish between real and fake quantum data samples.

The goal is to train the generator to produce quantum data that is indistinguishable from real data, according to the discriminator. This is achieved through an adversarial training process, where the generator and discriminator are trained simultaneously in a competitive manner.

## Repository Contents

- `qgan_notebook.ipynb`: Jupyter Notebook demonstrating the QGAN implementation.
- `qgan_script.py`: Python script containing the QGAN model and a main function for initializing the model with command-line arguments.

## Installation

To run the examples, you need to have the following dependencies installed:

- Python 3
- PyTorch
- TorchQuantum
- Jupyter Notebook
- ipywidgets

You can install the required Python packages using pip:

```bash
pip install torch torchquantum jupyter ipywidgets
```


Running the Examples
Jupyter Notebook
Open the qgan_notebook.ipynb file in Jupyter Notebook.
Execute the notebook cells to see the QGAN model in action.
Python Script
You can also run the QGAN model using the Python script. The script uses argparse to handle command-line arguments.

bash
Copy code
python qgan_script.py <n_qubits> <latent_dim>
Replace <n_qubits> and <latent_dim> with the desired number of qubits and latent dimensions.

Notebook Details
The Jupyter Notebook is structured as follows:

Introduction: Provides an overview of the QGAN and its components.
Import Libraries: Imports the necessary libraries, including PyTorch and TorchQuantum.
Generator Class: Defines the quantum generator model.
Discriminator Class: Defines the quantum discriminator model.
QGAN Class: Combines the generator and discriminator into a single QGAN model.
Main Function: Initializes the QGAN model and prints its structure.
Interactive Model Creation: Uses ipywidgets to create an interactive interface for adjusting the number of qubits and latent dimensions.
Understanding QGANs
QGANs are a type of Generative Adversarial Network (GAN) that operate in the quantum domain. They leverage quantum circuits to generate and evaluate data samples. The adversarial training process involves two competing networks:

The Generator creates fake quantum data samples from a latent space.
The Discriminator attempts to distinguish these fake samples from real quantum data.
Through training, the generator improves its ability to create realistic quantum data, while the discriminator enhances its ability to identify fake data. This process results in a generator that can produce high-quality quantum data samples.


## QGAN Implementation for CIFAR-10 Dataset
This implementation trains a QGAN on the CIFAR-10 dataset to generate fake images. It follows a similar structure to the TorchQuantum QGAN, with the addition of data loading and processing specific to the CIFAR-10 dataset.
Generated images can be seen in the folder

This `README.md` file explains the purpose of the repository, the structure of the notebook, and how to run the examples, along with a brief overview of the QGAN concept for those unfamiliar with it.


## Reference
- [ ] https://arxiv.org/abs/2312.09939
84 changes: 84 additions & 0 deletions examples/QuantumGan/QGan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torchquantum as tq

class Generator(nn.Module):
def __init__(self, n_qubits: int, latent_dim: int):
super().__init__()
self.n_qubits = n_qubits
self.latent_dim = latent_dim

# Quantum encoder
self.encoder = tq.GeneralEncoder([
{'input_idx': [i], 'func': 'rx', 'wires': [i]}
for i in range(self.n_qubits)
])

# RX gates
self.rxs = nn.ModuleList([
tq.RX(has_params=True, trainable=True) for _ in range(self.n_qubits)
])

def forward(self, x):
qdev = tq.QuantumDevice(n_wires=self.n_qubits, bsz=x.shape[0], device=x.device)
self.encoder(qdev, x)

for i in range(self.n_qubits):
self.rxs[i](qdev, wires=i)

return tq.measure(qdev)

class Discriminator(nn.Module):
def __init__(self, n_qubits: int):
super().__init__()
self.n_qubits = n_qubits

# Quantum encoder
self.encoder = tq.GeneralEncoder([
{'input_idx': [i], 'func': 'rx', 'wires': [i]}
for i in range(self.n_qubits)
])

# RX gates
self.rxs = nn.ModuleList([
tq.RX(has_params=True, trainable=True) for _ in range(self.n_qubits)
])

# Quantum measurement
self.measure = tq.MeasureAll(tq.PauliZ)

def forward(self, x):
qdev = tq.QuantumDevice(n_wires=self.n_qubits, bsz=x.shape[0], device=x.device)
self.encoder(qdev, x)

for i in range(self.n_qubits):
self.rxs[i](qdev, wires=i)

return self.measure(qdev)

class QGAN(nn.Module):
def __init__(self, n_qubits: int, latent_dim: int):
super().__init__()
self.generator = Generator(n_qubits, latent_dim)
self.discriminator = Discriminator(n_qubits)

def forward(self, z):
fake_data = self.generator(z)
fake_output = self.discriminator(fake_data)
return fake_output

def main(n_qubits, latent_dim):
model = QGAN(n_qubits, latent_dim)
print(model)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Quantum Generative Adversarial Network (QGAN) Example")
parser.add_argument('n_qubits', type=int, help='Number of qubits')
parser.add_argument('latent_dim', type=int, help='Dimension of the latent space')

args = parser.parse_args()

main(args.n_qubits, args.latent_dim)

Loading

0 comments on commit 6b30997

Please sign in to comment.