From 84b81bc3633fc968cd79cfb38080280f2300f506 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Sun, 9 Jun 2024 21:36:30 -0700 Subject: [PATCH] update main test file --- tests/test_main.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index 9ab9041..eebdc00 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -40,3 +40,21 @@ def test_longrope_model_forward(): input_ids = torch.randint(0, 50257, (2, 1024)) output = model(input_ids) assert output.shape == (2, 1024, 512) + + +def test_extend_context(): + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + extended_model = model.extend_context( + data_path="data/raw/enwik8.gz", + target_length=2048000, + max_sequence_length=65536, + tokenizer=tokenizer, + population_size=64, + num_mutations=16, + num_crossovers=16, + max_iterations=10, + ) + assert extended_model is not None