Skip to content

Commit

Permalink
Merge pull request #59 from lsisoft/fix_precision_16
Browse files Browse the repository at this point in the history
Add dtype in axis_pos creation to allow precision=16 training
  • Loading branch information
lucidrains authored Jan 25, 2022
2 parents e5a81bd + cfe01c7 commit ac0bd4a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions perceiver_pytorch/perceiver_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,13 @@ def forward(
mask = None,
return_embeddings = False
):
b, *axis, _, device = *data.shape, data.device
b, *axis, _, device, dtype = *data.shape, data.device, data.dtype
assert len(axis) == self.input_axis, 'input data must have the right number of axis'

if self.fourier_encode_data:
# calculate fourier encoded positions in the range of [-1, 1], for all axis

axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps = size, device = device), axis))
axis_pos = list(map(lambda size: torch.linspace(-1., 1., steps=size, device=device, dtype=dtype), axis))
pos = torch.stack(torch.meshgrid(*axis_pos, indexing = 'ij'), dim = -1)
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands)
enc_pos = rearrange(enc_pos, '... n d -> ... (n d)')
Expand Down

0 comments on commit ac0bd4a

Please sign in to comment.