Skip to content

Commit

Permalink
Fix layer norm usage (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
esceptico authored Sep 26, 2021
1 parent 628c9cd commit 361559a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/perceiver/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(
"""
super().__init__()
self.layer_norm = nn.LayerNorm(hidden_dim)
self.qkv_layer_norm = nn.LayerNorm(hidden_dim)
self.attention = MultiHeadAttention(
kv_dim=hidden_dim,
q_dim=hidden_dim,
Expand Down Expand Up @@ -157,7 +158,7 @@ def forward(
)
attention = self.dropout(attention)
x = x + attention
x = x + self.mlp(x)
x = x + self.mlp(self.qkv_layer_norm(x))
return x


Expand Down Expand Up @@ -195,6 +196,7 @@ def __init__(
self.use_query_residual = use_query_residual
self.kv_layer_norm = nn.LayerNorm(kv_dim)
self.q_layer_norm = nn.LayerNorm(q_dim)
self.qkv_layer_norm = nn.LayerNorm(q_dim)
self.attention = MultiHeadAttention(
kv_dim=kv_dim,
q_dim=q_dim,
Expand All @@ -219,14 +221,14 @@ def forward(
in [0, 1]. Defaults to None.
"""
attention = self.attention(
inputs_kv=inputs_kv,
inputs_q=inputs_q,
inputs_kv=self.kv_layer_norm(inputs_kv),
inputs_q=self.q_layer_norm(inputs_q),
attention_mask=attention_mask
)
attention = self.dropout(attention)
if self.use_query_residual:
x = inputs_q + attention
else:
x = attention
x = x + self.mlp(x)
x = x + self.mlp(self.qkv_layer_norm(x))
return x

0 comments on commit 361559a

Please sign in to comment.