diff --git a/tests/test_main.py b/tests/test_main.py index 9a89544..e8afb83 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -9,3 +9,10 @@ progressive_extension, short_context_recovery, ) + + +def test_load_data(): + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + data = load_data("data/raw/enwik8.gz", tokenizer, 65536) + assert len(data) > 0 + assert isinstance(data[0], torch.Tensor)