diff --git a/src/main.py b/src/main.py index b1d079e..7749860 100644 --- a/src/main.py +++ b/src/main.py @@ -43,9 +43,11 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat) interpolated_pos = pos_embed.clone() for i in range(d_model // 2): - mask = torch.arange(pos_embed.shape[-2]) < n_hat + mask = torch.arange(pos_embed.shape[-2], device=pos_embed.device) < n_hat scale = torch.where( - mask, torch.ones_like(pos_embed[..., 0]), 1 / lambda_factors[i] + mask, + torch.ones_like(pos_embed[..., 0], device=pos_embed.device), + 1 / (lambda_factors[i] * extension_ratio), ) interpolated_pos[..., i * 2] *= scale if i * 2 + 1 < d_model: # Check if the index is within bounds