-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment.py
62 lines (54 loc) · 2.59 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import print_function
import argparse
import math
import matplotlib
import sys
import train
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--nb-clusters', required=True, type=int)
parser.add_argument('--dataset', dest='dataset_selected',
choices=['isic', 'mura', 'hkvasir'], required=True)
parser.add_argument('--nb-epochs', type=int, default=300)
parser.add_argument('--finetune-epoch', type=int, default=250)
parser.add_argument('--mod-epoch', type=int, default=2)
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--sz-batch', type=int, default=32)
parser.add_argument('--sz-embedding', default=128, type=int)
parser.add_argument('--cuda-device', default=0, type=int)
parser.add_argument('--exp', default='run1', type=str, help='experiment identifier')
parser.add_argument('--dir', default='default', type=str)
parser.add_argument('--backend', default='faiss',
choices=('torch+sklearn', 'faiss', 'faiss-gpu'))
parser.add_argument('--random-seed', default=0, type=int)
parser.add_argument('--backbone-wd', default=1e-4, type=float)
parser.add_argument('--backbone-lr', default=1e-5, type=float)
parser.add_argument('--embedding-lr', default=1e-5, type=float)
parser.add_argument('--embedding-wd', default=1e-4, type=float)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('-dK', '--dyn_learner', dest='dyn_learner', action='store_true',
help='Enable dynamic K training')
args = vars(parser.parse_args())
config = train.load_config(config_name='config.json')
config['dataloader']['batch_size'] = args.pop('sz_batch')
config['dataloader']['num_workers'] = args.pop('num_workers')
config['recluster']['mod_epoch'] = args.pop('mod_epoch')
config['opt']['backbone']['lr'] = args.pop('backbone_lr')
config['opt']['backbone']['weight_decay'] = args.pop('backbone_wd')
config['opt']['embedding']['lr'] = args.pop('embedding_lr')
config['opt']['embedding']['weight_decay'] = args.pop('embedding_wd')
for k in args:
if k in config:
config[k] = args[k]
if config['nb_clusters'] == 1:
config['recluster']['enabled'] = False
config['log'] = {
'name': '{}-K-{}-M-{}-exp-{}'.format(
config['dataset_selected'],
config['nb_clusters'],
config['recluster']['mod_epoch'],
args['exp']
),
'path': 'log/{}'.format(args['dir'])
}
train.start(config)