-
Notifications
You must be signed in to change notification settings - Fork 18
/
test_fast.py
248 lines (225 loc) · 11.4 KB
/
test_fast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import argparse
import torch
import torch.nn.functional as F
from tqdm import tqdm
import data_loader.data_loaders as module_data
import model.loss as module_loss
import model.metric as module_metric
import model.model as module_arch
from parse_config import ConfigParser
from functools import partial
import dgl
import numpy as np
import itertools
def rearrange(energy_scores, candidate_position_idx, parent_position_idx):
tmp = np.isin(candidate_position_idx, parent_position_idx)
correct = np.where(tmp)[0]
incorrect = np.where(~tmp)[0]
labels = torch.cat((torch.ones(len(correct)), torch.zeros(len(incorrect)))).int()
energy_scores = torch.cat((energy_scores[correct,:], energy_scores[incorrect,:]))
return energy_scores, labels
def encode_graph(model, bg, h, pos):
bg.ndata['h'] = model.graph_propagate(bg, h)
hg = model.readout(bg, pos)
return hg
def main(config, args_outer):
logger = config.get_logger('test')
# case_study or not
need_case_study = (args_outer.case != "")
if need_case_study:
logger.info(f"save case study results to {args_outer.case}")
else:
logger.info("no need to save case study results")
# setup multiprocessing instance
torch.multiprocessing.set_sharing_strategy('file_system')
# setup data_loader instances
if args_outer.test_data == "":
test_data_path = config['test_data_loader']['args']['data_path']
else:
test_data_path = args_outer.test_data
test_data_loader = module_data.MaskedGraphDataLoader(
mode="test",
data_path=test_data_path,
sampling_mode=0,
batch_size=1,
expand_factor=config['test_data_loader']['args']['expand_factor'],
shuffle=True,
num_workers=8,
batch_type="large_batch",
cache_refresh_time=config['test_data_loader']['args']['cache_refresh_time'],
normalize_embed=config['test_data_loader']['args']['normalize_embed'],
test_topk=args_outer.topk
)
logger.info(test_data_loader)
# build model architecture
model = config.initialize('arch', module_arch)
logger.info(model)
# get function handles of loss and metrics
metric_fns = [getattr(module_metric, met) for met in config['metrics']]
if config['loss'].startswith("info_nce"):
pre_metric = partial(module_metric.obtain_ranks, mode=1) # info_nce_loss
else:
pre_metric = partial(module_metric.obtain_ranks, mode=0)
logger.info('Loading checkpoint: {} ...'.format(config.resume))
checkpoint = torch.load(config.resume)
state_dict = checkpoint['state_dict']
if config['n_gpu'] > 1:
model = torch.nn.DataParallel(model)
model.load_state_dict(state_dict)
# prepare model for testing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
test_dataset = test_data_loader.dataset
kv = test_dataset.kv
vocab = test_dataset.node_list
if need_case_study:
indice2word = test_dataset.vocab
node2parents = test_dataset.node2parents
candidate_positions = sorted(list(test_dataset.all_positions))
logger.info(f"Number of queries: {len(vocab)}")
anchor2subgraph = {}
for anchor in tqdm(candidate_positions):
anchor2subgraph[anchor] = test_dataset._get_subgraph(-1, anchor, 0)
if args_outer.batch_size == -1: # small dataset with only one batch
logger.info('Small batch mode')
# obtain graph representation
bg = dgl.batch([v for k,v in anchor2subgraph.items()])
h = bg.ndata.pop('x').to(device)
candidate_position_idx = bg.ndata['_id'][bg.ndata['pos']==1].tolist()
n_position = len(candidate_position_idx)
pos = bg.ndata['pos'].to(device)
with torch.no_grad():
hg = encode_graph(model, bg, h, pos)
# start per query prediction
total_metrics = torch.zeros(len(metric_fns))
if need_case_study:
all_cases = []
all_cases.append(["Test node index", "True parents", "Predicted parents"] + [fn.__name__ for fn in metric_fns])
with torch.no_grad():
for i, query in tqdm(enumerate(vocab)):
if need_case_study:
cur_case = [indice2word[query]]
true_parents = ", ".join([indice2word[ele] for ele in node2parents[query]])
cur_case.append(true_parents)
nf = torch.tensor(kv[str(query)], dtype=torch.float32).to(device)
expanded_nf = nf.expand(n_position, -1)
energy_scores = model.match(hg, expanded_nf)
if need_case_study: # select top-5 predicted parents
predicted_scores = energy_scores.cpu().squeeze_().tolist()
if config['loss'].startswith("info_nce"):
predict_parent_idx_list = [candidate_position_idx[ele[0]] for ele in sorted(enumerate(predicted_scores), key=lambda x:-x[1])[:5]]
else:
predict_parent_idx_list = [candidate_position_idx[ele[0]] for ele in sorted(enumerate(predicted_scores), key=lambda x:x[1])[:5]]
predict_parents = ", ".join([indice2word[ele] for ele in predict_parent_idx_list])
cur_case.append(predict_parents)
energy_scores, labels = rearrange(energy_scores, candidate_position_idx, node2parents[query])
all_ranks = pre_metric(energy_scores, labels)
for j, metric in enumerate(metric_fns):
tmp = metric(all_ranks)
total_metrics[j] += tmp
if need_case_study:
cur_case.append(str(tmp))
if need_case_study:
all_cases.append(cur_case)
# save case study results to file
if need_case_study:
with open(args_outer.case, "w") as fout:
for ele in all_cases:
fout.write("\t".join(ele))
fout.write("\n")
else: # large dataset with many batches
# obtain graph representation
logger.info(f'Large batch mode with batch_size = {args_outer.batch_size}')
batched_hg = [] # save the CPU graph representation
batched_positions = []
bg = []
positions = []
with torch.no_grad():
for i, (anchor, egonet) in tqdm(enumerate(anchor2subgraph.items()), desc="Generating graph encoding ..."):
positions.append(anchor)
bg.append(egonet)
if (i+1) % args_outer.batch_size == 0:
bg = dgl.batch(bg)
h = bg.ndata.pop('x').to(device)
pos = bg.ndata['pos'].to(device)
hg = encode_graph(model, bg, h, pos)
assert hg.shape[0] == len(positions), f"mismatch between hg.shape[0]: {hg.shape[0]} and len(positions): {len(positions)}"
batched_hg.append(hg.cpu())
batched_positions.append(positions)
bg = []
positions = []
del h
if len(bg) != 0:
bg = dgl.batch(bg)
h = bg.ndata.pop('x').to(device)
pos = bg.ndata['pos'].to(device)
hg = encode_graph(model, bg, h, pos)
assert hg.shape[0] == len(positions), f"mismatch between hg.shape[0]: {hg.shape[0]} and len(positions): {len(positions)}"
batched_hg.append(hg.cpu())
batched_positions.append(positions)
del h
# start per query prediction
total_metrics = torch.zeros(len(metric_fns))
if need_case_study:
all_cases = []
all_cases.append(["Test node index", "True parents", "Predicted parents"] + [fn.__name__ for fn in metric_fns])
candidate_position_idx = list(itertools.chain(*batched_positions))
batched_hg = [hg.to(device) for hg in batched_hg]
with torch.no_grad():
for i, query in tqdm(enumerate(vocab)):
if need_case_study:
cur_case = [indice2word[query]]
true_parents = ", ".join([indice2word[ele] for ele in node2parents[query]])
cur_case.append(true_parents)
nf = torch.tensor(kv[str(query)], dtype=torch.float32).to(device)
batched_energy_scores = []
for hg, positions in zip(batched_hg, batched_positions):
n_position = len(positions)
expanded_nf = nf.expand(n_position, -1)
energy_scores = model.match(hg, expanded_nf) # a tensor of size (n_position, 1)
batched_energy_scores.append(energy_scores)
batched_energy_scores = torch.cat(batched_energy_scores)
if need_case_study:
predicted_scores = batched_energy_scores.cpu().squeeze_().tolist()
if config['loss'].startswith("info_nce"):
predict_parent_idx_list = [candidate_position_idx[ele[0]] for ele in sorted(enumerate(predicted_scores), key=lambda x:-x[1])[:5]]
else:
predict_parent_idx_list = [candidate_position_idx[ele[0]] for ele in sorted(enumerate(predicted_scores), key=lambda x:x[1])[:5]]
predict_parents = ", ".join([indice2word[ele] for ele in predict_parent_idx_list])
cur_case.append(predict_parents)
batched_energy_scores, labels = rearrange(batched_energy_scores, candidate_position_idx, node2parents[query])
all_ranks = pre_metric(batched_energy_scores, labels)
for j, metric in enumerate(metric_fns):
tmp = metric(all_ranks)
total_metrics[j] += tmp
if need_case_study:
cur_case.append(str(tmp))
if need_case_study:
all_cases.append(cur_case)
# save case study results to file
if need_case_study:
with open(args_outer.case, "w") as fout:
for ele in all_cases:
fout.write("\t".join(ele))
fout.write("\n")
n_samples = test_data_loader.n_samples
log = {}
log.update({
met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)
})
log.update({
"test_topk": test_data_loader.dataset.test_topk
})
logger.info(log)
if __name__ == '__main__':
args = argparse.ArgumentParser(description='Testing taxonomy expansion model')
args.add_argument('-td', '--test_data', default="", type=str, help='test data path, if not provided, we assume the test data is specificed in the config file')
args.add_argument('-r', '--resume', required=True, type=str, help='path to latest checkpoint')
args.add_argument('-d', '--device', default=None, type=str, help='indices of GPUs to enable (default: all)')
args.add_argument('-k', '--topk', default=-1, type=int, help='topk retrieved instances for testing, -1 means no retrieval stage (default: -1)')
args.add_argument('-b', '--batch_size', default=-1, type=int, help='batch size, -1 for small dataset (default: -1), 30000 for larger MAG-Full data')
args.add_argument('-c', '--case', default="", type=str, help='case study saving file, if is "", no need to get case studies (default: "")')
args_outer = args.parse_args()
config = ConfigParser(args)
main(config, args_outer)