-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_training.py
90 lines (83 loc) · 3.27 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
from train import train_relic
parser = argparse.ArgumentParser(description='ReLIC')
parser.add_argument('--dataset_path',
default='./data',
help='Path where datasets will be saved')
parser.add_argument('--dataset_name',
default='stl10',
help='Dataset name',
choices=['stl10', 'cifar10', "tiny_imagenet", "food101", "imagenet1k"])
parser.add_argument(
'-m',
'--encoder_model_name',
default='resnet18',
choices=['resnet18', 'resnet50', "efficientnet"],
help=
'model architecture: resnet18, resnet50 or efficientnet (default: resnet18)'
)
parser.add_argument('-save_model_dir',
default='./models',
help='Path where models')
parser.add_argument('--num_epochs',
default=100,
type=int,
help='Number of epochs for training')
parser.add_argument('-b',
'--batch_size',
default=256,
type=int,
help='Batch size')
parser.add_argument('-lr', '--learning_rate', default=3e-4, type=float)
parser.add_argument('-wd', '--weight_decay', default=1e-5, type=float)
parser.add_argument('--fp16_precision',
action='store_true',
help='Whether to use 16-bit precision for GPU training')
parser.add_argument('--proj_out_dim',
default=64,
type=int,
help='Projector MLP out dimension')
parser.add_argument('--proj_hidden_dim',
default=512,
type=int,
help='Projector MLP hidden dimension')
parser.add_argument('--log_every_n_steps',
default=400,
type=int,
help='Log every n steps')
parser.add_argument('--gamma',
default=0.995,
type=float,
help='Initial EMA coefficient')
parser.add_argument('--alpha',
default=1.0,
type=float,
help='Regularization loss factor')
parser.add_argument('--update_gamma_after_step',
default=1,
type=int,
help='Update EMA gamma after this step')
parser.add_argument('--update_gamma_every_n_steps',
default=1,
type=int,
help='Update EMA gamma after this many steps')
parser.add_argument('--ckpt_path',
default=None,
type=str,
help='Specify path to relic_model.pth to resume training')
parser.add_argument('--use_siglip',
action='store_true',
help='Whether to use siglip loss')
parser.add_argument('--num_global_views',
default=2,
type=int,
help='Number of global (large) views to generate through augmentation')
parser.add_argument('--num_local_views',
default=4,
type=int,
help='Number of local (small) views to generate through augmentation')
def main():
args = parser.parse_args()
train_relic(args)
if __name__ == "__main__":
main()