Skip to content

Commit

Permalink
Ensure lambda_factors and extension_ratio are not None in non_uniform…
Browse files Browse the repository at this point in the history
…_interpolation
  • Loading branch information
jshuadvd committed Jun 16, 2024
1 parent 01925bb commit 7275090
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat)
if extension_ratio is None:
raise ValueError("extension_ratio cannot be None")

if lambda_factors is None:
raise ValueError("lambda_factors cannot be None")

d_model = pos_embed.shape[-1]
interpolated_pos = pos_embed.clone()

Expand All @@ -53,9 +56,8 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat)
torch.ones_like(pos_embed[..., 0], device=pos_embed.device),
1 / (lambda_factors[i] * extension_ratio),
)
interpolated_pos[..., i * 2] *= scale
if i * 2 + 1 < d_model: # Check if the index is within bounds
interpolated_pos[..., i * 2 + 1] *= scale
interpolated_pos[..., 2 * i] *= scale
interpolated_pos[..., 2 * i + 1] *= scale

return interpolated_pos

Expand Down

0 comments on commit 7275090

Please sign in to comment.