Skip to content

Commit

Permalink
🧪 Improve flexibility of tests in test_datasets.py
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Feb 5, 2024
1 parent c95c4f5 commit 19574b6
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,32 @@
from hezar.data import Dataset


TASK_TO_HUB_MAPPING = {
"text-classification": "hezarai/sentiment-dksf",
"sequence-labeling": "hezarai/lscp-pos-500k",
"ocr": "hezarai/persian-license-plate-v1",
"image-captioning": "hezarai/flickr30k-fa",
"text-summarization": "hezarai/xlsum-fa",
"speech-recognition": "hezarai/common-voice-13-fa"
}
TASK_TO_TOKENIZER_MAPPING = {
"text-classification": "hezarai/bert-base-fa",
"sequence-labeling": "hezarai/bert-base-fa",
"ocr": "hezarai/crnn-fa-printed-96-long",
"image-captioning": "hezarai/roberta-base-fa",
"text-summarization": "hezarai/t5-base-fa",
"speech-recognition": "hezarai/whisper-small-fa"
DATASETS_MAPPING = {
"text-classification": {
"path": "hezarai/sentiment-dksf",
"tokenizer_path": "hezarai/bert-base-fa",
},
"sequence-labeling": {
"path": "hezarai/lscp-pos-500k",
"tokenizer_path": "hezarai/bert-base-fa",
},
"ocr": {
"path": "hezarai/persian-license-plate-v1",
"tokenizer_path": "hezarai/crnn-fa-printed-96-long",
},
"image-captioning": {
"path": "hezarai/flickr30k-fa",
"tokenizer_path": "hezarai/roberta-base-fa",
},
"text-summarization": {
"path": "hezarai/xlsum-fa",
"tokenizer_path": "hezarai/t5-base-fa",
},
"speech-recognition": {
"path": "hezarai/common-voice-13-fa",
"tokenizer_path": "hezarai/whisper-small-fa",
"feature_extractor_path": "hezarai/whisper-small-fa"
},
}

TASK_TO_REQUIRED_FIELDS = {
Expand All @@ -42,14 +53,16 @@ def create_dataloader(dataset, batch_size, shuffle, collate_fn):
return dataloader


@pytest.mark.parametrize("task", TASK_TO_HUB_MAPPING.keys())
@pytest.mark.parametrize("task", DATASETS_MAPPING.keys())
def test_load_dataset(task):
required_fields = TASK_TO_REQUIRED_FIELDS[task]

path = DATASETS_MAPPING[task].pop("path")

train_dataset = Dataset.load(
TASK_TO_HUB_MAPPING[task],
path,
split="train",
tokenizer_path=TASK_TO_TOKENIZER_MAPPING[task]
**DATASETS_MAPPING[task]
)
assert isinstance(train_dataset, Dataset), INVALID_DATASET_TYPE.format(type(train_dataset))
sample = train_dataset[0]
Expand All @@ -58,9 +71,9 @@ def test_load_dataset(task):
assert field in sample, INVALID_DATASET_FIELDS.format(field)

test_dataset = Dataset.load(
TASK_TO_HUB_MAPPING[task],
path,
split="test",
tokenizer_path=TASK_TO_TOKENIZER_MAPPING[task]
**DATASETS_MAPPING[task]
)
assert isinstance(test_dataset, Dataset), INVALID_DATASET_TYPE.format(type(test_dataset))
sample = test_dataset[0]
Expand Down

0 comments on commit 19574b6

Please sign in to comment.