From 58eceac9f1111d7aef47a4bfb19af02669a6e010 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Mon, 10 Jun 2024 20:52:48 -0700 Subject: [PATCH] update main test file --- tests/test_main.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index 624ae11..15365a1 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -71,3 +71,31 @@ def test_recover_short_context(): tokenizer=tokenizer, ) assert recovered_model is not None + + +def test_progressive_extension(): + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ) + data = [torch.randint(0, 50257, (65536,)) for _ in range(10)] + ( + extended_model, + lambda_factors, + n_hat, + lambda_factors_base, + n_hat_base, + ) = progressive_extension( + model, + data, + base_length=65536, + target_length=2048000, + population_size=64, + num_mutations=16, + num_crossovers=16, + max_iterations=10, + ) + assert extended_model is not None + assert lambda_factors is not None + assert n_hat is not None + assert lambda_factors_base is not None + assert n_hat_base is not None