diff --git a/src/perceiver/attention.py b/src/perceiver/attention.py index cea8df3..a9fcdd2 100644 --- a/src/perceiver/attention.py +++ b/src/perceiver/attention.py @@ -128,6 +128,7 @@ def __init__( """ super().__init__() self.layer_norm = nn.LayerNorm(hidden_dim) + self.qkv_layer_norm = nn.LayerNorm(hidden_dim) self.attention = MultiHeadAttention( kv_dim=hidden_dim, q_dim=hidden_dim, @@ -157,7 +158,7 @@ def forward( ) attention = self.dropout(attention) x = x + attention - x = x + self.mlp(x) + x = x + self.mlp(self.qkv_layer_norm(x)) return x @@ -195,6 +196,7 @@ def __init__( self.use_query_residual = use_query_residual self.kv_layer_norm = nn.LayerNorm(kv_dim) self.q_layer_norm = nn.LayerNorm(q_dim) + self.qkv_layer_norm = nn.LayerNorm(q_dim) self.attention = MultiHeadAttention( kv_dim=kv_dim, q_dim=q_dim, @@ -219,8 +221,8 @@ def forward( in [0, 1]. Defaults to None. """ attention = self.attention( - inputs_kv=inputs_kv, - inputs_q=inputs_q, + inputs_kv=self.kv_layer_norm(inputs_kv), + inputs_q=self.q_layer_norm(inputs_q), attention_mask=attention_mask ) attention = self.dropout(attention) @@ -228,5 +230,5 @@ def forward( x = inputs_q + attention else: x = attention - x = x + self.mlp(x) + x = x + self.mlp(self.qkv_layer_norm(x)) return x