This repository is heavily based on Oscar Knagg's few-shot learning implementation github.com/oscarknagg/few-shot, focused on applying simple but strong Prototpyical Networks to fine grained classification task.
Main contributions this repository provides:
- Practical application of few-shot machine learning system ready to real world fine-grained classification problems.
- Transfer learning ready to make quick training possible. Using ImageNet pre-trained models by default, or any networks even non-CNN are available.
- Proved in a fairly difficult Kaggle competition that ImageNet pretrained model works fine as core model of Prototypical Networks.
Unlike very clean original implementation, this repository contains some dirty code to quickly present sample solution to a Kaggle competition "Humpback Whale Identification".
Some of submission code borrows functions from Radek Osmulski's github repository.
I'd like to express sincere appreciation to both Oscar Knagg and Radek Osmulski. Thank you.
Prototypical Networks was proposed in the paper Prototpyical Networks for Few-shot Learning (Snell et al), which calculates prototype as a central point of class in Euclidean space, then test samples can be simply classified by measuring distances to the class prototypes.
In Prototypical Networks, model learns all the non-linearity. It encapsulates everything in between non-linear inputs and linear outputs, system design and training algorithm make it all possible.
Figure from original paper. Color circles: training samples,
What Prototypical Networks scheme trains model is metrics in Euclidean space, this makes it quite handy tool for real world engineering.
Here's summary of nice traits for machine learning practitioners:
- Explainable: It discriminates classes in multi-dimensional Euclidean space, which many old fashioned engineers are familiar with. This is important so that we can explain to non-ML project stakeholders and finally bring the model to the real world projects. It’s not even cosine distance, just a conventional distance.
- Customizable: Any model can be used, so it is applicable to any problem; model is simply trained to map input data points to output data points in Euclidean space so that all classes can be distinguished by old fashioned distance.
- Few-shot ready: It works with long tail problems where very small number of samples are available with some classes, as well as imbalance of samples between classes. It is (almost as of now) proven in a Kaggle competition "Humpback Whale Identification".
- Easy to train: (I think) this is almost free from difficult and computationally intensive hard mining that selects training samples to make it difficulter as training goes.
This project derives prerequisite below:
This project is written in python 3.6 and Pytorch and assumes you have
a GPU.
-
Install dl-cliche from github, excuse me this is my almost-private library to repeat cliche code.
pip install git+https://github.com/daisukelab/dl-cliche.git@master --upgrade
-
Install albumentations.
-
Edit the
DATA_PATH
variable inconfig.py
to the location where you downloaded dataset copy from Kaggle. -
Open and run
app/whale/Example_Humpback_Whale_Identification.ipynb
to reproduce whale identification solution.
- Very simple design for both networks and training algorithm.
- All non-linearity can be learned by the model.
- Independent from model design, we can choose arbitrary networks best fit to our problem.
- Embeddings produced by the learnt model are simple data points in multi-dimensional Euclidean space where distances between data points are quite simply calculated.
- Training is easier comparing to Siamese networks for example.
- Less sensitive to class imbalance, training algorithm always picks equal number of samples from k-classes.
- Test time augmentation can be naturally applied for both getting prototypes and test samples' embeddings.
But
- Number of classes ProtoNets can train is mainly limited by memory size. Single GTX1080Ti can handle up to 20 classes for 1 shot with 384x384 images for example.
- As far as I have tried, more k-way (k-classes) results in better performance, and it is limited by memory as written above.
- Augmentation matters.
- Image size also matters.
- TTA pushes score.
- and more...
- Original paper: Prototpyical Networks for Few-shot Learning (Snell et al).
- Oscar Knagg's article: Theory and concepts
- Oscar Knagg's article: Discussion of implementation details
- Radek Osmulski's post on Kaggle discussion: [LB 0.760] Fastai Starter Pack
- Radek Osmulski's github repository: Humpback Whale Identification Competition Starter Pack