Skip to content

Commit

Permalink
add dtype in axis_pos creation to allow precision=16 training. Used s…
Browse files Browse the repository at this point in the history
…ame pattern as in fourier_encode.
  • Loading branch information
lsisoft committed Jan 24, 2022
1 parent e5a81bd commit cfe01c7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def forward(
mask = None,
return_embeddings = False
):
b, *axis, _, device = *data.shape, data.device
b, *axis, _, device, dtype = *data.shape, data.device, data.dtype
assert len(axis) == self.input_axis, 'input data must have the right number of axis'

if self.fourier_encode_data:
# calculate fourier encoded positions in the range of [-1, 1], for all axis

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
Expand Down

0 comments on commit cfe01c7

Please sign in to comment.