Skip to content

Commit

Permalink
update hf dataset test.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Mar 1, 2022
1 parent 65f6c31 commit 0e19206
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
6 changes: 2 additions & 4 deletions examples/training_sup_text_matching_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main():
help='Transformers model model or path')
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_predict", action="store_true", help="Whether to run predict.")
parser.add_argument('--output_dir', default='./outputs/STS-B-model-v1', type=str, help='Model output directory')
parser.add_argument('--output_dir', default='./outputs/STS-B-model', type=str, help='Model output directory')
parser.add_argument('--max_seq_length', default=64, type=int, help='Max sequence length')
parser.add_argument('--num_epochs', default=10, type=int, help='Number of training epochs')
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
Expand Down Expand Up @@ -88,9 +88,7 @@ def main():
else:
model = BertMatchModel(model_name_or_path=args.output_dir, encoder_type=args.encoder_type,
max_seq_length=args.max_seq_length)
test_dataset = HFTextMatchingTestDataset(model.tokenizer, args.task_name, max_len=args.max_seq_length,
split="test")

test_dataset = load_dataset("shibing624/nli_zh", args.task_name, split="test")
# Predict embeddings
srcs = []
trgs = []
Expand Down
35 changes: 35 additions & 0 deletions tests/test_hf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description:
"""
import sys
import unittest

sys.path.append('..')

from datasets import load_dataset


class DatasetTestCase(unittest.TestCase):

def test_data_diff(self):
test_dataset = load_dataset("shibing624/nli_zh", "STS-B", split="test")

# Predict embeddings
srcs = []
trgs = []
labels = []
for terms in test_dataset:
src, trg, label = terms['sentence1'], terms['sentence2'], terms['label']
srcs.append(src)
trgs.append(trg)
labels.append(label)
if len(src) > 100:
break
print(f'{test_dataset[0]}')
print(f'{srcs[0]}')


if __name__ == '__main__':
unittest.main()

0 comments on commit 0e19206

Please sign in to comment.