Skip to content

Commit

Permalink
throw in full mega architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 24, 2022
1 parent 4809dd5 commit eff465d
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ $ pip install mega-pytorch

## Usage

The Mega Layer with combination of attention and learned EMA

```python
import torch
from mega_pytorch import MegaLayer
Expand All @@ -25,9 +27,36 @@ layer = MegaLayer(
)

x = torch.randn(1, 1024, 128) # (batch, seq, dim)

out = layer(x) # (1, 1024, 128)
```

Full Mega (with layernorm for now)

```python
import torch
from mega_pytorch import Mega

mega = Mega(
num_tokens = 256, # number of tokens
dim = 128, # model dimensions
depth = 6, # depth
causal = False, # autoregressive or not
ema_heads = 16, # number of EMA heads
attn_dim_qk = 64, # dimension of queries / keys in attention
attn_dim_value = 256, # dimensino of values in attention
laplacian_attn_fn = True, # whether to use softmax (false) or laplacian attention fn (improved relu squared)
)

x = torch.randint(0, 256, (1, 1024))

logits = mega(x) # (1, 1024, 256)
```

## Todo

- [ ] how did they approach bidirectionality in multi-headed EMA?

## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion mega_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from mega_pytorch.mega_pytorch import MegaLayer
from mega_pytorch.mega_pytorch import MegaLayer, Mega
46 changes: 46 additions & 0 deletions mega_pytorch/mega_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,49 @@ def forward(self, x):
# update gate

return update_gate * H + (1 - update_gate) * x

# Mega

def FeedForward(dim, ff_mult):
dim_hidden = int(dim * ff_mult)
return nn.Sequential(
nn.Linear(dim, dim_hidden),
nn.GELU(),
nn.Linear(dim_hidden, dim)
)

class Mega(nn.Module):
def __init__(
self,
*,
dim,
num_tokens,
depth,
ff_mult = 2,
**kwargs
):
super().__init__()
self.token_emb = nn.Embedding(num_tokens, dim)

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
MegaLayer(**kwargs),
nn.LayerNorm(dim),
FeedForward(dim = dim, ff_mult = ff_mult),
nn.LayerNorm(dim)
]))

self.to_logits = nn.Linear(dim, num_tokens)

def forward(self, x):
x = self.token_emb(x)

for mega_layer, post_mega_norm, ff, post_ff_norm in self.layers:
x = mega_layer(x)
x = post_mega_norm(x)

x = ff(x) + x
x = post_ff_norm(x)

return self.to_logits(x)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'Mega-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.3',
version = '0.0.4',
license='MIT',
description = 'Mega - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit eff465d

Please sign in to comment.