Skip to content

This repository contains the code for the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging".

License

Notifications You must be signed in to change notification settings

berenslab/dependence-measures-medical-imaging

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging

This repository contains the code to reproduce the results from the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging", which was accepted to the 15th International Workshop on Machine Learning in Medical Imaging (MLMI 2024).

We present a comprehensive performance comparison of dependency measures to prevent shortcut learning in medical imaging.

drawing

Installation

Set up a python environment with a python version 3.10. Then, download the repository, activate the environment and install all other dependencies with

cd dependence-measures-medical-imaging
pip install --editable . 

This installs the code in src as an editable package and all the dependencies in requirements.txt.

Organization of the repo

  • configs: Configuration files for all experiments.
  • scripts: Slurm scripts for model training and hyperparameter sweeps.
  • src: Main source code to run the experiments.
    • data: Pytorch datasets and scripts/info to download data.
    • models: Pytorch lightning module to train models to prevent shortcut learning with different methods.
    • eval: Model evaluation with kNN classifiers and embedding plots.
  • train.py: Main training script to train k-fold cross validation (and optional hyperparameter sweeps).

Usage

Download public datasets

First, you need to download the two data sets Morpho-MNIST and CheXpert. For Morpho-MNIST we have a download script:

python src/data/download_data/load_morpho_mnist.py -d path-to-dataset-directory -v True

For CheXpert you need to register, hence we provide additional information on how to register and download the dataset: load_chexpert.txt.

Training

To run k-fold cross-validation for one method you need to hand over a config file to the train script. For example, for MINE with the Morpho-MNIST dataset the comand-line interface is

python src/train.py -tc configs/morpho-mnist/mine.yaml

Note: The dataset_path needs to be adjusted in the config file.

To run the code on a slurm cluster, we provide a bash script:

sbatch scripts/train.sh configs/morpho-mnist/mine.yaml

Run hyperparameter sweeps (wandb)

Initialize the sweep with

python src/utils/sweep_init.py -sc configs/example_sweep.yaml

This will print out the sweep_id that you can hand over to the script to start multiple runs (10 in this case) on a slurm cluster

sh scripts/sweep.sh 10 configs/morpho-mnist/mine.yaml sweep_id

Evaluation

To evaluate the trained models for the confusion matrix of kNN classifier accuracy for one model run

python src/eval/knn_classifier.py -c model_config -ckpts model_checkpoints

To generate the embedding plots of the paper run

python src/eval/embeddings.py -cfgs list_of_model_configs -ckpts list_of_model_checkpoints

Cite

If you find our code or paper useful, please consider citing this work. Until the MLMI 2024 processings are published, please cite the preprint

@misc{mueller2024benchmarking,
    title = {Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging},
    author = {M\"uller, Sarah and Fay, Louisa and Koch, Lisa M. and Gatidis, Sergios and K\"ustner, Thomas and Berens, Philipp},
    year={2024},
    eprint={},
    archivePrefix={arXiv},
}

About

This repository contains the code for the paper "Benchmarking Dependence Measures to Prevent Shortcut Learning in Medical Imaging".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published