Skip to content

Latest commit

 

History

History
68 lines (54 loc) · 4.17 KB

README.md

File metadata and controls

68 lines (54 loc) · 4.17 KB

Stop Throwing Away Discriminators!
Re-using Adversaries for Test-Time Training

Code for the paper:

Valvano G., Leo A. and Tsaftaris S. A. (DART, 2021), Stop Throwing Away Discriminators! Re-using Adversaries for Test-Time Training.

and the extended version:

Valvano G., Leo A. and Tsaftaris S. A. (arXiv, 2021), Re-using Adversarial Mask Discriminators for Test-time Training under Distribution Shifts.

The official project page is here.
An online version of the DART paper can be found here.
The extended version can be found here

Citation:

@incollection{valvano2021stop,
  title={Stop Throwing Away Discriminators! Re-using Adversaries for Test-Time Training},
  author={Valvano, Gabriele and Leo, Andrea and Tsaftaris, Sotirios A},
  booktitle={Domain Adaptation and Representation Transfer, and Affordable Healthcare and AI for Resource Diverse Global Health},
  pages={68--78},
  year={2021},
  publisher={Springer}
}

@article{valvano2021re,
  title={Re-using Adversarial Mask Discriminators for Test-time Training under Distribution Shifts},
  author={Valvano, Gabriele and Leo, Andrea and Tsaftaris, Sotirios A},
  journal={arXiv preprint arXiv:2108.11926},
  year={2021}
}

adversarial_ttt


Notes:

For the experiments, refer to the files:

  • experiments/base_gan_ttt.py. This file contains the model and all the code needed for training. It is the base class inherited from the class Experiment() inside experiments/acdc/exp_gan_ttt.py. Refer to the class method define_model() to see how we build the CNN architectures. The structure of segmentor, discriminator, and adaptor can be found under the folder architectures.
  • experiments/acdc/exp_gan_ttt.py. This file defines a child class inheriting from the base class defined in experiments/base_gan_ttt.py. It defines the directories and filenames needed for the logs, and also the get_data() method, which wraps the experiment to the dataset used for the experiments.

Once you download the ACDC dataset, you can pre-process it using the code in the file data_interface/utils_acdc/prepare_dataset.py. You can also train with custom datasets, but you must adhere to the template required by data_interface/interfaces/dataset_wrapper.py, which assumes the access to the dataset is through a tensorflow dataset iterator. Moreover, you will need to modify the method get_data() inside experiments/acdc/exp_gan_ttt.py.

You can start training following the guidelines in run.sh. To run the training on GPU #0 you can type in the shell:

sh run.sh 0

where 0 is the GPU number. The training will proceed for both experiments in:

  • semi-supervised learning (Non-Identifiable Distribution Shift between train and test set), splitting the dataset in 40-20-40% of samples for train, validation and test sets (training annotations only for 25% of the training data);
  • training on 1.5T MRI scanners and testing on 3T scanners (Identifiable Distribution Shift). After training, the script also performs the test using Adversarial Test-Time Training in its standard formulation, and in a continual learning setting.

After you run the script, you can monitor the training process using tensorboard:

tensorboard --logdir=results/acdc/graphs/

and then using your browser to navigate to the returned http address (defaults on localhost:6006).

Requirements

This code was implemented using TensorFlow 1.14 and the libraries detailed in requirements.txt.
You can install these libraries as:

pip install -r requirements.txt

or using conda (see this).

We tested the code on a TITAN Xp GPU, and on a GeForce GTX 1080, using CUDA 10.2.