diff --git a/tests/test_main.py b/tests/test_main.py index 827b9f7..7f5dc4e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -10,7 +10,7 @@ short_context_recovery, ) - +# Testing the load_data function def test_load_data(): tokenizer = GPT2Tokenizer.from_pretrained("gpt2") data = load_data("data/raw/enwik8.gz", tokenizer, 65536) @@ -19,6 +19,7 @@ def test_load_data(): assert data[0].shape[0] <= 65536 +# Testing the non_uniform_interpolation function def test_non_uniform_interpolation(): pos_embed = torch.randn(1, 100, 512) lambda_factors = torch.ones(256) @@ -28,6 +29,7 @@ def test_non_uniform_interpolation(): assert not torch.equal(pos_embed, interpolated) +# Testing the RoPEPositionalEncoding class def test_rope_positional_encoding(): rope = RoPEPositionalEncoding(d_model=512, max_len=100) positions = torch.arange(100).unsqueeze(0) @@ -36,6 +38,7 @@ def test_rope_positional_encoding(): assert not torch.equal(positions, pos_embeddings) +# Testing the LongRoPEModel class def test_longrope_model_initialization(): model = LongRoPEModel( d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536