From e9992b32550b6f915b06603bc20b57cd17b830cc Mon Sep 17 00:00:00 2001 From: Joshua David Date: Wed, 12 Jun 2024 21:42:34 -0700 Subject: [PATCH] correct non_uniform_interpolation --- src/main.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index b1d079e..7749860 100644 --- a/src/main.py +++ b/src/main.py @@ -43,9 +43,11 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat) interpolated_pos = pos_embed.clone() for i in range(d_model // 2): - mask = torch.arange(pos_embed.shape[-2]) < n_hat + mask = torch.arange(pos_embed.shape[-2], device=pos_embed.device) < n_hat scale = torch.where( - mask, torch.ones_like(pos_embed[..., 0]), 1 / lambda_factors[i] + mask, + torch.ones_like(pos_embed[..., 0], device=pos_embed.device), + 1 / (lambda_factors[i] * extension_ratio), ) interpolated_pos[..., i * 2] *= scale if i * 2 + 1 < d_model: # Check if the index is within bounds