Skip to content

Latest commit

 

History

History
74 lines (56 loc) · 2.7 KB

README.md

File metadata and controls

74 lines (56 loc) · 2.7 KB

MoPro: Webly Supervised Learning with Momentum Prototypes (Salesforce Research)

This is a PyTorch implementation of the MoPro paper (Blog post):

@article{MoPro,
	title={MoPro: Webly Supervised Learning with Momentum Prototypes},
	author={Junnan Li and Caiming Xiong and Steven C.H. Hoi},
	journal={ICLR},
	year={2021}
}

Requirements:

Training

This implementation currently only supports multi-gpu, DistributedDataParallel training, which is faster and simpler.

To perform webly-supervised training of a ResNet-50 model on WebVision V1.0 using a 4-gpu or 8-gpu machine, run:

python train.py \ 
  --data [WebVision folder] \ 
  --exp-dir experiment/MoPro\
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0

Download MoPro Pre-trained ResNet-50 Models

WebVision V1 WebVision v2

Noise Cleaning

python noise_cleaning.py --data [WebVision folder] --resume [pre-trained model path] --annotation pseudo_label.json

Classifier Retraining on WebVision

python classifier_retrain.py --data [WebVision folder] --imagenet [ImageNet folder]\ 
  --resume [pre-trained model path] --annotation pseudo_label.json --exp-dir experiment/cRT\
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 

Fine-tuning on ImageNet (1% of labeled data)

python finetune_imagenet.py \
  --data [ImageNet path] \
  --model-path [pre-trained model path] \
  --exp-dir experiment/Finetune \
  --low-resource 0.01 \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 

Result for WebVision-V1 pre-trained model:

Percentage 1% 10%
Accuracy 71.2 74.8

Linear SVM Evaluation on VOC or Places

python lowshot_svm.py --model_path [your pretrained model] --dataset VOC --voc-path [VOC data path]

Result for WebVision-V1 pre-trained model:

VOC k=1 k=2 k=4 k=8 k=16
mAP 59.5 71.3 76.5 81.4 83.7
Places k=1 k=2 k=4 k=8 k=16
Acc 16.9 23.2 29.2 34.5 38.7