-
Notifications
You must be signed in to change notification settings - Fork 0
/
generalRnn.py
49 lines (42 loc) · 1.72 KB
/
generalRnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch.nn as nn
class BaseCoder(nn.Module):
def __init__(self, vocab_size, hidden_size, embedding_size, input_dropout, output_dropout, n_layers, rnn, vocab, embeddings):
super(BaseCoder, self).__init__()
# init ...
self.vocab_size = vocab_size
# self.max_length = max_length
self.hidden_size = hidden_size
self.n_layers = n_layers
self.embedding_size = embedding_size
self.embedding = nn.Embedding(vocab_size, embedding_size)
# load pre-trained embeddings
self.load_pretrained_embeddings(vocab, embeddings, trainable=True)
# TODO: why two self.input_dropout here?
self.input_dropout = input_dropout
self.input_dropout = nn.Dropout(p=input_dropout)
self.output_dropout = output_dropout
if rnn.lower() == "lstm":
self.baseModel = nn.LSTM
elif rnn.lower() == "gru":
self.baseModel = nn.GRU
else:
## raise error
raise ValueError("No such cell!")
def forward(self, *args, **kwargs):
raise NotImplementedError
def load_pretrained_embeddings(self, vocab, embeddings, trainable=True):
if not vocab or not embeddings:
return
vocab_size = len(vocab)
count = 0
self.embedding.weight.requires_grad = False
for i in range(vocab_size):
word = vocab.id2word[i]
try:
self.embedding.weight[i] = embeddings['vectors'][embeddings['dico'].index(word)]
except Exception as e:
# print(e)
count += 1
print('missing embedding:', count, vocab_size)
if trainable:
self.embedding.weight.requires_grad = True