Skip to content

Commit

Permalink
correct non_uniform_interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 13, 2024
1 parent bc209b2 commit e9992b3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat)
interpolated_pos = pos_embed.clone()

for i in range(d_model // 2):
mask = torch.arange(pos_embed.shape[-2]) < n_hat
mask = torch.arange(pos_embed.shape[-2], device=pos_embed.device) < n_hat
scale = torch.where(
mask, torch.ones_like(pos_embed[..., 0]), 1 / lambda_factors[i]
mask,
torch.ones_like(pos_embed[..., 0], device=pos_embed.device),
1 / (lambda_factors[i] * extension_ratio),
)
interpolated_pos[..., i * 2] *= scale
if i * 2 + 1 < d_model: # Check if the index is within bounds
Expand Down

0 comments on commit e9992b3

Please sign in to comment.