Skip to content

Commit

Permalink
update main test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 10, 2024
1 parent 042d120 commit 84b81bc
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,21 @@ def test_longrope_model_forward():
input_ids = torch.randint(0, 50257, (2, 1024))
output = model(input_ids)
assert output.shape == (2, 1024, 512)


def test_extend_context():
model = LongRoPEModel(
d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536
)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
extended_model = model.extend_context(
data_path="data/raw/enwik8.gz",
target_length=2048000,
max_sequence_length=65536,
tokenizer=tokenizer,
population_size=64,
num_mutations=16,
num_crossovers=16,
max_iterations=10,
)
assert extended_model is not None

0 comments on commit 84b81bc

Please sign in to comment.