Skip to content

Commit

Permalink
remove freq base from all perceiver variants
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 3, 2023
1 parent c8c5f57 commit d6e3cda
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 10 deletions.
4 changes: 1 addition & 3 deletions perceiver_pytorch/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __init__(
num_freq_bands,
depth,
max_freq,
freq_base = 2,
input_channels = 3,
input_axis = 2,
num_latents = 512,
Expand All @@ -71,7 +70,6 @@ def __init__(
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.freq_base = freq_base

input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels

Expand Down Expand Up @@ -119,7 +117,7 @@ def forward(self, data, mask = None):

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)

Expand Down
4 changes: 1 addition & 3 deletions perceiver_pytorch/gated.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
num_freq_bands,
depth,
max_freq,
freq_base = 2,
input_channels = 3,
input_axis = 2,
num_latents = 512,
Expand All @@ -62,7 +61,6 @@ def __init__(
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.freq_base = freq_base

input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels

Expand Down Expand Up @@ -100,7 +98,7 @@ def forward(self, data, mask = None):

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)

Expand Down
4 changes: 1 addition & 3 deletions perceiver_pytorch/mixed_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
num_freq_bands,
depth,
max_freq,
freq_base = 2,
input_channels = 3,
input_axis = 2,
num_latents = 512,
Expand All @@ -44,7 +43,6 @@ def __init__(
self.input_axis = input_axis
self.max_freq = max_freq
self.num_freq_bands = num_freq_bands
self.freq_base = freq_base

input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels

Expand Down Expand Up @@ -82,7 +80,7 @@ def forward(self, data, mask = None):

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
enc_pos = repeat(enc_pos, '... -> b ...', b = b)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.8.6',
version = '0.8.7',
license='MIT',
description = 'Perceiver - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit d6e3cda

Please sign in to comment.