Skip to content

Latest commit

 

History

History
75 lines (47 loc) · 5.32 KB

README.md

File metadata and controls

75 lines (47 loc) · 5.32 KB

ARES 2.0: Robust Training for Image Classification

Abstract

This repository contains the code for adversarial training on classification models, which is derived from A Comprehensive Study on Robustness of Image Classification Models: Benchmarking and Rethinking, a Python library for adversarial machine learning research focusing on benchmarking adversarial robustness on image classification correctly and comprehensively. The project incorporates the widely adopted timm as its default classification library.

Major features

  • Integration with timm

    • Leverage various classification models from timm for adversarial training to achieve robustness across diverse model architectures.
  • State-of-the-art Models Available

    • Some of the SOTA models are available from the model zoo, which are trained with the corresponding settings.
  • Multiple Augmentations

    • Multiple augmentations are supported, including Mixup, Label Smoothing, EMA and so on.
  • Distributed training and testing

    • Pytorch distributed data-parallel training and testing are supported for faster training and testing.

Preparation

Dataset

  • We train our models with ImageNet dataset. Please download ImageNet dataset first. The directories to the training and evaluation dataset should be assigned to train_dir and eval_dir in the train_configs files.

Classification Model

  • Train classification models using timm or from your own model class.

Getting Started

  • We provide a command line interface to run adversarial training. For example, you can train a robust model of ResNet50 with the corresponding configuration:

    python -m torch.distributed.launch --nproc_per_node=<num-of-gpus-to-use> adversarial_training.py --configs=./train_configs/resnet50.yaml
  • For distributed training and testing, you can also refer to the run_train.sh for details.

Results

  • Evaluation of some classification models

    Attack settings: adversarial attack using PGD and autoattack with eps=4/255 under L $\infty$ norm.

    Dataset settings: randomly sampling 1000 data from ImageNet validation set.

    Model settings: adversarially trained on ImageNet training set.

    Model Name Clean Accuracy FGSM PGD100 AutoAttack RobustBench Checkpoints
    ResNet50 67.0 44.5 38.7 34.1 - Download
    ResNet101 71.0 51.3 46.5 42.2 - Download
    ResNet152 72.4 54.6 49.6 46.7 - Download
    Wide-ResNet50 70.5 51.8 44.6 39.3 - Download
    ConvNextS 77.3 60.3 56.9 54.3 - Download
    ConvNextB 77.2 62.2 59.0 56.8 55.82 Download
    ConvNextL 78.8 63.9 61.7 60.1 58.48 Download
    ViTS 70.7 51.3 47.5 43.7 - Download
    ViTB 74.7 55.9 52.2 49.7 - Download
    SwinS 76.6 61.5 58.4 55.6 - Download
    SwinB 76.6 63.2 60.2 57.3 56.16 Download
    SwinL 79.7 65.9 63.9 62.3 59.56 Download

Acknowledgement

Many thanks to these excellent open-source projects: