-
Notifications
You must be signed in to change notification settings - Fork 2
/
Encoder.py
34 lines (25 loc) · 1.11 KB
/
Encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn
class EncoderRNN(nn.Module):
def __init__(self, hidden_size, embedding, n_layers=1):
super(EncoderRNN, self).__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.embedding = embedding
'''
Initialize GRU; the input_size and hidden_size params
are both set to 'hidden_size' because our input
size is a word embedding with number of features == hidden_size
'''
self.gru = nn.GRU(hidden_size,
hidden_size,
n_layers,
bidirectional=True)
def forward(self, input_seq, input_lengths, hidden=None):
embedded = self.embedding(input_seq)
packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
outputs, hidden = self.gru(packed, hidden)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
outputs = outputs[:, :,
:self.hidden_size] + outputs[:, :,
self.hidden_size:]
return outputs, hidden