diff --git a/perceiver_pytorch/perceiver_pytorch.py b/perceiver_pytorch/perceiver_pytorch.py index deeec5c..0a3fe68 100644 --- a/perceiver_pytorch/perceiver_pytorch.py +++ b/perceiver_pytorch/perceiver_pytorch.py @@ -221,13 +221,13 @@ def forward( mask = None, return_embeddings = False ): - b, *axis, _, device = *data.shape, data.device + b, *axis, _, device, dtype = *data.shape, data.device, data.dtype assert len(axis) == self.input_axis, 'input data must have the right number of axis' if self.fourier_encode_data: # calculate fourier encoded positions in the range of [-1, 1], for all axis - axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis)) + axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis)) pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1) enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands) enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')