Skip to content

Commit

Permalink
improvise on bidirectional for multi-head learned ema
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 24, 2022
1 parent 41b4485 commit a7a4480
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ logits = mega(x) # (1, 1024, 256)

## Todo

- [ ] how did they approach bidirectionality in multi-headed EMA?
- [ ] authors mistakened about extrapolative abilities of RoPE. replace with dynamic positional bias eventually
- [x] how did they approach bidirectionality in multi-headed EMA?

## Citations

Expand Down
37 changes: 24 additions & 13 deletions mega_pytorch/mega_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(
)

def forward(self, x, v_input = None):
seq_len, dim = x.shape[-2:]
seq_len, dim, device, dtype = *x.shape[-2:], x.device, x.dtype

is_softmax_attn = not self.laplacian_attn_fn

Expand All @@ -163,8 +163,7 @@ def forward(self, x, v_input = None):
sim = sim + self.rel_pos_bias(sim)

if self.causal:
n, device = x.shape[1], x.device, x.dtype
causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1)
causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1)

if self.causal and not self.laplacian_attn_fn:
# is softmax attention and using large negative value pre-softmax
Expand All @@ -184,12 +183,11 @@ def __init__(
*,
dim,
heads,
bidirectional = False,
dim_head = None
):
super().__init__()
dim_head = default(dim_head, dim)
inner_dim = heads * dim_head
self.heads = heads
self.bidirectional = bidirectional

self.expansion = nn.Parameter(torch.randn(heads, dim))
self.reduction = nn.Parameter(torch.randn(heads, dim))
Expand All @@ -199,6 +197,10 @@ def __init__(
self.alphas = nn.Parameter(torch.randn(heads))
self.dampen_factors = nn.Parameter(torch.randn(heads))

if bidirectional:
self.reverse_alphas = nn.Parameter(torch.randn(heads))
self.reverse_dampen_factors = nn.Parameter(torch.randn(heads))

def forward(self, x):
device, seq_len = x.device, x.shape[1]

Expand All @@ -208,19 +210,27 @@ def forward(self, x):

# weights derived from alphas (learned exponential smoothing decay rate)

alphas = self.alphas.sigmoid()
dampen_factors = self.dampen_factors.sigmoid()
def apply_learned_ema_with_damping(x, alphas, dampen_factors):
alphas = alphas.sigmoid()
dampen_factors = dampen_factors.sigmoid()

reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))

# conv1d fft O(nlog(n))

reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device)
K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1'))
return conv1d_fft(x, K, dim = -3, weight_dim = -2)

# conv1d fft O(nlog(n))
x = apply_learned_ema_with_damping(x, self.alphas, self.dampen_factors)

out = conv1d_fft(x, K, dim = -3, weight_dim = -2)
if self.bidirectional:
x = torch.flip(x, dims = (1,))
x = apply_learned_ema_with_damping(x, self.reverse_alphas, self.reverse_dampen_factors)
x = torch.flip(x, dims = (1,))

# combine heads and out

return einsum('... h d, h d -> ... d', out, self.reduction)
return einsum('... h d, h d -> ... d', x, self.reduction)

# Mega Layer
# Single headed Attention + Multi-headed EMA, then GRU-esque gating
Expand Down Expand Up @@ -250,6 +260,7 @@ def __init__(
self.multi_headed_ema = MultiHeadedEMA(
dim = dim,
heads = ema_heads,
bidirectional = not causal,
dim_head = ema_dim_head
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'Mega-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Mega - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit a7a4480

Please sign in to comment.