Skip to content

Latest commit

 

History

History
63 lines (36 loc) · 2.25 KB

README.md

File metadata and controls

63 lines (36 loc) · 2.25 KB

ImageNet training in PyTorch - Residual Attention Network

version phamquiluan/ResidualAttentionNetwork phamquiluan/ResidualAttentionNetwork

This implements training of Residual Attention Network on the ImageNet dataset, and provide the pretrained weights.

Install

pip install 'git+ssh://git@github.com/phamquiluan/ResidualAttentionNetwork.git@v0.2.0'

Quickstart

import torch
from resattnet import resattnet56

m = resattnet56(in_channels=3, num_classes=10)  # pretrained is load automatically

tensor = torch.Tensor(1, 3, 224, 224)

output = m(tensor)

print(output.shape)  # torch.Size([1, 10])

Pretrained Download

Download resattnet56 pretrained Imagenet1K: link

Eval: Acc@1 77.024 Acc@5 93.574

Training

To train a model, run main.py with the desired model architecture and the path to the ImageNet dataset:

python main.py -a resattnet56 [imagenet-folder with train and val folders]

Multi-processing Distributed Data Parallel Training

You should always use the NCCL backend for multi-processing distributed training since it currently provides the best distributed training performance.

Single node, multiple GPUs:

python main.py -a resattnet56 --dist-url 'tcp://127.0.0.1:FREEPORT' --dist-backend 'nccl' --multiprocessing-distributed --world-size 1 --rank 0 [imagenet-folder with train and val folders]