From 03aca7a10fbe2b8ed23e54a24d9bc17ae31edc99 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Sun, 9 Jun 2024 21:34:48 -0700 Subject: [PATCH] update main test file --- tests/test_main.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index e8afb83..4866642 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -16,3 +16,11 @@ def test_load_data(): data = load_data("data/raw/enwik8.gz", tokenizer, 65536) assert len(data) > 0 assert isinstance(data[0], torch.Tensor) + + +def test_non_uniform_interpolation(): + pos_embed = torch.randn(1, 100, 512) + lambda_factors = torch.ones(256) + n_hat = 50 + interpolated = non_uniform_interpolation(pos_embed, 2.0, lambda_factors, n_hat) + assert interpolated.shape == pos_embed.shape