diff --git a/tests/test_main.py b/tests/test_main.py index ea4ecd7..6d07b92 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -16,6 +16,7 @@ def test_load_data(): data = load_data("data/raw/enwik8.gz", tokenizer, 65536) assert len(data) > 0 assert isinstance(data[0], torch.Tensor) + assert data[0].shape[0] <= 65536 def test_non_uniform_interpolation(): @@ -24,6 +25,7 @@ def test_non_uniform_interpolation(): n_hat = 50 interpolated = non_uniform_interpolation(pos_embed, 2.0, lambda_factors, n_hat) assert interpolated.shape == pos_embed.shape + assert not torch.equal(pos_embed, interpolated) def test_rope_positional_encoding(): @@ -31,6 +33,7 @@ def test_rope_positional_encoding(): positions = torch.arange(100).unsqueeze(0) pos_embeddings = rope(positions) assert pos_embeddings.shape == (1, 100, 512) + assert not torch.equal(positions, pos_embeddings) def test_longrope_model_forward(): @@ -40,6 +43,7 @@ def test_longrope_model_forward(): input_ids = torch.randint(0, 50257, (2, 1024)) output = model(input_ids) assert output.shape == (2, 1024, 512) + assert not torch.equal(input_ids, output) def test_extend_context(): @@ -58,6 +62,7 @@ def test_extend_context(): max_iterations=10, ) assert extended_model is not None + assert extended_model.max_len == 2048000 def test_recover_short_context(): @@ -71,6 +76,7 @@ def test_recover_short_context(): tokenizer=tokenizer, ) assert recovered_model is not None + assert recovered_model.max_len == 65536 def test_progressive_extension(): @@ -99,6 +105,7 @@ def test_progressive_extension(): assert n_hat is not None assert lambda_factors_base is not None assert n_hat_base is not None + assert extended_model.max_len == 2048000 def test_short_context_recovery(): @@ -107,3 +114,34 @@ def test_short_context_recovery(): ) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") short_context_recovery(model, tokenizer) + + +def test_longrope_model_initialization(): + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ) + assert model.d_model == 512 + assert model.n_heads == 8 + assert model.num_layers == 6 + assert model.vocab_size == 50257 + assert model.max_len == 65536 + + +def test_longrope_model_embedding(): + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ) + input_ids = torch.randint(0, 50257, (2, 1024)) + embeddings = model.embedding(input_ids) + assert embeddings.shape == (2, 1024, 512) + + +def test_longrope_model_transformers(): + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ) + input_ids = torch.randint(0, 50257, (2, 1024)) + embeddings = model.embedding(input_ids) + for transformer in model.transformers: + embeddings = transformer(embeddings) + assert embeddings.shape == (2, 1024, 512)