Skip to content

A repo of a modified version of Diffusion Transformer

Notifications You must be signed in to change notification settings

ArchiMickey/Just-a-DiT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Transformer

DiT CIFAR10 Samples

This repo contains the Diffusion Transformer from the paper Scalable Diffusion Models with Transformers (DiT). [arxiv] [code]. It is a repo created with interest in the combination of diffusion model and transformer model. The code for the network is mostly based on the official implementation from MetaAI. I made several changes to the model with new techniques and tricks.

Setup

You can recreate the conda environment using the provided environment.yml.

conda env create -f environment.yml
conda activate dit

Sampling

You can sample from the model using the Sampler class in sampler.py.

from sampler import Sampler
sampler = Sampler(model) # A trained DiT model
samples = sampler.sample(10) # Sample 10 images
# [10, C, H, W]

For visualization, I use moviepy to generate gifs from the intermediate steps of samples.

CIFAR10 Gif

Training

To train a DiT, you can use the train.py script.

python train.py

I use wandb for logging. If you don't want to use it, you can remove the logger in train.py. You can also check the log of the latest training on wandb here. I have trained the model on CIFAR10 dataset for 200k steps. The model is not converged yet by looking at the FID, more training may lead to better results.

Techniques and Tricks

Flow Matching

Instead of using the standard Gaussian Diffusion, I use the Flow Matching technique from the paper Flow Matching for Generative Modeling by using the torchcfm. This technique helps to sample the images faster and more efficiently.

Tailored SNR Samplers

I implement the Logit-Normal Sampling for the timesteps. The technique is used in the Research Paper of Stable Diffusion 3. This is useful for biasing the intermediate steps during the training of the diffusion model.

ViT with registers

I add register tokens to the Transformer Model which is from the paper Vision Transformers need registers. The paper shows that adding additional register tokens can improve the performance of the model.

Classifier-Free Guidance

For the DiT model, I implement another forward function for classifier-free guidance. This is useful for sampling from the model. Samples with 1.0 and 2.5 cfg scale:

CIFAR10 cfg=1.0

CIFAR10 cfg=2.5

EMA

I use the Exponential Moving Average (EMA) for the model weights. This helps to stabilize the training and improve the performance of the model.

About

A repo of a modified version of Diffusion Transformer

Topics

Resources

Stars

Watchers

Forks

Languages