Skip to content

Commit

Permalink
fix dropout in original perceiver
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 7, 2021
1 parent 2a1b039 commit 2d59df4
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion perceiver_pytorch/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(self, data, mask = None):
# calculate fourier encoded positions in the range of [-1, 1], for all axis

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
Expand Down
2 changes: 1 addition & 1 deletion perceiver_pytorch/gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, data, mask = None):
# calculate fourier encoded positions in the range of [-1, 1], for all axis

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
Expand Down
2 changes: 1 addition & 1 deletion perceiver_pytorch/mixed_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def forward(self, data, mask = None):
# calculate fourier encoded positions in the range of [-1, 1], for all axis

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
Expand Down
13 changes: 6 additions & 7 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __init__(self, dim, mult = 4, dropout = 0.):
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)

def forward(self, x):
Expand All @@ -90,10 +90,8 @@ def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
self.dropout = nn.Dropout(dropout)
self.to_out = nn.Linear(inner_dim, query_dim)

def forward(self, x, context = None, mask = None):
h = self.heads
Expand All @@ -114,6 +112,7 @@ def forward(self, x, context = None, mask = None):

# attention, what we cannot get enough of
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
Expand Down Expand Up @@ -228,7 +227,7 @@ def forward(
# calculate fourier encoded positions in the range of [-1, 1], for all axis

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos), dim = -1)
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.7.5',
version = '0.8.0',
license='MIT',
description = 'Perceiver - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 2d59df4

Please sign in to comment.