Skip to content

Commit

Permalink
added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 12, 2024
1 parent a1bf3e3 commit f2ab4d6
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from train import CustomDataset, collate_fn, preprocess_data, validate_targets, train
from src.main import LongRoPEModel


# Testing the CustomDataset class
def test_custom_dataset():
sequences = [[1, 2, 3], [4, 5, 6]]
targets = [[2, 3, 4], [5, 6, 7]]
Expand All @@ -14,13 +14,15 @@ def test_custom_dataset():
assert dataset[0] == (sequences[0], targets[0])


# Testing the CustomDataset class with empty sequences and targets
def test_custom_dataset_empty():
sequences = []
targets = []
dataset = CustomDataset(sequences, targets)
assert len(dataset) == 0


# Testing the collate_fn function
def test_collate_fn():
batch = [([1, 2, 3], [2, 3, 4]), ([4, 5], [5, 6])]
inputs, targets = collate_fn(batch)
Expand All @@ -30,13 +32,15 @@ def test_collate_fn():
assert torch.equal(targets[0], torch.tensor([2, 3, 4]))


# Testing the collate_fn function with an empty batch
def test_collate_fn_empty():
batch = []
inputs, targets = collate_fn(batch)
assert inputs.shape == (0,)
assert targets.shape == (0,)


# Testing the preprocess_data function
def test_preprocess_data():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
data = "This is a test."
Expand All @@ -45,19 +49,22 @@ def test_preprocess_data():
assert all(len(seq) <= 10 for seq in sequences)


# Testing the preprocess_data function with an empty string
def test_preprocess_data_empty():
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
data = ""
sequences = preprocess_data(data, tokenizer, max_length=10, overlap=5)
assert len(sequences) == 0


# Testing the validate_targets function
def test_validate_targets():
targets = [[1, 2, 3], [4, 5, 6]]
vocab_size = 10
assert validate_targets(targets, vocab_size) == True


# Testing the validate_targets function with invalid targets
def test_validate_targets_invalid():
targets = [[1, 2, 3], [4, 5, 10]]
vocab_size = 10
Expand Down

0 comments on commit f2ab4d6

Please sign in to comment.