Skip to content

Official PyTorch implementation for Diffusion Rejection Sampling (DiffRS) in ICML 2024.

License

Notifications You must be signed in to change notification settings

byeonghu-na/DiffRS

 
 

Repository files navigation

Diffusion Rejection Sampling (DiffRS) (ICML 2024)


This repo contains an official PyTorch implementation for the paper "Diffusion Rejection Sampling" in ICML 2024.

Byeonghu Na, Yeongmin Kim, Minsang Park, Donghyeok Shin, Wanmo Kang, and Il-Chul Moon


This paper introduces Diffusion Rejection Sampling (DiffRS), a new diffusion sampling approach that ensures alignment between the reverse transition and the true transition at each timestep.

Illustration of the sampling process for DiffRS. The path with the green background represents the DiffRS sampling process, and the rightmost images are generated when the images are sampled as a base sampler without rejection from the intermediate image. Timesteps are expressed as the noise level σ from the EDM scheme.

Overview of DiffRS. We sequentially apply the rejection sampling on the pre-trained transition kernel (red) to align the true transition kernel (blue). The acceptance probability is estimated by the time-dependent discriminator.

Requirements

The requirements for this code are the same as DG.

In our experiment, we utilized CUDA 11.4 and PyTorch 1.12.

Diffusion Rejection Sampling

  1. Download the pre-trained diffusion network and the trained discriminator network from DG.
  • Download 'edm-cifar10-32x32-uncond-vp.pkl' at EDM.
  • Download 'DG/checkpoints/ADM_classifier/32x32_classifier.pt' at DG.
  • Download 32x32_classifier.pt at ADM.
  1. Generate DiffRS samples using generate_diffrs.py. For example:
python3 generate_diffrs.py \
    --network checkpoints/pretrained_score/edm-cifar10-32x32-uncond-vp.pkl \
    --outdir=samples/cifar10/diffrs --rej_percentile=0.75 --max_iter=105

Acknowledgements

This work is heavily built upon the code from:

Citation

@inproceedings{na2024diffusion,
  title={Diffusion Rejection Sampling},
  author={Na, Byeonghu and Kim, Yeongmin and Park, Minsang and Shin, Donghyeok and Kang, Wanmo and Moon, Il-Chul},
  booktitle={International Conference on Machine Learning},
  year={2024},
  organization={PMLR}
}

About

Official PyTorch implementation for Diffusion Rejection Sampling (DiffRS) in ICML 2024.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.7%
  • Dockerfile 0.3%