This repo contains the Diffusion Transformer from the paper Scalable Diffusion Models with Transformers (DiT). [arxiv] [code]. It is a repo created with interest in the combination of diffusion model and transformer model. The code for the network is mostly based on the official implementation from MetaAI. I made several changes to the model with new techniques and tricks.
You can recreate the conda environment using the provided environment.yml.
conda env create -f environment.yml
conda activate dit
You can sample from the model using the Sampler
class in sampler.py
.
from sampler import Sampler
sampler = Sampler(model) # A trained DiT model
samples = sampler.sample(10) # Sample 10 images
# [10, C, H, W]
For visualization, I use moviepy
to generate gifs from the intermediate steps of samples.
To train a DiT, you can use the train.py
script.
python train.py
I use wandb
for logging. If you don't want to use it, you can remove the logger in train.py
. You can also check the log of the latest training on wandb here. I have trained the model on CIFAR10 dataset for 200k steps. The model is not converged yet by looking at the FID, more training may lead to better results.
Instead of using the standard Gaussian Diffusion, I use the Flow Matching technique from the paper Flow Matching for Generative Modeling by using the torchcfm. This technique helps to sample the images faster and more efficiently.
I implement the Logit-Normal Sampling for the timesteps. The technique is used in the Research Paper of Stable Diffusion 3. This is useful for biasing the intermediate steps during the training of the diffusion model.
I add register tokens to the Transformer Model which is from the paper Vision Transformers need registers. The paper shows that adding additional register tokens can improve the performance of the model.
For the DiT model, I implement another forward function for classifier-free guidance. This is useful for sampling from the model. Samples with 1.0 and 2.5 cfg scale:
I use the Exponential Moving Average (EMA) for the model weights. This helps to stabilize the training and improve the performance of the model.