From d6e3cda8abfbadfc24c3092bb9babfaa97dca8cd Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 3 Jan 2023 14:36:25 -0800 Subject: [PATCH] remove freq base from all perceiver variants --- perceiver_pytorch/experimental.py | 4 +--- perceiver_pytorch/gated.py | 4 +--- perceiver_pytorch/mixed_latents.py | 4 +--- setup.py | 2 +- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/perceiver_pytorch/experimental.py b/perceiver_pytorch/experimental.py index 31b8bcc..55eac78 100644 --- a/perceiver_pytorch/experimental.py +++ b/perceiver_pytorch/experimental.py @@ -53,7 +53,6 @@ def __init__( num_freq_bands, depth, max_freq, - freq_base = 2, input_channels = 3, input_axis = 2, num_latents = 512, @@ -71,7 +70,6 @@ def __init__( self.input_axis = input_axis self.max_freq = max_freq self.num_freq_bands = num_freq_bands - self.freq_base = freq_base input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels @@ -119,7 +117,7 @@ def forward(self, data, mask = None): axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis)) pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1) - enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base) + enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') enc_pos = repeat(enc_pos, '... -> b ...', b = b) diff --git a/perceiver_pytorch/gated.py b/perceiver_pytorch/gated.py index 5c5f1ea..dc7d8da 100644 --- a/perceiver_pytorch/gated.py +++ b/perceiver_pytorch/gated.py @@ -44,7 +44,6 @@ def __init__( num_freq_bands, depth, max_freq, - freq_base = 2, input_channels = 3, input_axis = 2, num_latents = 512, @@ -62,7 +61,6 @@ def __init__( self.input_axis = input_axis self.max_freq = max_freq self.num_freq_bands = num_freq_bands - self.freq_base = freq_base input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels @@ -100,7 +98,7 @@ def forward(self, data, mask = None): axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis)) pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1) - enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base) + enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') enc_pos = repeat(enc_pos, '... -> b ...', b = b) diff --git a/perceiver_pytorch/mixed_latents.py b/perceiver_pytorch/mixed_latents.py index 6b35688..c80362a 100644 --- a/perceiver_pytorch/mixed_latents.py +++ b/perceiver_pytorch/mixed_latents.py @@ -25,7 +25,6 @@ def __init__( num_freq_bands, depth, max_freq, - freq_base = 2, input_channels = 3, input_axis = 2, num_latents = 512, @@ -44,7 +43,6 @@ def __init__( self.input_axis = input_axis self.max_freq = max_freq self.num_freq_bands = num_freq_bands - self.freq_base = freq_base input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels @@ -82,7 +80,7 @@ def forward(self, data, mask = None): axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis)) pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1) - enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base) + enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, '... n d -> ... (n d)') enc_pos = repeat(enc_pos, '... -> b ...', b = b) diff --git a/setup.py b/setup.py index 6b372af..8e4434a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'perceiver-pytorch', packages = find_packages(), - version = '0.8.6', + version = '0.8.7', license='MIT', description = 'Perceiver - Pytorch', long_description_content_type = 'text/markdown',