-
Notifications
You must be signed in to change notification settings - Fork 7
/
post_trainer.py
73 lines (54 loc) · 2.76 KB
/
post_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torch import optim
import numpy as np
from utils import get_posttrain_train_valid_dataset
from torch.utils.data import DataLoader
from datasets import KGETrainDataset, KGEEvalDataset
from trainer import Trainer
class PostTrainer(Trainer):
def __init__(self, args):
super(PostTrainer, self).__init__(args)
self.args = args
self.load_metatrain()
# dataloader
train_dataset, valid_dataset = get_posttrain_train_valid_dataset(args)
self.train_dataloader = DataLoader(train_dataset, batch_size=self.args.posttrain_bs,
collate_fn=KGETrainDataset.collate_fn)
self.valid_dataloader = DataLoader(valid_dataset, batch_size=args.indtest_eval_bs,
collate_fn=KGEEvalDataset.collate_fn)
self.optimizer = optim.Adam(list(self.ent_init.parameters()) + list(self.rgcn.parameters())
+ list(self.kge_model.parameters()), lr=self.args.posttrain_lr)
def load_metatrain(self):
state = torch.load(self.args.metatrain_state, map_location=self.args.gpu)
self.ent_init.load_state_dict(state['ent_init'])
self.rgcn.load_state_dict(state['rgcn'])
self.kge_model.load_state_dict(state['kge_model'])
def get_ent_emb(self, sup_g_bidir):
self.ent_init(sup_g_bidir)
ent_emb = self.rgcn(sup_g_bidir)
return ent_emb
def train(self):
self.logger.info('start fine-tuning')
# print epoch test rst
self.evaluate_indtest_test_triples(num_cand=50)
for i in range(1, self.args.posttrain_num_epoch + 1):
losses = []
for batch in self.train_dataloader:
pos_triple, neg_tail_ent, neg_head_ent = [b.to(self.args.gpu) for b in batch]
ent_emb = self.get_ent_emb(self.indtest_train_g)
loss = self.get_loss(pos_triple, neg_tail_ent, neg_head_ent, ent_emb)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
losses.append(loss.item())
self.logger.info('epoch: {} | loss: {:.4f}'.format(i, np.mean(losses)))
if i % self.args.posttrain_check_per_epoch == 0:
self.evaluate_indtest_test_triples(num_cand=50)
def evaluate_indtest_valid_triples(self, num_cand='all'):
ent_emb = self.get_ent_emb(self.indtest_train_g)
results = self.evaluate(ent_emb, self.valid_dataloader, num_cand)
self.logger.info('valid on ind-test-graph')
self.logger.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format(
results['mrr'], results['hits@1'],
results['hits@5'], results['hits@10']))
return results