Skip to content

Commit

Permalink
added comments to main test file
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 12, 2024
1 parent 9974a1d commit 232fa37
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
short_context_recovery,
)


# Testing the load_data function
def test_load_data():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
data = load_data("data/raw/enwik8.gz", tokenizer, 65536)
Expand All @@ -19,6 +19,7 @@ def test_load_data():
assert data[0].shape[0] <= 65536


# Testing the non_uniform_interpolation function
def test_non_uniform_interpolation():
pos_embed = torch.randn(1, 100, 512)
lambda_factors = torch.ones(256)
Expand All @@ -28,6 +29,7 @@ def test_non_uniform_interpolation():
assert not torch.equal(pos_embed, interpolated)


# Testing the RoPEPositionalEncoding class
def test_rope_positional_encoding():
rope = RoPEPositionalEncoding(d_model=512, max_len=100)
positions = torch.arange(100).unsqueeze(0)
Expand All @@ -36,6 +38,7 @@ def test_rope_positional_encoding():
assert not torch.equal(positions, pos_embeddings)


# Testing the LongRoPEModel class
def test_longrope_model_initialization():
model = LongRoPEModel(
d_model=512, n_heads=8, num_layers=6, vocab_size=50257, max_len=65536
Expand Down

0 comments on commit 232fa37

Please sign in to comment.