-
Notifications
You must be signed in to change notification settings - Fork 5
/
path_handler.py
63 lines (52 loc) · 1.9 KB
/
path_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import pathlib
def model_path(config, root="./saved"):
root = pathlib.Path(root)
filename = "{}".format(config.dataset)
# Dataset-specific keys
if config.dataset in ["CIFAR10"]:
filename += "_augm_{}".format(
config.augment,
)
# Model-specific keys
filename += "_model_{}".format(
config.model,
)
if "sa" in config.model:
filename += "_type_{}".format(config.attention_type)
if config.attention_type == "Local":
filename += "_patch_{}".format(config.patch_size)
filename += "_dpatt_{}_dpval_{}_activ_{}_norm_{}_white_{}".format(
config.dropout_att,
config.dropout_values,
config.activation_function,
config.norm_type,
config.whitening_scale,
)
# Optimization arguments
filename += "_optim_{}".format(config.optimizer)
if config.optimizer == "SGD":
filename += "_momentum_{}".format(config.optimizer_momentum)
filename += "_lr_{}_bs_{}_ep_{}_wd_{}_seed_{}_sched_{}".format(
config.lr,
config.batch_size,
config.epochs,
config.weight_decay,
config.seed,
config.scheduler,
)
if config.scheduler not in ["constant", "linear_warmup_cosine"]:
filename += "_schdec_{}".format(config.sched_decay_factor)
if config.scheduler == "multistep":
filename += "_schsteps_{}".format(config.sched_decay_steps)
# Comment
if config.comment != "":
filename += "_comment_{}".format(config.comment)
# Add correct termination
filename += ".pt"
# Check if directory exists and warn the user if the it exists and train is used.
os.makedirs(root, exist_ok=True)
path = root / filename
config.path = str(path)
if config.train and path.exists():
print("WARNING! The model exists in directory and will be overwritten")