Skip to content

Implementation of Diffusion Transformer Model in Pytorch

License

Notifications You must be signed in to change notification settings

milmor/diffusion-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Transformer

Implementation of the Diffusion Transformer model in the paper:

Scalable Diffusion Models with Transformers.

See here for the official Pytorch implementation.

Dependencies

  • Python 3.9
  • Pytorch 2.1.1

Training Diffusion Transformer

Use --data_dir=<data_dir> to specify the dataset path.

python train.py --data_dir=./data/

Samples

Sample output from minDiT (39.89M parameters) on CIFAR-10:

Sample output from minDiT on CelebA:

Hparams setting

Adjust hyperparameters in the config.py file.

Implementation notes:

  • minDiT is designed to offer reasonable performance using a single GPU (RTX 3080 TI).
  • minDiT largely follows the original DiT model.
  • DiT Block with adaLN-Zero.
  • Diffusion Transformer with Linformer attention.
  • EDM sampler.
  • FID evaluation.

todo

  • Add Classifier-Free Diffusion Guidance and conditional pipeline.
  • Add Latent Diffusion and Autoencoder training.
  • Add generate.py file.

Licence

MIT