-
Notifications
You must be signed in to change notification settings - Fork 12
/
set_up.py
97 lines (77 loc) · 3.05 KB
/
set_up.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import os
import pprint as pp
import random
from datetime import date
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import wandb
def fix_random_seed_as(random_seed):
if random_seed == -1:
random_seed = np.random.randint(100000)
print("RANDOM SEED: {}".format(random_seed))
random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
cudnn.deterministic = True
cudnn.benchmark = False
return random_seed
def _get_experiment_index(experiment_path):
idx = 0
while os.path.exists(experiment_path + "_" + str(idx)):
idx += 1
return idx
def create_experiment_export_folder(experiment_dir, experiment_description):
print(os.path.abspath(experiment_dir))
if not os.path.exists(experiment_dir):
os.mkdir(experiment_dir)
experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description)
print(os.path.abspath(experiment_path))
os.mkdir(experiment_path)
print("folder created: " + os.path.abspath(experiment_path))
return experiment_path
def get_name_of_experiment_path(experiment_dir, experiment_description):
experiment_path = os.path.join(experiment_dir, (experiment_description + "_" + str(date.today())))
idx = _get_experiment_index(experiment_path)
experiment_path = experiment_path + "_" + str(idx)
return experiment_path
def export_config_as_json(config, experiment_path):
with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile:
json.dump(config, outfile, indent=2)
def generate_tags(config):
tags = []
tags.append(config.get('generator', config.get('text_encoder')))
tags.append(config.get('trainer'))
tags = [tag for tag in tags if tag is not None]
return tags
def set_up_gpu(device_idx):
if device_idx:
os.environ['CUDA_VISIBLE_DEVICES'] = device_idx
return {
'num_gpu': len(device_idx.split(","))
}
else:
idxs = os.environ['CUDA_VISIBLE_DEVICES']
return {
'num_gpu': len(idxs.split(","))
}
def setup_experiment(config):
device_info = set_up_gpu(config['device_idx'])
config.update(device_info)
random_seed = fix_random_seed_as(config['random_seed'])
config['random_seed'] = random_seed
export_root = create_experiment_export_folder(config['experiment_dir'], config['experiment_description'])
export_config_as_json(config, export_root)
config['export_root'] = export_root
pp.pprint(config, width=1)
os.environ['WANDB_SILENT'] = "true"
tags = generate_tags(config)
project_name = config['wandb_project_name']
wandb_account_name = config['wandb_account_name']
experiment_name = config['experiment_description']
experiment_name = experiment_name if config['random_seed'] != -1 else experiment_name + "_{}".format(random_seed)
wandb.init(config=config, name=experiment_name, project=project_name,
entity=wandb_account_name, tags=tags)
return export_root, config