-
Notifications
You must be signed in to change notification settings - Fork 3
/
simclr_finetune.py
145 lines (122 loc) · 5.53 KB
/
simclr_finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
from argparse import ArgumentParser
from pytorch_lightning import Trainer, seed_everything
from simclr_module import SimCLR
from transforms import SimCLRFinetuneTransform
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from cifar100_datamodule import CIFAR100DataModule
from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
imagenet_normalization,
stl10_normalization,
)
from pl_bolts.utils.stability import under_review
@under_review()
def cli_main(): # pragma: no cover
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
seed_everything(1234)
parser = ArgumentParser()
parser.add_argument("--dataset", type=str, help="cifar10, stl10, imagenet", default="cifar10")
parser.add_argument("--ckpt_path", type=str, help="path to ckpt")
parser.add_argument("--data_dir", type=str, help="path to dataset", default=os.getcwd())
parser.add_argument("--batch_size", default=64, type=int, help="batch size per gpu")
parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU")
parser.add_argument("--feat_dim", default=128, type=int, help="number of feat dim(256 for product loss, 128 for others)")
parser.add_argument("--gpus", default=4, type=int, help="number of GPUs")
parser.add_argument("--num_epochs", default=100, type=int, help="number of epochs")
# fine-tuner params
parser.add_argument("--in_features", type=int, default=2048)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--learning_rate", type=float, default=0.3)
parser.add_argument("--weight_decay", type=float, default=1e-6)
parser.add_argument("--nesterov", type=bool, default=False) # fix nesterov flag here
parser.add_argument("--scheduler_type", type=str, default="cosine")
parser.add_argument("--gamma", type=float, default=0.1)
parser.add_argument("--final_lr", type=float, default=0.0)
args = parser.parse_args()
if args.dataset == "cifar10" or args.dataset == "cifar100":
dm = CIFAR10DataModule(
data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers
) if args.dataset == "cifar10" else CIFAR100DataModule(
data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers
)
dm.train_transforms = SimCLRFinetuneTransform(
normalize=cifar10_normalization(), input_height=dm.dims[-1], eval_transform=False
)
dm.val_transforms = SimCLRFinetuneTransform(
normalize=cifar10_normalization(), input_height=dm.dims[-1], eval_transform=True
)
dm.test_transforms = SimCLRFinetuneTransform(
normalize=cifar10_normalization(), input_height=dm.dims[-1], eval_transform=True
)
args.maxpool1 = False
args.first_conv = False
args.num_samples = 1
elif args.dataset == "stl10":
dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
dm.train_dataloader = dm.train_dataloader_labeled
dm.val_dataloader = dm.val_dataloader_labeled
args.num_samples = 1
dm.train_transforms = SimCLRFinetuneTransform(
normalize=stl10_normalization(), input_height=dm.dims[-1], eval_transform=False
)
dm.val_transforms = SimCLRFinetuneTransform(
normalize=stl10_normalization(), input_height=dm.dims[-1], eval_transform=True
)
dm.test_transforms = SimCLRFinetuneTransform(
normalize=stl10_normalization(), input_height=dm.dims[-1], eval_transform=True
)
args.maxpool1 = False
args.first_conv = True
elif args.dataset == "imagenet":
dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers)
dm.train_transforms = SimCLRFinetuneTransform(
normalize=imagenet_normalization(), input_height=dm.dims[-1], eval_transform=False
)
dm.val_transforms = SimCLRFinetuneTransform(
normalize=imagenet_normalization(), input_height=dm.dims[-1], eval_transform=True
)
dm.test_transforms = SimCLRFinetuneTransform(
normalize=imagenet_normalization(), input_height=dm.dims[-1], eval_transform=True
)
args.num_samples = 1
args.maxpool1 = True
args.first_conv = True
else:
raise NotImplementedError("other datasets have not been implemented till now")
backbone = SimCLR(
gpus=args.gpus,
nodes=1,
num_samples=args.num_samples,
batch_size=args.batch_size,
maxpool1=args.maxpool1,
first_conv=args.first_conv,
dataset=args.dataset,
feat_dim=args.feat_dim
).load_from_checkpoint(args.ckpt_path, strict=False)
tuner = SSLFineTuner(
backbone,
in_features=args.in_features,
num_classes=dm.num_classes,
epochs=args.num_epochs,
hidden_dim=None,
dropout=args.dropout,
learning_rate=args.learning_rate,
weight_decay=args.weight_decay,
nesterov=args.nesterov,
scheduler_type=args.scheduler_type,
gamma=args.gamma,
final_lr=args.final_lr,
)
trainer = Trainer(
gpus=args.gpus,
num_nodes=1,
precision=16,
max_epochs=args.num_epochs,
accelerator="gpu",
sync_batchnorm=True if args.gpus > 1 else False,
)
trainer.fit(tuner, dm)
trainer.test(datamodule=dm)
if __name__ == "__main__":
cli_main()