From 2e2a9cd721202d300e046e179e7d9b20e4a321c5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 4 Mar 2021 19:30:40 -0800 Subject: [PATCH] complete the main ideas of the Perceiver --- README.md | 31 +++++++++ perceiver_pytorch/perceiver_pytorch.py | 90 +++++++++++++++++++++++--- 2 files changed, 111 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 1091676..0c16bef 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,37 @@ Implementation of Perceiver, General Perception with Iterative Attention, in Pytorch +## Install + +```bash +$ pip install perceiver-pytorch +``` + +## Usage + +```python +import torch +from perceiver_pytorch.perceiver_pytorch import Perceiver + +model = Perceiver( + num_fourier_features = 6, # number of fourier features, with original value (2 * K + 1) + depth = 48, # depth of net, in paper, they went deep, making up for lack of attention + num_latents = 6, # number of latents, or induced set points, or centroids. different papers giving it different names + cross_dim = 512, # cross attention dimension + latent_dim = 512, # latent dimension + cross_heads = 1, # number of heads for cross attention. paper said 1 + latent_heads = 8, # number of heads for latent self attention, 8 + cross_dim_head = 64, + latent_dim_head = 64, + num_classes = 1000, # output number of classes + attn_dropout = 0., + ff_dropout = 0. +) + +img = torch.randn(1, 224 * 224) # 1 imagenet image, pixelized + +model(img) +``` ## Citations ```bibtex diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index 5f43fcb..741e1d8 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -9,6 +9,9 @@ def exists(val): return val is not None +def default(val, d): + return val if exists(val) else d + def fourier_encode(x, num_encodings = 4): x = x.unsqueeze(-1) device, dtype, orig_x = x.device, x.dtype, x @@ -21,13 +24,20 @@ def fourier_encode(x, num_encodings = 4): # helper classes class PreNorm(nn.Module): - def __init__(self, dim, fn): + def __init__(self, dim, fn, context_dim = None): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) + self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None def forward(self, x, **kwargs): x = self.norm(x) + + if exists(self.norm_context): + context = kwargs['context'] + normed_context = self.norm_context(context) + kwargs.update(context = normed_context) + return self.fn(x, **kwargs) class FeedForward(nn.Module): @@ -43,25 +53,85 @@ def __init__(self, dim, mult = 4, dropout = 0.): def forward(self, x): return self.net(x) +class Attention(nn.Module): + def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 75, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias = False) + self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context = None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k, v = self.to_kv(context).chunk(2, dim = -1) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + attn = sim.softmax(dim = -1) + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h = h) + return self.to_out(out) + # main class class Perceiver(nn.Module): def __init__( self, - input_dim, + *, + num_fourier_features, depth, - cross_attn_dim = 512, num_latents = 6, - cross_attn_heads = 1, - cross_attn_dim_head = 64, - latent_attn_dim = 512, - latent_attn_heads = 8, - latent_attn_dim_head = 64, + cross_dim = 512, + latent_dim = 512, + cross_heads = 1, + latent_heads = 8, + cross_dim_head = 64, + latent_dim_head = 64, num_classes = 1000, attn_dropout = 0., ff_dropout = 0. ): super().__init__() - def forward(self, x): - return x + self.num_fourier_features = num_fourier_features + input_dim = (num_fourier_features * 2) + 1 + self.latents = nn.Parameter(torch.randn(num_latents, latent_dim)) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + PreNorm(latent_dim, Attention(latent_dim, input_dim, dropout = attn_dropout), context_dim = input_dim), + PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)), + PreNorm(latent_dim, Attention(latent_dim, dropout = attn_dropout)), + PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) + ])) + + self.to_logits = nn.Linear(latent_dim, num_classes) + + def forward(self, data): + b = data.shape[0] + data = fourier_encode(data, self.num_fourier_features) + + x = repeat(self.latents, 'n d -> b n d', b = b) + + for cross_attn, cross_ff, latent_attn, latent_ff in self.layers: + x = cross_attn(x, context = data) + x + x = cross_ff(x) + x + x = latent_attn(x) + x + x = latent_ff(x) + x + + x = x.mean(dim = -2) + return self.to_logits(x)