Skip to content

Commit

Permalink
update main test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 11, 2024
1 parent b8609b3 commit 7173fa1
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_load_data():
data = load_data("data/raw/enwik8.gz", tokenizer, 65536)
assert len(data) > 0
assert isinstance(data[0], torch.Tensor)
assert data[0].shape[0] <= 65536


def test_non_uniform_interpolation():
Expand All @@ -24,13 +25,15 @@ def test_non_uniform_interpolation():
n_hat = 50
interpolated = non_uniform_interpolation(pos_embed, 2.0, lambda_factors, n_hat)
assert interpolated.shape == pos_embed.shape
assert not torch.equal(pos_embed, interpolated)


def test_rope_positional_encoding():
rope = RoPEPositionalEncoding(d_model=512, max_len=100)
positions = torch.arange(100).unsqueeze(0)
pos_embeddings = rope(positions)
assert pos_embeddings.shape == (1, 100, 512)
assert not torch.equal(positions, pos_embeddings)


def test_longrope_model_forward():
Expand All @@ -40,6 +43,7 @@ def test_longrope_model_forward():
input_ids = torch.randint(0, 50257, (2, 1024))
output = model(input_ids)
assert output.shape == (2, 1024, 512)
assert not torch.equal(input_ids, output)


def test_extend_context():
Expand All @@ -58,6 +62,7 @@ def test_extend_context():
max_iterations=10,
)
assert extended_model is not None
assert extended_model.max_len == 2048000


def test_recover_short_context():
Expand All @@ -71,6 +76,7 @@ def test_recover_short_context():
tokenizer=tokenizer,
)
assert recovered_model is not None
assert recovered_model.max_len == 65536


def test_progressive_extension():
Expand Down Expand Up @@ -99,6 +105,7 @@ def test_progressive_extension():
assert n_hat is not None
assert lambda_factors_base is not None
assert n_hat_base is not None
assert extended_model.max_len == 2048000


def test_short_context_recovery():
Expand All @@ -107,3 +114,34 @@ def test_short_context_recovery():
)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
short_context_recovery(model, tokenizer)


def test_longrope_model_initialization():
model = LongRoPEModel(
d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536
)
assert model.d_model == 512
assert model.n_heads == 8
assert model.num_layers == 6
assert model.vocab_size == 50257
assert model.max_len == 65536


def test_longrope_model_embedding():
model = LongRoPEModel(
d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536
)
input_ids = torch.randint(0, 50257, (2, 1024))
embeddings = model.embedding(input_ids)
assert embeddings.shape == (2, 1024, 512)


def test_longrope_model_transformers():
model = LongRoPEModel(
d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536
)
input_ids = torch.randint(0, 50257, (2, 1024))
embeddings = model.embedding(input_ids)
for transformer in model.transformers:
embeddings = transformer(embeddings)
assert embeddings.shape == (2, 1024, 512)

0 comments on commit 7173fa1

Please sign in to comment.