-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdefault_sets.py
92 lines (72 loc) · 2.47 KB
/
default_sets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from collections import defaultdict
import random
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import numpy as np
import time
import json
import itertools
import torch.multiprocessing
from config import Config
from logger import is_main_process
import sys
torch.multiprocessing.set_sharing_strategy('file_system')
if is_main_process():
with open(sys.argv[1], 'r') as fin:
cfg = json.load(fin)
print(cfg)
config = Config(config_file=sys.argv[1])
try:
note = config.train.note
except:
note = ''
if config.model == 'DKEC':
date = config.model + '_' + config.train.backbone + '_' + config.dataset + '_' + str(config.train.graph_layer) + '_' + note
elif config.model == 'BERT':
date = config.model + '_' + config.train.backbone + '_' + config.dataset + '_' + note
elif config.model == 'BERT_LA':
date = config.model + '_' + config.train.backbone + '_' + config.dataset + '_' + note
else:
date = config.model + '_' + config.dataset
num_train_epochs = config.train.epochs
dataset = config.dataset
task = config.task
ROOT = config.root_dir
DIR = os.path.join(ROOT, 'dataset', dataset)
if dataset == "RAA":
assert config.train.window_size == 1 or config.train.window_size == None
fname = '%s/EMS_Protocol.json' % DIR
with open(fname, 'r') as f:
label = json.load(f)
with open(os.path.join(DIR, 'hier2p.json'), 'r') as f:
hier2label = json.load(f)
with open(os.path.join(DIR, 'p2hier.json'), 'r') as f:
label2hier = json.load(f)
elif "MIMIC3" in dataset:
fname = '%s/ICD9_descriptions.json' % DIR
with open(fname, 'r') as f:
ICD9_description = json.load(f)
fname = '%s/ICD9CODES.json' % DIR
with open(fname, 'r') as f:
ICD9_DIAG = json.load(f)
label = list(ICD9_DIAG.keys())
fname = os.path.join(DIR, 'hier2p.json')
with open(fname, 'r') as f:
hier2label = json.load(f)
fname = os.path.join(DIR, 'p2hier.json')
with open(fname, 'r') as f:
label2hier = json.load(f)
else:
raise Exception('check the dataset in config')
def seed_everything(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic\
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True