If you make use of this code, please cite the following paper (and give us a star ✨):
@InProceedings{sgada2021,
author = {Akkaya, Ibrahim Batuhan and Altinel, Fazil and Halici, Ugur},
title = {Self-Training Guided Adversarial Domain Adaptation for Thermal Imagery},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2021},
pages = {4322-4331}
}
This repository contains official implementation of "Self-training Guided Adversarial Domain Adaptation For Thermal Imagery" paper (accepted to CVPR 2021 Perception Beyond the Visible Spectrum (PBVS) workshop).
- Python 3.8.5
- PyTorch 1.6.0
To install the environment using Conda:
$ conda env create -f requirements_conda.yml
This command creates a Conda environment named sgada
. The environment includes all necessary packages for training of SGADA method. After installation of the environment, activate it using the command below:
$ conda activate sgada
Before running the training code, make sure that DATASETDIR
environment variable is set to your dataset directory.
$ export DATASETDIR="/path/to/dataset/dir"
- Download FLIR ADAS dataset: Link
- Download MS-COCO dataset:
- After you downloaded the datasets, extract them to
DATASETDIR
. - Crop annotated objects (for bicycle, car and person classes only) using the command below:
(sgada) $ python utils/prepare_dataset.py
After the preparation steps, your dataset folder should be in the following structure.
DATASETDIR
└── sgada_data
├── flir
│ ├── train
│ │ ├── bicycle
│ │ ├── car
│ │ └── person
│ ├── val
│ │ ├── bicycle
│ │ ├── car
│ │ └── person
│ ├── test_wconf_wdomain_weights.txt
│ └── validation_wconf_wdomain_weights.txt
└── mscoco
├── train
│ ├── bicycle
│ ├── car
│ └── person
└── val
├── bicycle
├── car
└── person
test_wconf_wdomain_weights.txt
and validation_wconf_wdomain_weights.txt
files can be found here. Place them under DATASETDIR/sgada_data/flir/
. These files have the fields below.
filePath, classifierPrediction, classifierConfidence, discriminatorPrediction, discriminatorConfidence, sampleWeight
If you want to generate pseudo-labelling files by yourself, your pseudo-labelling files should follow the given order. In order to obtain confidences and predictions, you can follow the training scheme in ADDA.PyTorch-resnet.
[Optional]
Follow the source only training scheme in ADDA.PyTorch-resnet and save the model file.- If you want to use this source only model file, skip to the Step 4.
- Download the model file trained on source only dataset. Link
- Extract the compressed file.
- To train SGADA, run the command below.
(sgada) $ python core/sgada_domain.py \
--trained [PATH] \
--lr [LR] --d_lr [D_LR] --batch_size [BS] \
--lam [LAM] --thr [THR] --thr_domain [THR_DOMAIN] \
--device cuda:[GPU_ID]
Parameter Name | Type | Definition |
---|---|---|
[PATH] |
str |
Path to the source only model file generated in Step 1 or downloaded in Step 2 |
[LR] |
float |
Learning rate |
[D_LR] |
float |
Discriminator learning rate |
[BS] |
int |
Batch size |
[LAM] |
float |
Trade-off parameter |
[THR] |
float |
Classifier threshold |
[THR_DOMAIN] |
float |
Discriminator threshold |
[GPU_ID] |
int |
GPU device ID |
(sgada) $ python core/sgada_domain.py \
--trained /mnt/sgada_model_files/best_model.pt \
--lr 1e-5 --d_lr 1e-3 --batch_size 32 \
--lam 0.25 --thr 0.79 --thr_domain 0.87 \
--device cuda:3
This repo is mostly based on: