Skip to content

Commit

Permalink
Attention scale bug (#10)
Browse files Browse the repository at this point in the history
* Fix attention scale

* Change version

* Cleanup encoder docstring
  • Loading branch information
esceptico authored Nov 18, 2021
1 parent 6b65073 commit ea7f82c
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 5 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name='perceiver-io-pytorch',
version='0.1.3rc1',
version='0.1.4',
packages=['perceiver_io'],
package_dir={'': 'src'},
url='https://github.com/esceptico/perceiver-io',
Expand Down
2 changes: 1 addition & 1 deletion src/perceiver_io/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(
k = rearrange(k, 'b s (n h) -> b n s h', h=self.qk_head_dim)
q = rearrange(q, 'b s (n h) -> b n s h', h=self.qk_head_dim)
v = rearrange(v, 'b s (n h) -> b n s h', h=self.v_head_dim)
attention = (q @ k.transpose(-2, -1) / self.scale)
attention = (q @ k.transpose(-2, -1) * self.scale)
if attention_mask is not None:
min_value = torch.finfo(attention.dtype).min
extended_mask = (1 - attention_mask) * min_value
Expand Down
3 changes: 0 additions & 3 deletions src/perceiver_io/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def __init__(
Defaults to None.
v_out_dim: Size of Value matrix last dimension.
Defaults to None.
cross_attn_head_dim: Size of cross-attention head. If None,this
value will be calculated as latent_dim / num_cross_attn_heads.
Defaults to None.
num_cross_attn_heads: Number of cross-attention heads.
Defaults to 1.
num_self_attn_heads: Number of self-attention heads.
Expand Down

0 comments on commit ea7f82c

Please sign in to comment.