forked from nikopj/DGCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
make_args.py
executable file
·91 lines (79 loc) · 1.88 KB
/
make_args.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
import sys, json
from os.path import join
from pprint import pprint
import numpy as np
def write_args(arg_dict, name):
with open(join("args",name+".json"), '+w') as outfile:
outfile.write(json.dumps(arg_dict, indent=4, sort_keys=True))
args_file = open("args_gcdlnet.json")
args = json.load(args_file)
args_file.close()
loop_args = {
"nf": [32, 64],
"topK": [None, 8],
}
args["model"] = {
"nic": 1,
"nf": 64,
"ks": 7,
"iters": 10,
"window_size": 32,
"topK": 8,
"rank": 11,
"circ_rows": None,
"leak": 0.2
}
args["train"] = {
"loaders": {
"batch_size": 4,
"crop_size": 32,
"load_color": False,
"trn_path_list": ["CBSD432"],
"val_path_list": ["Set12"],
"tst_path_list": ["CBSD68"]
},
"fit": {
"epochs": 3000,
"noise_std": 25,
"val_freq": 25,
"save_freq": 5,
"backtrack_thresh": 0.5,
"verbose": False,
"clip_grad": 5e-2
},
"opt": {
"lr": 1e-3
},
"sched": {
"gamma": 0.95,
"step_size": 25
}
}
args['type'] = "GCDLNet"
args['paths']['ckpt'] = None
vnum = 0
name = "nf_topK"
def product(*args, repeat=1):
# product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
pools = [tuple(pool) for pool in args] * repeat
result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]
for prod in result:
yield tuple(prod)
keys = list(loop_args.keys())
with open(f"Models/{args['type']}-{name}.summary", "a") as summary:
for items in product(*[loop_args[k] for k in keys]):
for i, it in enumerate(items):
if keys[i] in args['model']:
args['model'][keys[i]] = it
elif keys[i] in args['train']['fit']:
args['train']['fit'][keys[i]] = it
version = args['type']+"-" + name + "-" + str(vnum)
args['paths']['save'] = "Models/" + version
write_args(args, version)
print(f'{version}: {items}')
summary.write(f'{version}: {items}\n')
vnum += 1