diff --git a/src/main.py b/src/main.py index 743fa54..fb8e2d1 100644 --- a/src/main.py +++ b/src/main.py @@ -43,6 +43,9 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat) if extension_ratio is None: raise ValueError("extension_ratio cannot be None") + if lambda_factors is None: + raise ValueError("lambda_factors cannot be None") + d_model = pos_embed.shape[-1] interpolated_pos = pos_embed.clone() @@ -53,9 +56,8 @@ def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat) torch.ones_like(pos_embed[..., 0], device=pos_embed.device), 1 / (lambda_factors[i] * extension_ratio), ) - interpolated_pos[..., i * 2] *= scale - if i * 2 + 1 < d_model: # Check if the index is within bounds - interpolated_pos[..., i * 2 + 1] *= scale + interpolated_pos[..., 2 * i] *= scale + interpolated_pos[..., 2 * i + 1] *= scale return interpolated_pos