This repository provides the code to classify images in two different categories, i.e. Similar (1) and Dissimilar (0) based on the image similarity task performed by utilizing a Contrastive Learning-based approach (including employing a custom contrastive loss). Furthermore, Siamese Networks is being used in n-way k-shot settings considered in the current implementation.
Python 3.9
PyTorch 1.10.2
TorchVision 0.11.3
numpy 1.22.3
matplotlib 3.5.1
Omniglot dataset is being used which is a collection of 1623 hand drawn characters from 50 different alphabets. For every character there are just 20 examples, each drawn by a different person. Each image is a gray scale image of resolution 105x105. Please clone this repo and then extract the images_background
and images_evaluation
folders. Finally, run DataGeneration.py
file to create pickle files train.pickle
and val.pickle
files and store them in data
folder. Here, train.pickle
file contains characters from 30 different alphabets, whereas val.pickle
contains characters from remaining 20 different alphabets.
- The
SiameseNetwork
model class for n-way k-shot learning can be found inModel.py
file. - To train the network, run
Training.py
file. - The average loss for the trained model is printed after every epoch.
- All hyperparameters to control training and testing of the model are provided in the given
Training.py
file.
Image Comparison 1 | Image Comparison 2 | Image Comparion 3 |
---|---|---|
Image Comparison 4 | Image Comparison 5 | Image Comparion 6 |
---|---|---|
Image Comparison 7 | Image Comparison 8 | Image Comparion 9 | Image Comparison 10 |
---|---|---|---|
Image Comparison 1 | Image Comparison 2 | Image Comparion 3 |
---|---|---|
Among all the 10 comparisons made under Image Similarity Scores sub-section, images 1, 6, and 8 appear more similar, thereby having predicted labels as 1, as shown in the Results for Image Classification sub-section. This way, the current implementation frames the image similarity task as the image classification task.