Skip to content

Latest commit

 

History

History
253 lines (195 loc) · 12.8 KB

README.md

File metadata and controls

253 lines (195 loc) · 12.8 KB


License

Overview

SubpopBench is a benchmark of subpopulation shift. It is a living PyTorch suite containing benchmark datasets and algorithms for subpopulation shift, as introduced in Change is Hard: A Closer Look at Subpopulation Shift (Yang et al., ICML 2023).

Contents

Currently we support 13 datasets and ~20 algorithms that span different learning strategies. Feel free to send us a PR to add your algorithm / dataset for subpopulation shift.

Available Algorithms

The currently available algorithms are:

Send us a PR to add your algorithm! Our implementations use the hyper-parameter grids described here.

Available Datasets

The currently available datasets are:

Send us a PR to add your dataset! You can follow the dataset format described here.

Model Architectures & Pretraining Methods

The supported image architectures are:

The supported text architectures are:

Note that text architectures are only compatible with CivilComments.

Subpopulation Shift Scenarios

We characterize four basic types of subpopulation shift using our framework, and categorize each dataset into its most dominant shift type.

  • Spurious Correlations (SC): certain $a$ is spuriously correlated with $y$ in training but not in testing.
  • Attribute Imbalance (AI): certain attributes are sampled with a much smaller probability than others in $p_{\text{train}}$, but not in $p_{\text{test}}$.
  • Class Imbalance (CI): certain (minority) classes are underrepresented in $p_{\text{train}}$, but not in $p_{\text{test}}$.
  • Attribute Generalization (AG): certain attributes can be totally missing in $p_{\text{train}}$, but present in $p_{\text{test}}$.

Evaluation Metrics

We include a variety of metrics aiming for a thorough evaluation from different aspects:

  • Average Accuracy & Worst Accuracy
  • Average Precision & Worst Precision
  • Average F1-score & Worst F1-score
  • Adjusted Accuracy
  • Balanced Accuracy
  • AUROC & AUPRC
  • Expected Calibration Error (ECE)

Model Selection Criteria

We highlight the impact of whether attribute is known in (1) training set and (2) validation set, where the former is specified by --train_attr in train.py, and the latter is specified by model selection criteria. We show a few important selection criteria:

  • OracleWorstAcc: Picks the best test-set worst-group accuracy (oracle)
  • ValWorstAccAttributeYes: Picks the best val-set worst-group accuracy (attributes known in validation)
  • ValWorstAccAttributeNo: Picks the best val-set worst-class accuracy (attributes unknown in validation; group degenerates to class)

Getting Started

Installation

Prerequisites

Run the following commands to clone this repo and create the Conda environment:

git clone git@github.com:YyzHarry/SubpopBench.git
cd SubpopBench/
conda env create -f environment.yml
conda activate subpop_bench

Downloading Data

Download the original datasets and generate corresponding metadata in your data_path:

python -m subpopbench.scripts.download --data_path <data_path> --download

For MIMICNoFinding, CheXpertNoFinding, CXRMultisite, and MIMICNotes, see MedicalData.md for instructions for downloading the datasets manually.

Code Overview

Main Files

  • train.py: main training script
  • sweep.py: launch a sweep with all selected algorithms (provided in subpopbench/learning/algorithms.py) on all subpopulation shift datasets
  • collect_results.py: collect sweep results to automatically generate result tables (as in the paper)

Main Arguments

  • train.py:
    • --dataset: name of chosen subpopulation dataset
    • --algorithm: choose algorithm used for running
    • --train_attr: whether attributes are known or not during training (yes or no)
    • --data_dir: data path
    • --output_dir: output path
    • --output_folder_name: output folder name (under output_dir) for the current run
    • --hparams_seed: seed for different hyper-parameters
    • --seed: seed for different runs
    • --stage1_folder & --stage1_algo: arguments for two-stage algorithms
    • --image_arch & --text_arch: model architecture and source of initial model weights (text architectures only compatible with CivilComments)
  • sweep.py:
    • --n_hparams: how many hparams to run for each <dataset, algorithm> pair
    • --best_hp & --n_trials: after sweeping hparams, fix best hparam and run trials with different seeds

Usage

Train a single model (with unknown attributes)

python -m subpopbench.train \
       --algorithm <algo> \
       --dataset <dset> \
       --train_attr no \
       --data_dir <data_path> \
       --output_dir <output_path> \
       --output_folder_name <output_folder_name>

Train a model using 2-stage methods, e.g., DFR (with known attributes)

python -m subpopbench.train \
       --algorithm DFR \
       --dataset <dset> \
       --train_attr yes \
       --data_dir <data_path> \
       --output_dir <output_path> \
       --output_folder_name <output_folder_name> \
       --stage1_folder <stage1_model_folder> \
       --stage1_algo <stage1_algo>

Launch a sweep with different hparams (with unknown attributes)

python -m subpopbench.sweep launch \
       --algorithms <...> \
       --dataset <...> \
       --train_attr no \
       --n_hparams <num_of_hparams> \
       --n_trials 1

Launch a sweep after fixing hparam with different seeds (with unknown attributes)

python -m subpopbench.sweep launch \
       --algorithms <...> \
       --dataset <...> \
       --train_attr no \
       --best_hp \
       --input_folder <...> \
       --n_trials <num_of_trials>

Collect the results of your sweep

python -m subpopbench.scripts.collect_results --input_dir <...>

Updates

Acknowledgements

This code is partly based on the open-source implementations from DomainBed, spurious_feature_learning, and multi-domain-imbalance.

Citation

If you find this code or idea useful, please cite our work:

@inproceedings{yang2023change,
  title={Change is Hard: A Closer Look at Subpopulation Shift},
  author={Yang, Yuzhe and Zhang, Haoran and Katabi, Dina and Ghassemi, Marzyeh},
  booktitle={International Conference on Machine Learning},
  year={2023}
}

Contact

If you have any questions, feel free to contact us through email (yuzhe@mit.edu & haoranz@mit.edu) or GitHub issues. Enjoy!