Skip to content

Spectra456/MLT_Pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multi-task learning Pytorch

Simple multi-task learning example on 2 different datasets. For all our tests we using EfficentNetB0 with input size 3x32x32 P.S. Sorry for my code-style, i don't have enough time(((

Cifar-10 (Single model)

Best Accuracy(top-1): 88.7%

Accuracy(top-1) and Loss function: Top-1 Accuracy Loss

Confusion matrix Confusion matrix

FashionRGB (Single model)

It's a dataset from my previous test task. Here we have 5 classes: Blouse, Dress, Jeans, Skirt, Tank, 5k rgb images in train set and 5k images in val set

Bloose: Bloose

Dress: Dress

Jeans: Jeans

Skirt: Skirt

Tank: Tank

Best Accuracy(top-1): 77.4%

Accuracy(top-1) and Loss function: Accuracy(top-1) Loss

Confusion Matrix Confusion matrix

Multi-Task Model

Best Accuracy(top-1) FashionRGB: 80.3% Best Accuracy(top-1) Cifar10: 80.6%

I tried several methods of MLT, but there always some disbalance between datasets(because Cifar10 have 10x more images in train set). Here i used my own solution, where:

for i in range(len(train_loader_cifar):
    if (bool(random.getrandbits(1)) ==  True):
	    then train Cifar10 layer
    if (bool(random.getrandbits(1)) ==  True):
	    then train FashionRGB layer

It's help us keeping balance between metrics of our datasets

Accuracy(top-1) and Loss function:

Accuracy(top-1)

Loss

Confusion matrix Fashion-RGB Fashion-RGB Confusion matrix Cifar-10 Cifar-10 Confusion matrix

Conclusion

Model Accuracy
Cifar-10 88.7 %
FashionRGB 77.4 %
MLT(Cifar-10) 80.6 %(-8.1%)
MLT(FashionRGB) 80.3 %(+2.9%)

Run

  1. Download FashionRGB dataset from Google Drive.

  2. Unzip archive into dataset folder of this project.

  3. python train_simple.py --dataset FashionRGB

  4. python train_simple.py --dataset Cifar10

  5. python train_multi.py You can check all arguments inside this two python scripts.

  6. Launch tensorboard (optional) If you want to check my experiments in tensorboard

tensorboard --logdir logs/old

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Languages