-
Notifications
You must be signed in to change notification settings - Fork 1
/
EncoderLayer.py
36 lines (31 loc) · 1.49 KB
/
EncoderLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/python3
# -*- coding:utf-8 -*-
from tensorflow import keras
from MultiHeadAttention import MultiHeadAttention
from FFN import feed_forward_network
# Encoder Layer
class EncoderLayer(keras.layers.Layer):
"""
x -> self attention -> add & normalize & dropout -> feed_forward -> add & normalize & dropout
"""
def __init__(self, d_model, num_heads, dff, rate=0.1):
super(EncoderLayer, self).__init__()
self.mha = MultiHeadAttention(d_model, num_heads)
self.ffn = feed_forward_network(d_model, dff)
self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-6)
self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-6)
self.dropout1 = keras.layers.Dropout(rate)
self.dropout2 = keras.layers.Dropout(rate)
def call(self, x, training, encoder_padding_mask):
# x.shape: (batch_size, seq_len, dim=d_model)
# attn_output.shape: (batch_size, seq_len, d_model)
# out1.shape: (batch_size, seq_len, d_model)
attn_output, _ = self.mha(x, x, x, encoder_padding_mask)
attn_output = self.dropout1(attn_output, training=training)
out1 = self.layer_norm1(x + attn_output)
# ffn_output.shape: (batch_size, seq_len, d_model)
# out2.shape: (batch_size, seq_len, d_model)
ffn_output = self.ffn(out1)
ffn_output = self.dropout2(ffn_output, training=training)
out2 = self.layer_norm2(out1 + ffn_output)
return out2