This is a simplified Scratch Pytorch Implementation of Vision Transformer (ViT) with detailed Steps (Refer to model.py)
- The default network is a scaled-down version of the original ViT architecture from the ViT Paper.
- Has only 200k-800k parameters depending upon the embedding dimension (Original ViT-Base has 86 million).
- Tested on MNIST, FashionMNIST, SVHN, CIFAR10, and CIFAR100 datasets.
- Uses a smaller patch size of 4.
- Can be used with bigger datasets by increasing the model parameters and patch size.
Run commands (also available in scripts.sh):
Dataset | Run command | Test Acc |
---|---|---|
MNIST | python main.py --dataset mnist --epochs 100 | 99.5 |
Fashion MNIST | python main.py --dataset fmnist | 92.3 |
SVHN | python main.py --dataset svhn --n_channels 3 --image_size 32 --embed_dim 128 | 96.2 |
CIFAR10 | python main.py --dataset cifar10 --n_channels 3 --image_size 32 --embed_dim 128 | 86.3 (82.5 w/o RandAug) |
CIFAR100 | python main.py --dataset cifar100 --n_channels 3 --image_size 32 --embed_dim 128 | 59.6 (55.8 w/o RandAug) |
Config | MNIST and FMNIST | SVHN and CIFAR |
---|---|---|
Input Size | 1 X 28 X 28 | 3 X 32 X 32 |
Patch Size | 4 | 4 |
Sequence Length | 7*7 = 49 | 8*8 = 64 |
Embedding Size | 64 | 128 |
Parameters | 210k | 820k |
Num of Layers | 6 | 6 |
Num of Heads | 4 | 4 |
Forward Multiplier | 2 | 2 |
Dropout | 0.1 | 0.1 |