- main branch: a stable implementation of time-dependent variational scores.
- static_module branch: implementation of time-invariant variational scores, which is faster.
- time_series branch: Applications of VSDM in time series forecasting.
- image branch: The EDM implementation of both FB-SDE and VSDM of Images.
@inproceedings{VSDM,
title={Variational Schr\"odinger Diffusion Models},
author={Wei Deng and Weijian Luo and Yixin Tan and Marin Bilo\v s and Yu Chen and Yuriy Nevmyvaka and Ricky T. Q. Chen},
booktitle={ International Conference on Machine Learning},
year={2024}
}
Following the link, we can install the environment vsd
using Anaconda as follows
conda env create --file requirements.yaml python=3
conda activate vsd
We set beta-r as 0 to fix the hyperparameters of the VP-SDE. We choose
python main.py --problem-name gaussian --num-stage 20 --forward-net Linear --dir gaussian_vsdm_4 --beta-max 4 --beta-r 0. --interact-coef 1
python main.py --problem-name spiral --num-itr-dsm 100000 --dir spiral_8y_dsm_10 --y-scalar 8 --beta-max 10 --DSM-baseline
python main.py --problem-name checkerboard --num-itr-dsm 100000 --dir check_6x_dsm_10 --x-scalar 6 --beta-max 10 --DSM-baseline
python main.py --problem-name spiral --num-itr-dsm 100000 --dir spiral_8y_dsm_20 --y-scalar 8 --beta-max 20 --DSM-baseline
python main.py --problem-name checkerboard --num-itr-dsm 100000 --dir check_6x_dsm_20 --x-scalar 6 --beta-max 20 --DSM-baseline
python main.py --problem-name spiral --num-itr-dsm 500 --num-stage 200 --forward-net Linear \
--dir spiral_8y_vsdm_10 --y-scalar 8 --beta-max 10
python main.py --problem-name checkerboard --num-itr-dsm 500 --num-stage 200 --forward-net Linear \
--dir check_6x_vsdm_10 --x-scalar 6 --beta-max 10
The current code only support NFE=6 (setting interval 108) and 8 (interval 128).
python main.py --problem-name spiral --num-itr-dsm 100000 --dir spiral_dsm_nfe_6 --y-scalar 8 --DSM-baseline --nfe 6
python main.py --problem-name checkerboard --num-itr-dsm 100000 --dir check_dsm_nfe_6 --x-scalar 6 --DSM-baseline --nfe 6
python main.py --problem-name spiral --num-itr-dsm 500 --num-stage 200 --forward-net Linear --dir spiral_vsdm_nfe_6 --y-scalar 8 --interact-coef 0.85 --nfe 6
python main.py --problem-name checkerboard --num-itr-dsm 500 --num-stage 200 --forward-net Linear --dir check_vsdm_nfe_6 --x-scalar 6 --interact-coef 0.85 --nfe 6