Official implementation of the paper Unifying GANs and Score-Based Diffusion as Generative Particle Models (Jean-Yves Franceschi, Mike Gartrell, Ludovic Dos Santos, Thibaut Issenhuth, Emmanuel de Bézenac, Mickaël Chen, Alain Rakotomamonjy), to appear at NeurIPS 2023.
All models were trained with Python 3.10.4 and PyTorch 1.13.1 using CUDA 11.8. The requirements.txt
file lists Python package dependencies.
To launch an experiment, you need:
- a path to the data
$DATA_PATH
; - a path to a YAML config file (whose format depends on the chosen model, but examples are shown in the
config
folder)$CONFIG_FILE
; - a path where the logs and checkpoints should be saved
$LOG_PATH
; - the chosen name of the experiment
$NAME
.
You can then launch in the root folder of the project the following command:
bash launch.sh --data_path $DATA_PATH --configs $CONFIG_FILE --save_path $LOG_PATH --save_name $NAME
There are some optional arguments to this command which you can find with bash launch.sh --help
. In particular, --device $DEVICE1 $DEVICE2 ...
launches training on GPUs $DEVICE1
, $DEVICE2
, etc. (e.g. --device 0 1 2 3
). By default, multi-GPU training uses PyTorch's DistributedDataParallel
and PyTorch's launching utility torchrun
. As of now, multi-GPU training only supports single-node parallelism but extending it to multi-node should be straightforward.
Please do not include any log file or data in code folders as it may prevent this code from properly working. Indeed, this code automatically saves the content of the project root folder in the log directory.
By default, a test in performed at the end of training on the model corresponding to the last iteration only. If you want to test the best saved model, you can first load the checkpoint using the --load
option, then load the best model by additionnally using the --load_best
option. This will resume training if it was not completed; if you want to only test the model, add the --only_test
option.
Before loading a model, please create a backup of the experiment folder as some files may be overwritten in the process.
Implemented models are listed in gpm/models/models.py
. In particular: a basic GAN implementation, a score-based model (reimplementation of EDM), our Discriminator Flow, and our ScoreGAN model which requires a pretrained data score network.
The config
folder lists configuration files with suggested parameters that were used in the paper. A documented example can be found at config/discr_flow/celeba/config.yaml
. All possible parameter dictionaries are documented in the code (cf. discussion below).
This code is meant to be generic enough so that it can be used to create pretty much any PyTorch deep model, and for example extend the proposed models. On the downside, it lets the user responsible of many technical details and of complying with the general organization.
All training initialization and iterations (seed, devices, parallelism on a single node, model creation, data loading, training and testing loops, etc.) are handled in gpm/train.py
; options for the launch command are in gpm/utils/args.py
. All neural networks used in the models are implemented in the different files of gpm/networks
; models are coded separately in gpm/models
.
To keep the code generic, model configurations are provided as YAML files, with examples in the config
folder. The expected arguments differ depending on the chosen models and networks. To specify which arguments are needed for e.g. a network, we use the following kind of code (here is an example for an MLP).
class MLPDict(DotDict):
def __init__(self, *kargs, **kwargs):
super().__init__(*kargs, **kwargs)
assert {'depth', 'width', 'activation'}.issubset(self.keys())
self.depth: int # Number of hidden layers
self.width: int # Width of hidden layers
self.activation: ActivationDict | list[ActivationDict] # Activations of hidden layers
if isinstance(self.activation, list):
self.activation = list(map(ActivationDict, self.activation))
else:
self.activation = ActivationDict(self.activation)
self.batch_norm: bool # Whether to incorporate batch norm (default, False)
if 'batch_norm' not in self:
self.batch_norm = False
self.spectral_norm: bool # Whether to perform spectral norm (default, False)
if 'spectral_norm' not in self:
self.spectral_norm = False
The assert
ensures that all required arguments are in the provided configuration file. The rest of the class attributes is just syntax sugar to provide typing information on these arguments, or specifies default values for some parameters.
Model coding is based on object-oriented programming: each model is a single torch.nn.Module
containing all the model parameters and inheriting from the abstract class BaseModel
which indicates which methods any model should reimplement. You can follow the documentation for more details. Datasets follow the same philosophy, inheriting the Dataset
class from PyTorch and the custom abstract class BaseDataset
.
Log folders are organized as follows, as coded by the internal logger:
- a
chkpt
folder containing checkpoints needed to resume training, including the last checkpoint and the best checkpoint w.r.t. the evaluation metric; - an
eval
and atest
folder containing evaluations on the validation and test sets (there can be several evaluation configurations at once); - a
config.json
file containing the training configuration; logs.*.json
files containing training, validation and testing logs, including all logges metrics;result*.json
files containing information on the saved checkpoints (last and best);- a
source.zip
file containing the launched source code; - a
test.yaml
file containing the original YAML configuration file.
This code is open-source. We share most of it under the Apache 2.0 License.
However, we reuse code from FastGAN and EDM where were released under more restrictive licenses (respectively, GNU GPLv3 and CC-BY-NC-SA 4.0) that require redistribution under the same license or equivalent. Hence, the corresponding parts of our code (respectively, gpm/networks/conv/fastgan
and gpm/networks/score
) are open-sourced using the original licenses of these works and not Apache. See the corresponding folders for the details.