From a7a4480331079215095effec466086f790323ff7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 23 Sep 2022 19:32:37 -0700 Subject: [PATCH] improvise on bidirectional for multi-head learned ema --- README.md | 2 +- mega_pytorch/mega_pytorch.py | 37 +++++++++++++++++++++++------------- setup.py | 2 +- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index cfb84e4..43e881c 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,8 @@ logits = mega(x) # (1, 1024, 256) ## Todo -- [ ] how did they approach bidirectionality in multi-headed EMA? - [ ] authors mistakened about extrapolative abilities of RoPE. replace with dynamic positional bias eventually +- [x] how did they approach bidirectionality in multi-headed EMA? ## Citations diff --git a/mega_pytorch/mega_pytorch.py b/mega_pytorch/mega_pytorch.py index 00b7579..27b1a29 100644 --- a/mega_pytorch/mega_pytorch.py +++ b/mega_pytorch/mega_pytorch.py @@ -147,7 +147,7 @@ def __init__( ) def forward(self, x, v_input = None): - seq_len, dim = x.shape[-2:] + seq_len, dim, device, dtype = *x.shape[-2:], x.device, x.dtype is_softmax_attn = not self.laplacian_attn_fn @@ -163,8 +163,7 @@ def forward(self, x, v_input = None): sim = sim + self.rel_pos_bias(sim) if self.causal: - n, device = x.shape[1], x.device, x.dtype - causal_mask = torch.ones((n, n), device = device, dtype = torch.bool).triu(1) + causal_mask = torch.ones((seq_len, seq_len), device = device, dtype = torch.bool).triu(1) if self.causal and not self.laplacian_attn_fn: # is softmax attention and using large negative value pre-softmax @@ -184,12 +183,11 @@ def __init__( *, dim, heads, + bidirectional = False, dim_head = None ): super().__init__() - dim_head = default(dim_head, dim) - inner_dim = heads * dim_head - self.heads = heads + self.bidirectional = bidirectional self.expansion = nn.Parameter(torch.randn(heads, dim)) self.reduction = nn.Parameter(torch.randn(heads, dim)) @@ -199,6 +197,10 @@ def __init__( self.alphas = nn.Parameter(torch.randn(heads)) self.dampen_factors = nn.Parameter(torch.randn(heads)) + if bidirectional: + self.reverse_alphas = nn.Parameter(torch.randn(heads)) + self.reverse_dampen_factors = nn.Parameter(torch.randn(heads)) + def forward(self, x): device, seq_len = x.device, x.shape[1] @@ -208,19 +210,27 @@ def forward(self, x): # weights derived from alphas (learned exponential smoothing decay rate) - alphas = self.alphas.sigmoid() - dampen_factors = self.dampen_factors.sigmoid() + def apply_learned_ema_with_damping(x, alphas, dampen_factors): + alphas = alphas.sigmoid() + dampen_factors = dampen_factors.sigmoid() + + reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device) + K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1')) + + # conv1d fft O(nlog(n)) - reversed_powers = torch.arange(seq_len - 1, -1, -1, device = device) - K = alphas * (((1 - alphas) * dampen_factors) ** rearrange(reversed_powers, '... l -> ... l 1')) + return conv1d_fft(x, K, dim = -3, weight_dim = -2) - # conv1d fft O(nlog(n)) + x = apply_learned_ema_with_damping(x, self.alphas, self.dampen_factors) - out = conv1d_fft(x, K, dim = -3, weight_dim = -2) + if self.bidirectional: + x = torch.flip(x, dims = (1,)) + x = apply_learned_ema_with_damping(x, self.reverse_alphas, self.reverse_dampen_factors) + x = torch.flip(x, dims = (1,)) # combine heads and out - return einsum('... h d, h d -> ... d', out, self.reduction) + return einsum('... h d, h d -> ... d', x, self.reduction) # Mega Layer # Single headed Attention + Multi-headed EMA, then GRU-esque gating @@ -250,6 +260,7 @@ def __init__( self.multi_headed_ema = MultiHeadedEMA( dim = dim, heads = ema_heads, + bidirectional = not causal, dim_head = ema_dim_head ) diff --git a/setup.py b/setup.py index c48c12e..482d816 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'Mega-pytorch', packages = find_packages(exclude=[]), - version = '0.0.5', + version = '0.0.6', license='MIT', description = 'Mega - Pytorch', author = 'Phil Wang',