-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlanguage_model.py
123 lines (102 loc) · 4.23 KB
/
language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
This code is modified from Hengyuan Hu's repository.
https://github.com/hengyuan-hu/bottom-up-attention-vqa
"""
import torch
import torch.nn as nn
import numpy as np
from fc import FCNet
class WordEmbedding(nn.Module):
"""Word Embedding
The ntoken-th dim is used for padding_idx, which agrees *implicitly*
with the definition in Dictionary.
"""
def __init__(self, ntoken, emb_dim, dropout, op):
super(WordEmbedding, self).__init__()
self.op = op
self.emb = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken)
if 'c' in op:
self.emb_ = nn.Embedding(ntoken+1, emb_dim, padding_idx=ntoken)
self.emb_.weight.requires_grad = False # fixed
self.dropout = nn.Dropout(dropout)
self.ntoken = ntoken
self.emb_dim = emb_dim
def init_embedding(self, np_file, tfidf=None, tfidf_weights=None):
weight_init = torch.from_numpy(np.load(np_file))
assert weight_init.shape == (self.ntoken, self.emb_dim)
self.emb.weight.data[:self.ntoken] = weight_init
if tfidf is not None:
if 0 < tfidf_weights.size:
weight_init = torch.cat([weight_init, torch.from_numpy(tfidf_weights)], 0)
weight_init = tfidf.matmul(weight_init) # (N x N') x (N', F)
self.emb_.weight.requires_grad = True
if 'c' in self.op:
self.emb_.weight.data[:self.ntoken] = weight_init.clone()
def forward(self, x):
emb = self.emb(x)
if 'c' in self.op:
emb = torch.cat((emb, self.emb_(x)), 2)
emb = self.dropout(emb)
return emb
class QuestionEmbedding(nn.Module):
def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type='GRU'):
"""Module for question embedding
"""
super(QuestionEmbedding, self).__init__()
assert rnn_type == 'LSTM' or rnn_type == 'GRU'
rnn_cls = nn.LSTM if rnn_type == 'LSTM' else nn.GRU if rnn_type == 'GRU' else None
self.rnn = rnn_cls(
in_dim, num_hid, nlayers,
bidirectional=bidirect,
dropout=dropout,
batch_first=True)
self.in_dim = in_dim
self.num_hid = num_hid
self.nlayers = nlayers
self.rnn_type = rnn_type
self.ndirections = 1 + int(bidirect)
def init_hidden(self, batch):
# just to get the type of tensor
weight = next(self.parameters()).data
hid_shape = (self.nlayers * self.ndirections, batch, self.num_hid)
if self.rnn_type == 'LSTM':
return (weight.new(*hid_shape).zero_(),
weight.new(*hid_shape).zero_())
else:
return weight.new(*hid_shape).zero_()
def forward(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
output, hidden = self.rnn(x, hidden)
if self.ndirections == 1:
return output[:, -1]
forward_ = output[:, -1, :self.num_hid]
backward = output[:, 0, self.num_hid:]
return torch.cat((forward_, backward), dim=1)
def forward_all(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
output, hidden = self.rnn(x, hidden)
return output
class RnnQuestionEmbedding(nn.Module):
def __init__(self, ntoken, w_dim, q_dim, op):
super(RnnQuestionEmbedding, self).__init__()
self.w_emb = WordEmbedding(ntoken, w_dim, .0, op)
self.rnn = QuestionEmbedding(w_dim if 'c' not in op else w_dim * 2, q_dim, 1, False, .0)
def forward(self, q):
w_emb = self.w_emb(q)
q_emb = self.rnn.forward_all(w_emb)
return q_emb
class BertRnnQuestionEmbedding(nn.Module):
def __init__(self, bert, rot_dim, q_dim, op):
super(BertRnnQuestionEmbedding, self).__init__()
w_dim = 768
self.w_emb = bert
self.w_emb_ = FCNet([w_dim, rot_dim], 'ReLU', 0.)
self.rnn = QuestionEmbedding(rot_dim if 'c' not in op else rot_dim * 2, q_dim, 1, False, .0)
def forward(self, q):
w_emb, sentence_embedding = self.w_emb(q, output_all_encoded_layers=False) # [batch, q_len, q_dim]
q_emb = self.rnn.forward_all(self.w_emb_(w_emb))
return q_emb