Pytorch implementation of triplet networks for metric learning
This package requires Pytorch version 1.4.0 and TorchVision 0.5.0
- GPU implementation of online triplet loss in a way similar to pytorch loss
- Implements 1-1 sampling strategy as defined in [1]
- Random semi-hard and fixed semi-hard sampling
- UMAP visualization of the results
- Implementation of training strategy to train a classifier after learning the embeddings.
- Implementation of stratified sampling strategy for the batches.
- Implemented on MNIST dataset as an example.
networks.py
- ConvNet class - base network for embedding images in vectors and getting labels
loss.py
- OnlineTripletLoss - triplet loss class for embeddings
- NegativeTripletSelector - class for selecting the negative sample from the batch based on the sampling strategy.
train.py
- TripletTrainer - class for training the dataset with triplet loss and a classifier after it if required.
utils.py
- make_weights_for_balanced_classes - assign weight to every sample in dataset for batch sampling.
- save_embedding_umap - save UMAPs of the training set and test set.
config.json
- hyperparameters for the training
[1] Theoretical Guarantees of Deep Embedding Losses Under Label Noise