-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
101 lines (86 loc) · 3.66 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import json
import tqdm
import torch
import torch.utils.data
from typing import List
from collections import OrderedDict
from doc import collate, Example, Dataset
from config import args
from models import build_model
from utils import AttrDict, move_to_cuda
from dict_hub import build_tokenizer
from logger_config import logger
class BertPredictor:
def __init__(self):
self.model = None
self.train_args = AttrDict()
self.use_cuda = False
def load(self, ckt_path, use_data_parallel=False):
assert os.path.exists(ckt_path)
ckt_dict = torch.load(ckt_path, map_location=lambda storage, loc: storage)
self.train_args.__dict__ = ckt_dict['args']
self._setup_args()
build_tokenizer(self.train_args)
self.model = build_model(self.train_args)
# DataParallel will introduce 'module.' prefix
state_dict = ckt_dict['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
k = k[len('module.'):]
new_state_dict[k] = v
self.model.load_state_dict(new_state_dict, strict=True)
self.model.eval()
if use_data_parallel and torch.cuda.device_count() > 1:
logger.info('Use data parallel predictor')
self.model = torch.nn.DataParallel(self.model).cuda()
self.use_cuda = True
elif torch.cuda.is_available():
self.model.cuda()
self.use_cuda = True
logger.info('Load model from {} successfully'.format(ckt_path))
def _setup_args(self):
for k, v in args.__dict__.items():
if k not in self.train_args.__dict__:
logger.info('Set default attribute: {}={}'.format(k, v))
self.train_args.__dict__[k] = v
logger.info('Args used in training: {}'.format(json.dumps(self.train_args.__dict__, ensure_ascii=False, indent=4)))
args.use_link_graph = self.train_args.use_link_graph
args.is_test = True
@torch.no_grad()
def predict_by_examples(self, examples: List[Example]):
data_loader = torch.utils.data.DataLoader(
Dataset(path='', examples=examples, task=args.task),
num_workers=1,
batch_size=max(args.batch_size, 512),
collate_fn=collate,
shuffle=False)
hr_tensor_list, tail_tensor_list = [], []
for idx, batch_dict in enumerate(data_loader):
if self.use_cuda:
batch_dict = move_to_cuda(batch_dict)
outputs = self.model(**batch_dict)
hr_tensor_list.append(outputs['hr_vector'])
tail_tensor_list.append(outputs['tail_vector'])
return torch.cat(hr_tensor_list, dim=0), torch.cat(tail_tensor_list, dim=0)
@torch.no_grad()
def predict_by_entities(self, entity_exs) -> torch.tensor:
examples = []
for entity_ex in entity_exs:
examples.append(Example(head_id='', relation='',
tail_id=entity_ex.entity_id))
data_loader = torch.utils.data.DataLoader(
Dataset(path='', examples=examples, task=args.task),
num_workers=2,
batch_size=max(args.batch_size, 1024),
collate_fn=collate,
shuffle=False)
ent_tensor_list = []
for idx, batch_dict in enumerate(tqdm.tqdm(data_loader)):
batch_dict['only_ent_embedding'] = True
if self.use_cuda:
batch_dict = move_to_cuda(batch_dict)
outputs = self.model(**batch_dict)
ent_tensor_list.append(outputs['ent_vectors'])
return torch.cat(ent_tensor_list, dim=0)