Implementation of the Diffusion Transformer model in the paper:
See here for the official Pytorch implementation.
- Python 3.9
- Pytorch 2.1.1
Use --data_dir=<data_dir>
to specify the dataset path.
python train.py --data_dir=./data/
Sample output from minDiT (39.89M parameters) on CIFAR-10:
Sample output from minDiT on CelebA:
Adjust hyperparameters in the config.py
file.
Implementation notes:
- minDiT is designed to offer reasonable performance using a single GPU (RTX 3080 TI).
- minDiT largely follows the original DiT model.
- DiT Block with adaLN-Zero.
- Diffusion Transformer with Linformer attention.
- EDM sampler.
- FID evaluation.
- Add Classifier-Free Diffusion Guidance and conditional pipeline.
- Add Latent Diffusion and Autoencoder training.
- Add generate.py file.
MIT