Skip to content

Commit

Permalink
Fix TypeError in non_uniform_interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 14, 2024
1 parent def4bf1 commit 01925bb
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat)
Returns:
torch.Tensor: Interpolated position embeddings.
"""

if extension_ratio is None:
raise ValueError("extension_ratio cannot be None")

d_model = pos_embed.shape[-1]
interpolated_pos = pos_embed.clone()

Expand Down

0 comments on commit 01925bb

Please sign in to comment.