This repository contains the implementation of the paper POODLE: Improving Few-shot Learning via Penalizing Out-of-Distribution Samples.
Duong H. Le*, Khoi D. Nguyen*, Khoi Nguyen, Quoc-Huy Tran, Rang Nguyen, Binh-Son Hua (NeurIPS 2021)
TLDR: We leverage samples from distractor classes or randomly generated noise to improve the generalization of few-shot learner.
If you find our paper/code helpful, please cite our paper:
@inproceedings{
le2021poodle,
title={{POODLE}: Improving Few-shot Learning via Penalizing Out-of-Distribution Samples},
author={Duong Hoang Le and Khoi Duc Nguyen and Khoi Nguyen and Quoc-Huy Tran and Rang Nguyen and Binh-Son Hua},
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
year={2021},
url={https://openreview.net/forum?id=wEvO8BCqZcm}
}
In this work, we propose to use out-of-distribution samples, i.e., unlabeled samples coming from outside the target classes, to improve few-shot learning. Specifically, we exploit the easily available out-of-distribution samples to drive the classifier to avoid irrelevant features by maximizing the distance from prototypes to out-of-distribution samples while minimizing that of in-distribution samples (i.e., support, query data). Our approach is simple to implement, agnostic to feature extractors, lightweight without any additional cost for pre-training, and applicable to both inductive and transductive settings. Extensive experiments on various standard benchmarks demonstrate that the proposed method consistently improves the performance of pretrained networks with different architectures.
-
Download these zipped files and put them into
./data
-
(Optional) Download pretrained checkpoints here and extract to
./results
. -
Run
init.sh
to preprocess all data.
After these steps, your folder should be organized as follow:
results/
├── cub/
├── resnet12/
├────── student_0/
├────── student_1/
├────── checkpoint.pth.tar
├────── model_best.pth.tar
├── resnet12_ssl/
├── mini/
├── resnet12/
├── resnet12_ssl/
├── ...
├── tiered/
├── resnet12/
├── resnet12_ssl/
├── ...
data/
├── images/
├── n0153282900000005.jpg
├── n0153282900000006.jpg
├── ...
├── tiered-imagenet/
├── data/
├── class_names.txt
├── ...
├── CUB_200_100/
├── attributes/
├── images/
├── ...
├── split/
├── mini/
├────── train.csv
├────── val.csv
├────── test.csv
├── tiered/
├────── train.csv
├────── val.csv
├────── test.csv
├── cub/
├────── train.csv
├────── val.csv
├────── test.csv
assets/
configs/
src/
...
To run the code:
-
Reconfigurate argument in
run.sh
(please read the comment to adjust the dataset/architecture). Quick guideline:- To train the model, remove the option
--evaluate
. - To train the model with rotation loss, add
--do-ssl
. - Note that, knowledge distillation is done after finishing training automatically.
- Set
save_path
in[resnet12 | mobilenet | wideres | ...].config
to different checkpoints for simple, rot, and rot+kd baselines for example:- simple: set
save_path
to./results/mini/resnet12
. - rot: set
save_path
to./results/mini/resnet12_ssl
. - rot+kd: set
save_path
to./results/mini/resnet12_ssl/student_1
.
- simple: set
- To train the model, remove the option
-
Run
bash run.sh
Our implementation is based on the the official implementation of Simple Shot and TIM.