Skip to content

Commit

Permalink
complete the main ideas of the Perceiver
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 5, 2021
1 parent 0a52000 commit 2e2a9cd
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 10 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,37 @@

Implementation of <a href="https://arxiv.org/abs/2103.03206">Perceiver</a>, General Perception with Iterative Attention, in Pytorch

## Install

```bash
$ pip install perceiver-pytorch
```

## Usage

```python
import torch
from perceiver_pytorch.perceiver_pytorch import Perceiver

model = Perceiver(
num_fourier_features = 6, # number of fourier features, with original value (2 * K + 1)
depth = 48, # depth of net, in paper, they went deep, making up for lack of attention
num_latents = 6, # number of latents, or induced set points, or centroids. different papers giving it different names
cross_dim = 512, # cross attention dimension
latent_dim = 512, # latent dimension
cross_heads = 1, # number of heads for cross attention. paper said 1
latent_heads = 8, # number of heads for latent self attention, 8
cross_dim_head = 64,
latent_dim_head = 64,
num_classes = 1000, # output number of classes
attn_dropout = 0.,
ff_dropout = 0.
)

img = torch.randn(1, 224 * 224) # 1 imagenet image, pixelized

model(img)
```
## Citations

```bibtex
Expand Down
90 changes: 80 additions & 10 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def fourier_encode(x, num_encodings = 4):
x = x.unsqueeze(-1)
device, dtype, orig_x = x.device, x.dtype, x
Expand All @@ -21,13 +24,20 @@ def fourier_encode(x, num_encodings = 4):
# helper classes

class PreNorm(nn.Module):
def __init__(self, dim, fn):
def __init__(self, dim, fn, context_dim = None):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

def forward(self, x, **kwargs):
x = self.norm(x)

if exists(self.norm_context):
context = kwargs['context']
normed_context = self.norm_context(context)
kwargs.update(context = normed_context)

return self.fn(x, **kwargs)

class FeedForward(nn.Module):
Expand All @@ -43,25 +53,85 @@ def __init__(self, dim, mult = 4, dropout = 0.):
def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 75, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads

self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)

def forward(self, x, context = None):
h = self.heads

q = self.to_q(x)
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim = -1)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim = -1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
return self.to_out(out)

# main class

class Perceiver(nn.Module):
def __init__(
self,
input_dim,
*,
num_fourier_features,
depth,
cross_attn_dim = 512,
num_latents = 6,
cross_attn_heads = 1,
cross_attn_dim_head = 64,
latent_attn_dim = 512,
latent_attn_heads = 8,
latent_attn_dim_head = 64,
cross_dim = 512,
latent_dim = 512,
cross_heads = 1,
latent_heads = 8,
cross_dim_head = 64,
latent_dim_head = 64,
num_classes = 1000,
attn_dropout = 0.,
ff_dropout = 0.
):
super().__init__()

def forward(self, x):
return x
self.num_fourier_features = num_fourier_features
input_dim = (num_fourier_features * 2) + 1
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(latent_dim, Attention(latent_dim, input_dim, dropout = attn_dropout), context_dim = input_dim),
PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)),
PreNorm(latent_dim, Attention(latent_dim, dropout = attn_dropout)),
PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))
]))

self.to_logits = nn.Linear(latent_dim, num_classes)

def forward(self, data):
b = data.shape[0]
data = fourier_encode(data, self.num_fourier_features)

x = repeat(self.latents, 'n d -> b n d', b = b)

for cross_attn, cross_ff, latent_attn, latent_ff in self.layers:
x = cross_attn(x, context = data) + x
x = cross_ff(x) + x
x = latent_attn(x) + x
x = latent_ff(x) + x

x = x.mean(dim = -2)
return self.to_logits(x)

0 comments on commit 2e2a9cd

Please sign in to comment.