diff --git a/examples/training_sup_text_matching_model.py b/examples/training_sup_text_matching_model.py index 5b5eb4f..788589f 100644 --- a/examples/training_sup_text_matching_model.py +++ b/examples/training_sup_text_matching_model.py @@ -48,7 +48,7 @@ def main(): help='Transformers model model or path') parser.add_argument("--do_train", action="store_true", help="Whether to run training.") parser.add_argument("--do_predict", action="store_true", help="Whether to run predict.") - parser.add_argument('--output_dir', default='./outputs/STS-B-model-v1', type=str, help='Model output directory') + parser.add_argument('--output_dir', default='./outputs/STS-B-model', type=str, help='Model output directory') parser.add_argument('--max_seq_length', default=64, type=int, help='Max sequence length') parser.add_argument('--num_epochs', default=10, type=int, help='Number of training epochs') parser.add_argument('--batch_size', default=64, type=int, help='Batch size') @@ -88,9 +88,7 @@ def main(): else: model = BertMatchModel(model_name_or_path=args.output_dir, encoder_type=args.encoder_type, max_seq_length=args.max_seq_length) - test_dataset = HFTextMatchingTestDataset(model.tokenizer, args.task_name, max_len=args.max_seq_length, - split="test") - + test_dataset = load_dataset("shibing624/nli_zh", args.task_name, split="test") # Predict embeddings srcs = [] trgs = [] diff --git a/tests/test_hf_dataset.py b/tests/test_hf_dataset.py new file mode 100644 index 0000000..ecd4aee --- /dev/null +++ b/tests/test_hf_dataset.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +""" +@author:XuMing(xuming624@qq.com) +@description: +""" +import sys +import unittest + +sys.path.append('..') + +from datasets import load_dataset + + +class DatasetTestCase(unittest.TestCase): + + def test_data_diff(self): + test_dataset = load_dataset("shibing624/nli_zh", "STS-B", split="test") + + # Predict embeddings + srcs = [] + trgs = [] + labels = [] + for terms in test_dataset: + src, trg, label = terms['sentence1'], terms['sentence2'], terms['label'] + srcs.append(src) + trgs.append(trg) + labels.append(label) + if len(src) > 100: + break + print(f'{test_dataset[0]}') + print(f'{srcs[0]}') + + +if __name__ == '__main__': + unittest.main()