This repository contains the official PyTorch implementation of DataFreeShield: Defending Adversarial Attacks without Training Data presented at ICML 2024.
pip install -r requirements.txt
For biomedical dataset (MedMNIST-V2), we train our own teacher model, which we provide in the below table.
Dataset | ResNet-18 | Acc (%) | ResNet-50 | Acc (%) |
---|---|---|---|---|
TissueMNIST | link | 67.62 | link | 68.29 |
BloodMNIST | link | 95.53 | link | 95.00 |
PathMNIST | link | 92.19 | link | 91.41 |
OrganCMNIST | link | 90.74 | link | 91.06 |
Note that you can also train your own teacher with our script:
python3 train_biomedical.py --data_flag tissuemnist --as_rgb --gpu [gpu_id] --model resnet50
For general domain datasets (SVHN, CIFAR-10, CIFAR-100), we use pretrained weights from PytorchCV.
The code for sample synthesis builds upon DeepInversion. For faster generation, we run multiple jobs in parallel on different GPUs and merge the generated sets afterwards.
python3 generate.py --model resnet20_cifar10 --save_root datasets/rn20_cifar10/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
python3 generate.py --model resnet56_cifar10 --save_root datasets/rn56_cifar10/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
python3 generate.py --model wrn28_10_cifar10 --save_root datasets/wrn28_cifar10/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
python3 generate.py --model resnet20_svhn --save_root datasets/rn20_svhn/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
python3 generate.py --model resnet56_svhn --save_root datasets/rn56_svhn/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
python3 generate.py --model wrn28_10_svhn --save_root datasets/wrn28_svhn/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
python3 generate_biomedical.py --model resnet18 --data_flag tissuemnist --save_root datasets/rn18_tissue/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
python3 generate_biomedical.py --model resnet50 --data_flag tissuemnist --save_root datasets/rn50_tissue/ --num_total_images 10000 --seed [random_seed] --gpu [gpu_id]
This script will merge individual sets of generated data into a single .pt file.
python3 merge_dataset.py --root [/path/to/save/dataset] --model resnet20 --dataset cifar10
We also provide pregenerated datasets for those who want to proceed with training without generation.
Dataset | ResNet-18 | ResNet-50 |
---|---|---|
TissueMNIST | TBR | TBR |
BloodMNIST | TBR | TBR |
PathMNIST | TBR | TBR |
OrganCMNIST | TBR | TBR |
Dataset | ResNet-20 | ResNet-56 | WRN-28-10 |
---|---|---|---|
SVHN | link | link | link |
CIFAR-10 | link | link | link |
The training script and the required files are placed under train/
We provide example usage:
python3 main.py --conf_path configs/cifar10_robust_student.hocon --advloss DFShieldLoss --model resnet20 --train_eps 4 --train_step_size 1 --eps 4 --step_size 1 --exp_name rn20_cifar10_dfshield --data_pth [path/to/dataset] --p_thresh 0.5 --agg_iter 20
python3 main.py --conf_path configs/svhn_robust_student.hocon --advloss DFShieldLoss --model resnet56 --train_eps 8 --train_step_size 2 --eps 8 --step_size 2 --exp_name rn56_svhn_dfshield --data_pth [path/to/dataset] --p_thresh 0.5 --agg_iter 20
We provide evaluation code for autoattack in evaluation/
run_AA.py is the script that runs AutoAttack given the necessary information (path to the weight, dataset, epsilon, etc). The example usage can be found in eval.sh
However, we suggest you use run_parse.py, which automates this process by reading all this necessary information from the path to the weight. We encourage you to use this especially when you have multiple checkpoints waiting to be evaluated. It makes life easier :)
@inproceedings{lee2024datafreeshield,
title={DataFreeShield: Defending Adversarial Attacks without Training Data},
author={Lee, Hyeyoon and Choi, Kanghyun and Kwon, Dain and Park, Sunjong and Jaiswal, Mayoore Selvarasa and Park, Noseong and Choi, Jonghyun and Lee, Jinho},
booktitle = {International Conference on Machine Learning (ICML)},
year = {2024}
}