From e7dd8eb22bd4bff87632f9ca14f2f712ccc5d6c8 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Sun, 9 Jun 2024 21:39:20 -0700 Subject: [PATCH] update main test file --- tests/test_main.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index eebdc00..624ae11 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -58,3 +58,16 @@ def test_extend_context(): max_iterations=10, ) assert extended_model is not None + + +def test_recover_short_context(): + model = LongRoPEModel( + d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536 + ) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + recovered_model = model.recover_short_context( + data_path="data/raw/enwik8.gz", + max_sequence_length=65536, + tokenizer=tokenizer, + ) + assert recovered_model is not None