EPVT: Environment-aware Prompt Vision Transformer for Domain Generalization in Skin Lesion Recognition
Official PyTorch implementation of the MICCAI 2023 paper and a domain generalization benchmark for skin lesion recognition.
[arXiv
]
[BibTex
]
[MICCAI paper
]
[Journal paper
]
[abstract] Skin lesion recognition using deep learning has made remarkable progress, and there is an increasing need for deploying these systems in real-world scenarios. However, recent research has revealed that deep neural networks for skin lesion recognition may overly depend on disease-irrelevant image artifacts (i.e. dark corners, dense hairs), leading to poor generalization in unseen environments. To address this issue, we propose a novel domain generalization method called EPVT, which involves embedding prompts into the vision transformer to collaboratively learn knowledge from diverse domains. Concretely, EPVT leverages a set of domain prompts, each of which plays as a domain expert, to capture domain-specific knowledge; and a shared prompt for general knowledge over the entire dataset. To facilitate knowledge sharing and the interaction of different prompts, we introduce a domain prompt generator that enables low-rank multiplicative updates between domain prompts and the shared prompt. A domain mixup strategy is additionally devised to reduce the co-occurring artifacts in each domain, which allows for more flexible decision margins and mitigates the issue of incorrectly assigned domain labels. Experiments on four out-of-distribution datasets and six different biased ISIC datasets demonstrate the superior generalization ability of EPVT in skin lesion recognition across various environments.
- The extended version has been released at https://github.com/SiyuanYan1/PLDG. We propose a universal domain generalization framework for medical image classification without relying on any domain labels.
Create the environment and install packages
conda create -n env_name python=3.8 -y
conda activate env_name
pip install -r requirements.txt
ISIC2019: download ISIC2019 training dataset from here
Derm7pt: download Derm7pt Clinical and Derm7pt Dermoscopic dataset from here
PH2: download the PH2 dataset from here
PAD: download the PAD-UFES-20 dataset from here
Pre-processing the ISIC2019 dataset to construct the artifacts-based domain generalization training dataset, you need to modify path names in the pre-processing file accordingly.
python data_proc/grouping.py
Put each dataset in a folder under the domainbed/data
directory as follows:
data
├── ISIC2019_train
│ ├── clean
│ │ ├──ben
│ │ ├──mel
│ ├── dark_corner
│ ├── gel_bubble
│ ├── ...
The processed ISIC2019 dataset and 4 OOD testing datasets are in GoogleDrive. Please refer to our paper and its supplementary material for more details about these datasets.
You can find them in the repo https://github.com/alceubissoto/artifact-generalization-skin.
Our benchmark is modified based on DomainBed, please refer to DomainBed Readme for more details on commands running jobs.
# Training EPVT on ISIC2019
CUDA_VISIBLE_DEVICES=0 python -m domainbed/scripts/train_epvt.py --data_dir=./domainbed/data/ --steps 1501 --dataset SKIN --test_env 0 --algorithm DoPrompt_group_decompose --output_dir \
results/exp --hparams '{"lr": 5e-6, "lr_classifier": 5e-5,"batch_size":26,"wd_classifier": 1e-5, "prompt_dim":10}' --exp 'prompt_final_vis' --ood_vis True
#Test EPVT on four OOD datasets
CUDA_VISIBLE_DEVICES=0 python -m domainbed/scripts/test_epvt.py --model_name 'prompt_final_vis.pkl'
#Training ERM baseline on ISIC2019
CUDA_VISIBLE_DEVICES=1 python -m domainbed/scripts/train_erm.py --data_dir=./domainbed/data/ --steps 1501 --dataset SKIN --test_env 0 --algorithm ERM \
--output_dir results/exp --hparams '{"lr": 5e-6, "lr_classifier": 5e-5,"batch_size":26,"wd_classifier":1e-5}' --exp 'erm_baseline'
#Test ERM on four OOD datasets
CUDA_VISIBLE_DEVICES=1 python -m domainbed/scripts/test_erm.py --model_name 'erm_baseline.pkl'
@inproceedings{yan2023epvt,
title={EPVT: Environment-Aware Prompt Vision Transformer for Domain Generalization in Skin Lesion Recognition},
author={Yan, Siyuan and Liu, Chi and Yu, Zhen and Ju, Lie and Mahapatra, Dwarikanath and Mar, Victoria and Janda, Monika and Soyer, Peter and Ge, Zongyuan},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={249--259},
year={2023},
organization={Springer}
}
@article{yan2024prompt,
title={Prompt-driven Latent Domain Generalization for Medical Image Classification},
author={Yan, Siyuan and Liu, Chi and Yu, Zhen and Ju, Lie and Mahapatra, Dwarikanath and Betz-Stablein, Brigid and Mar, Victoria and Janda, Monika and Soyer, Peter and Ge, Zongyuan},
journal={arXiv preprint arXiv:2401.03002},
year={2024}
}
This code is built on DomainBed, DoPrompt, and DG_SKIN . We thank the authors for sharing their codes.