diff --git a/tests/test_main.py b/tests/test_main.py index 7f5dc4e..c426689 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -50,6 +50,7 @@ def test_longrope_model_initialization(): assert model.max_len == 65536 +# Testing the LongRoPEModel class with embedding def test_longrope_model_embedding(): model = LongRoPEModel( d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 @@ -60,6 +61,7 @@ def test_longrope_model_embedding(): assert not torch.equal(input_ids, embeddings) +# Testing the LongRoPEModel class with transformers def test_longrope_model_transformers(): model = LongRoPEModel( d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 @@ -71,6 +73,7 @@ def test_longrope_model_transformers(): assert embeddings.shape == (2, 1024, 512) +# Testing the LongRoPEModel class with forward def test_longrope_model_forward(): model = LongRoPEModel( d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536