diff --git a/xtuner/configs/internlm/internlm_7b/internlm_7b_qlora_alpaca_enzh_oasst1_e3.py b/xtuner/configs/internlm/internlm_7b/internlm_7b_qlora_alpaca_enzh_oasst1_e3.py new file mode 100644 index 000000000..438925e07 --- /dev/null +++ b/xtuner/configs/internlm/internlm_7b/internlm_7b_qlora_alpaca_enzh_oasst1_e3.py @@ -0,0 +1,210 @@ +import torch +from bitsandbytes.optim import PagedAdamW32bit +from datasets import load_dataset +from mmengine.dataset import DefaultSampler +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR +from peft import LoraConfig +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +from xtuner.datasets import ConcatDataset, process_hf_dataset +from xtuner.datasets.collate_fns import default_collate_fn +from xtuner.datasets.map_fns import (alpaca_map_fn, alpaca_zh_map_fn, + oasst1_map_fn, template_map_fn_factory) +from xtuner.engine import DatasetInfoHook, EvaluateChatHook +from xtuner.models import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# path +pretrained_model_name_or_path = 'internlm/internlm-7b' +alpaca_zh_path = 'silk-road/alpaca-data-gpt4-chinese' +alpaca_en_path = 'tatsu-lab/alpaca' +oasst1_path = 'timdettmers/openassistant-guanaco' + +# data +prompt_template = PROMPT_TEMPLATE.alpaca +batch_size = 1 # per_device +accumulative_counts = 16 +dataloader_num_workers = 0 +max_epochs = 3 + +# optim +optim_type = PagedAdamW32bit +lr = 2e-4 +betas = (0.9, 0.999) +weight_decay = 0.01 +max_norm = 1 # grad clip + +# Assess the progress of the model's training via interactive dialogue. +evaluation_freq = 500 +human_inputs = ['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'] + +# other +max_length = 2048 +pack_to_max_length = True + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.float16, + quantization_config=dict( + type=BitsAndBytesConfig, + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4')), + lora=dict( + type=LoraConfig, + r=64, + lora_alpha=16, + lora_dropout=0.1, + bias='none', + task_type='CAUSAL_LM')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +alpaca_en = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=alpaca_en_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=alpaca_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +alpaca_zh = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=alpaca_zh_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=alpaca_zh_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +oasst1 = dict( + type=process_hf_dataset, + dataset=dict(type=load_dataset, path=oasst1_path), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=oasst1_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +train_dataset = dict( + type=ConcatDataset, + datasets_cfg=dict(alpaca_en=alpaca_en, alpaca_zh=alpaca_zh, oasst1=oasst1)) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=DefaultSampler, shuffle=True), + collate_fn=dict(type=default_collate_fn)) + +####################################################################### +# PART 4 Scheduler # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', + dtype='float16') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = dict( + type=CosineAnnealingLR, + eta_min=lr * 0.1, + by_epoch=True, + T_max=max_epochs, + convert_to_iter_based=True) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + sample_inputs=human_inputs, + instruction=prompt_template.INSTRUCTION_START) +] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 100 iterations. + logger=dict(type=LoggerHook, interval=10), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per epoch. + checkpoint=dict(type=CheckpointHook, interval=1), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) diff --git a/xtuner/configs/llama/llama_7b/llama_7b_qlora_alpaca_e3.py b/xtuner/configs/llama/llama_7b/llama_7b_qlora_alpaca_e3.py index f80ee0c1d..738791d13 100644 --- a/xtuner/configs/llama/llama_7b/llama_7b_qlora_alpaca_e3.py +++ b/xtuner/configs/llama/llama_7b/llama_7b_qlora_alpaca_e3.py @@ -11,8 +11,9 @@ from xtuner.datasets import ConcatDataset, process_hf_dataset from xtuner.datasets.collate_fns import default_collate_fn -from xtuner.datasets.map_fns import alpaca_map_fn, alpaca_zh_map_fn -from xtuner.engine import LogSampleHook, SampleGenerateHook +from xtuner.datasets.map_fns import (alpaca_map_fn, alpaca_zh_map_fn, + template_map_fn_factory) +from xtuner.engine import DatasetInfoHook, EvaluateChatHook from xtuner.models import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE @@ -25,6 +26,7 @@ alpaca_en_path = 'tatsu-lab/alpaca' # data +prompt_template = PROMPT_TEMPLATE.alpaca batch_size = 1 # per_device accumulative_counts = 16 dataloader_num_workers = 0 @@ -37,10 +39,13 @@ weight_decay = 0.01 max_norm = 1 # grad clip +# Assess the progress of the model's training via interactive dialogue. +evaluation_freq = 500 +human_inputs = ['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'] + # other max_length = 2048 pack_to_max_length = True -generate_test_freq = 500 ####################################################################### # PART 2 Model & Tokenizer # @@ -83,8 +88,10 @@ dataset=dict(type=load_dataset, path=alpaca_en_path), tokenizer=tokenizer, max_length=max_length, - map_fn=alpaca_map_fn, - remove_columns=['instruction', 'text'], + dataset_map_fn=alpaca_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length) @@ -93,8 +100,10 @@ dataset=dict(type=load_dataset, path=alpaca_zh_path), tokenizer=tokenizer, max_length=max_length, - map_fn=alpaca_zh_map_fn, - remove_columns=['instruction', 'instruction_zh', 'input_zh', 'output_zh'], + dataset_map_fn=alpaca_zh_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length) @@ -139,15 +148,13 @@ ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ - dict(type=LogSampleHook, tokenizer=tokenizer), + dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( - type=SampleGenerateHook, + type=EvaluateChatHook, tokenizer=tokenizer, - every_n_iters=generate_test_freq, - sample_inputs=[ - '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' - ], - instruction=PROMPT_TEMPLATE.alpaca.INSTRUCTION_START) + every_n_iters=evaluation_freq, + sample_inputs=human_inputs, + instruction=prompt_template.INSTRUCTION_START) ] # configure default hooks diff --git a/xtuner/configs/llama/llama_7b/llama_7b_qlora_oasst1_e3.py b/xtuner/configs/llama/llama_7b/llama_7b_qlora_oasst1_e3.py index 86095e5af..639171e69 100644 --- a/xtuner/configs/llama/llama_7b/llama_7b_qlora_oasst1_e3.py +++ b/xtuner/configs/llama/llama_7b/llama_7b_qlora_oasst1_e3.py @@ -11,8 +11,8 @@ from xtuner.datasets import process_hf_dataset from xtuner.datasets.collate_fns import default_collate_fn -from xtuner.datasets.map_fns import oasst1_map_fn -from xtuner.engine import LogSampleHook, SampleGenerateHook +from xtuner.datasets.map_fns import oasst1_map_fn, template_map_fn_factory +from xtuner.engine import DatasetInfoHook, EvaluateChatHook from xtuner.models import SupervisedFinetune from xtuner.utils import PROMPT_TEMPLATE @@ -24,6 +24,7 @@ data_path = 'timdettmers/openassistant-guanaco' # data +prompt_template = PROMPT_TEMPLATE.openassistant batch_size = 1 # per_device accumulative_counts = 16 dataloader_num_workers = 0 @@ -36,10 +37,13 @@ weight_decay = 0.01 max_norm = 1 # grad clip +# Assess the progress of the model's training via interactive dialogue. +evaluation_freq = 500 +human_inputs = ['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'] + # other max_length = 2048 pack_to_max_length = True -generate_test_freq = 500 ####################################################################### # PART 2 Model & Tokenizer # ####################################################################### @@ -81,7 +85,10 @@ dataset=dict(type=load_dataset, path=data_path), tokenizer=tokenizer, max_length=max_length, - map_fn=oasst1_map_fn, + dataset_map_fn=oasst1_map_fn, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, shuffle_before_pack=True, pack_to_max_length=pack_to_max_length) @@ -122,15 +129,13 @@ ####################################################################### # Log the dialogue periodically during the training process, optional custom_hooks = [ - dict(type=LogSampleHook, tokenizer=tokenizer), + dict(type=DatasetInfoHook, tokenizer=tokenizer), dict( - type=SampleGenerateHook, + type=EvaluateChatHook, tokenizer=tokenizer, - every_n_iters=generate_test_freq, - sample_inputs=[ - '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' - ], - instruction=PROMPT_TEMPLATE.openassistant.INSTRUCTION_START) + every_n_iters=evaluation_freq, + sample_inputs=human_inputs, + instruction=prompt_template.INSTRUCTION_START) ] # configure default hooks diff --git a/xtuner/datasets/huggingface.py b/xtuner/datasets/huggingface.py index f23c2a82a..bf03bfc6f 100644 --- a/xtuner/datasets/huggingface.py +++ b/xtuner/datasets/huggingface.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging from functools import partial import numpy as np from datasets import DatasetDict +from mmengine import print_log from mmengine.config import Config, ConfigDict -from mmengine.config.lazy import LazyObject from xtuner.registry import BUILDER from .utils import Packer, encode_fn @@ -13,14 +14,44 @@ def process_hf_dataset(dataset, tokenizer, max_length, + dataset_map_fn=None, + template_map_fn=None, max_dataset_length=None, split='train', - map_fn=None, - remove_columns=[], + remove_unused_columns=False, rename_maps=[], shuffle_before_pack=True, pack_to_max_length=True, input_ids_with_output=True): + """Post-process the dataset loaded from the Hugging Face Hub, or a local + dataset. + + Args: + dataset: The dataset to be post-processed. + tokenizer: The tokenizer processes some raw text as input and outputs + an Encoding. + max_length: Max length of the sequence. + dataset_map_fn: Map the original dataset format to the one defined + by xTuner. + template_map_fn: Add the prompt template to the dataset + max_dataset_length: If the length of the dataset is too long, we can + randomly extract `max_dataset_length` from it. + split: Which split of the data to load. + If `None`, will return a `dict` with all splits (typically + `datasets.Split.TRAIN` and `datasets.Split.TEST`). + If given, will return a single Dataset. + remove_unused_columns: Whether to remove columns from the dataset + that are not used during training. + rename_maps: Rename the column name of the dataset. + shuffle_before_pack: Whether to shuffle the dataset before + packing them. + pack_to_max_length: Whether to pack the dataset to the `max_length `. + This usually improves gpu utilization and therefore reduces + training time. + input_ids_with_output: Whether to put the groundtruth output + corresponding to the question into the dataset. Typically set + it to True during training and False during testing. + """ dataset = BUILDER.build(dataset) if isinstance(dataset, DatasetDict): @@ -34,20 +65,31 @@ def process_hf_dataset(dataset, len(dataset), max_dataset_length, replace=False) dataset = dataset.select(indices) - if isinstance(map_fn, str): - map_fn = eval(map_fn) - if isinstance(map_fn, list): - assert all( - [callable(fn) and isinstance(fn, LazyObject) for fn in map_fn]) - for fn in map_fn[:-1]: - fn = fn.build() - dataset = dataset.map(fn) - dataset = dataset.map( - map_fn[-1].build(), remove_columns=remove_columns) - elif map_fn is not None: - dataset = dataset.map(map_fn, remove_columns=remove_columns) + # Extract the useful data for training from the original dataset. + if dataset_map_fn is not None: + dataset = dataset.map(dataset_map_fn) + + # Add prompt template, such as ### Human: xxx ###Assistant: xxx + if template_map_fn is not None: + if isinstance(template_map_fn, dict) or isinstance( + template_map_fn, Config) or isinstance(template_map_fn, + ConfigDict): + template_map_fn = BUILDER.build(template_map_fn) + dataset = dataset.map(template_map_fn) + for old, new in rename_maps: dataset = dataset.rename_column(old, new) + + # remove unused columns + if pack_to_max_length and (not remove_unused_columns): + print_log( + 'We have to remove unused columns if ' + '`pack_to_max_length` is set to True.', + logger='current', + level=logging.WARNING) + remove_unused_columns = True + + # tokenize if isinstance(tokenizer, dict) or isinstance( tokenizer, Config) or isinstance(tokenizer, ConfigDict): tokenizer = BUILDER.build(tokenizer) @@ -56,12 +98,15 @@ def process_hf_dataset(dataset, encode_fn, tokenizer=tokenizer, max_length=max_length, - input_ids_with_output=input_ids_with_output)) + input_ids_with_output=input_ids_with_output), + remove_columns=list(dataset.column_names) + if remove_unused_columns else None) + + # pack to max length if pack_to_max_length and split == 'train': if shuffle_before_pack: dataset = dataset.shuffle() dataset = dataset.flatten_indices() - column_names = list(dataset.column_names) - dataset = dataset.map( - Packer(max_length), batched=True, remove_columns=column_names) + dataset = dataset.map(Packer(max_length), batched=True) + return dataset diff --git a/xtuner/datasets/map_fns/__init__.py b/xtuner/datasets/map_fns/__init__.py index 3f4b7b703..4a488c53e 100644 --- a/xtuner/datasets/map_fns/__init__.py +++ b/xtuner/datasets/map_fns/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .dataset_map_fn import * # noqa: F401, F403 -from .model_map_fn import * # noqa: F401, F403 +from .dataset_map_fns import * # noqa: F401, F403 +from .template_map_fn import template_map_fn # noqa: F401 +from .template_map_fn import template_map_fn_factory # noqa: F401 diff --git a/xtuner/datasets/map_fns/dataset_map_fn/alpaca_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/alpaca_map_fn.py deleted file mode 100644 index 08be92056..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/alpaca_map_fn.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def alpaca_map_fn(example): - PROMPT = { - 'with_input': - ('Below is an instruction that describes a task, paired with an ' - 'input that provides further context. ' - 'Write a response that appropriately completes the request.\n\n' - '### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n' - '### Response: '), - 'without_input': - ('Below is an instruction that describes a task. ' - 'Write a response that appropriately completes the request.\n\n' - '### Instruction:\n{instruction}\n\n' - '### Response: ') - } - if example.get('input', '') != '': - prompt_template = PROMPT['with_input'] - else: - prompt_template = PROMPT['without_input'] - - if example.get('output', '') == '': - return {'conversation': [{'input': '', 'output': ''}]} - else: - return { - 'conversation': [{ - 'input': prompt_template.format(**example), - 'output': example['output'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/alpaca_zh_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/alpaca_zh_map_fn.py deleted file mode 100644 index 87ea99da1..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/alpaca_zh_map_fn.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def alpaca_zh_map_fn(example): - PROMPT = { - 'with_input': - ('Below is an instruction that describes a task, paired with an ' - 'input that provides further context. ' - 'Write a response that appropriately completes the request.\n\n' - '### Instruction:\n{instruction_zh}\n\n### Input:\n{input_zh}\n\n' - '### Response: '), - 'without_input': - ('Below is an instruction that describes a task. ' - 'Write a response that appropriately completes the request.\n\n' - '### Instruction:\n{instruction_zh}\n\n' - '### Response: ') - } - if example.get('input', '') != '': - prompt_template = PROMPT['with_input'] - else: - prompt_template = PROMPT['without_input'] - - return { - 'conversation': [{ - 'input': prompt_template.format(**example), - 'output': example['output_zh'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/arxiv_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/arxiv_map_fn.py deleted file mode 100644 index adfa869b5..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/arxiv_map_fn.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def arxiv_map_fn(example): - PROMPT = ('If you are an expert in writing papers, please generate ' - "a good paper title for this paper based on other authors' " - 'descriptions of their abstracts.\n\n' - '### Descriptions:\n{abstract}\n\n### Title: ') - return { - 'conversation': [{ - 'input': PROMPT.format(**example), - 'output': example['title'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/code_alpaca_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/code_alpaca_map_fn.py deleted file mode 100644 index 04c03e59b..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/code_alpaca_map_fn.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def code_alpaca_map_fn(example): - return { - 'conversation': [{ - 'input': - '### Human: {prompt}\n### Bot: '.format(**example), - 'output': - example['completion'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/colors_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/colors_map_fn.py deleted file mode 100644 index a38968fea..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/colors_map_fn.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. - - -def colors_map_fn(example): - PROMPT = ('You are a professional color designer. Please provide the ' - 'corresponding colors based on the description of Human.\n' - '### Human: {input}\n### Bot: ') - desc = ':'.join(example['description'].split(':')[1:]).strip() - return { - 'conversation': [{ - 'input': PROMPT.format(input=desc), - 'output': example['color'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/crime_kg_assitant_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/crime_kg_assitant_map_fn.py deleted file mode 100644 index 8966b7e4c..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/crime_kg_assitant_map_fn.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def crime_kg_assitant_map_fn(example): - return { - 'conversation': [{ - 'input': ('你现在是一名专业的中国律师,请根据Human的问题给出准确、有理有据的回复。\n\n' - '### Human: {input}\n### Bot: ').format(**example), - 'output': - example['output'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/law_reference_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/law_reference_map_fn.py deleted file mode 100644 index 3c16c781a..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/law_reference_map_fn.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def law_reference_map_fn(example): - return { - 'conversation': [{ - 'input': ('你现在是一名专业的中国律师,请根据Human的问题给出准确、有理有据的回复。\n\n' - '### Human: {question}\n### Bot: ').format(**example), - 'output': - example['answer'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/medical_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/medical_map_fn.py deleted file mode 100644 index 87b1d3133..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/medical_map_fn.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def medical_map_fn(example): - PROMPT = { - 'with_input': ('如果你是一名医生,请根据患者的描述回答医学问题。\n\n' - '### Input: {instruction}. {input}\n\n### Response: '), - 'without_input': ('如果你是一名医生,请根据患者的描述回答医学问题。\n\n' - '### Input: {instruction}\n\n### Response: '), - } - if example.get('input', '') != '': - prompt_template = PROMPT['with_input'] - else: - prompt_template = PROMPT['without_input'] - - return { - 'conversation': [{ - 'input': prompt_template.format(**example), - 'output': example['output'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/oasst1_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/oasst1_map_fn.py deleted file mode 100644 index e43f5338a..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/oasst1_map_fn.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def oasst1_map_fn(example): - r"""Example before preprocessing: example['text'] = '### Human: Can you - explain xxx### Assistant: Sure!xxx### Human: I didn't understand how - xxx### Assistant: It has to do with a process xxx.'. - - Example after preprocessing: - example['conversation'] = [ - { - 'input': '### Human: Can you explain xxx', - 'output': '### Assistant: Sure! xxx' - }, - { - 'input': '### Human: I didn't understand how xxx', - 'output': '### Assistant: It has to do with a process xxx.' - } - ] - """ - data = [ - '### ' + sentence.strip() - for sentence in example['text'].strip().split('###') if sentence != '' - ] - if len(data) % 2: - # The last round of conversation solely consists of input - # without any output. - # Discard the input part of the last round, as this part is ignored in - # the loss calculation. - data.pop() - conversation = [] - for i in range(0, len(data), 2): - single_turn_conversation = {'input': data[i], 'output': data[i + 1]} - conversation.append(single_turn_conversation) - return {'conversation': conversation} diff --git a/xtuner/datasets/map_fns/dataset_map_fn/openorca_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/openorca_map_fn.py deleted file mode 100644 index fb721847c..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/openorca_map_fn.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def openorca_map_fn(example): - PROMPT = ('Below is an instruction that describes a task. ' - 'Write a response that appropriately completes the request.\n\n' - '### Instruction:\n{question}\n\n' - '### Response: ') - - return { - 'conversation': [{ - 'input': PROMPT.format(**example), - 'output': example['response'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/sql_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fn/sql_map_fn.py deleted file mode 100644 index 5a09f20dc..000000000 --- a/xtuner/datasets/map_fns/dataset_map_fn/sql_map_fn.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def sql_map_fn(example): - PROMPT = ( - 'If you are an expert in SQL, please generate a good SQL Query for ' - 'Question based on the CREATE TABLE statement.\n' - '### Question: {context}\n{question}\n### Query: ') - return { - 'conversation': [{ - 'input': PROMPT.format(**example), - 'output': example['answer'] - }] - } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/__init__.py b/xtuner/datasets/map_fns/dataset_map_fns/__init__.py similarity index 88% rename from xtuner/datasets/map_fns/dataset_map_fn/__init__.py rename to xtuner/datasets/map_fns/dataset_map_fns/__init__.py index 8bd50c67a..a613c5a0b 100644 --- a/xtuner/datasets/map_fns/dataset_map_fn/__init__.py +++ b/xtuner/datasets/map_fns/dataset_map_fns/__init__.py @@ -8,6 +8,7 @@ from .law_reference_map_fn import law_reference_map_fn from .medical_map_fn import medical_map_fn from .oasst1_map_fn import oasst1_map_fn +from .openai_map_fn import openai_map_fn from .openorca_map_fn import openorca_map_fn from .sql_map_fn import sql_map_fn from .tiny_codes_map_fn import tiny_codes_map_fn @@ -16,5 +17,5 @@ 'alpaca_map_fn', 'alpaca_zh_map_fn', 'oasst1_map_fn', 'arxiv_map_fn', 'medical_map_fn', 'openorca_map_fn', 'code_alpaca_map_fn', 'tiny_codes_map_fn', 'colors_map_fn', 'law_reference_map_fn', - 'crime_kg_assitant_map_fn', 'sql_map_fn' + 'crime_kg_assitant_map_fn', 'sql_map_fn', 'openai_map_fn' ] diff --git a/xtuner/datasets/map_fns/dataset_map_fns/alpaca_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/alpaca_map_fn.py new file mode 100644 index 000000000..130fe2d30 --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/alpaca_map_fn.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def alpaca_map_fn(example): + if example.get('output') == '': + return {'conversation': [{'input': '', 'output': ''}]} + else: + return { + 'conversation': [{ + 'input': + '{instruction}\n{input}'.format(**example), + 'output': + example['output'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fn/tiny_codes_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/alpaca_zh_map_fn.py similarity index 53% rename from xtuner/datasets/map_fns/dataset_map_fn/tiny_codes_map_fn.py rename to xtuner/datasets/map_fns/dataset_map_fns/alpaca_zh_map_fn.py index cd1f62fa2..94e4502b7 100644 --- a/xtuner/datasets/map_fns/dataset_map_fn/tiny_codes_map_fn.py +++ b/xtuner/datasets/map_fns/dataset_map_fns/alpaca_zh_map_fn.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -def tiny_codes_map_fn(example): +def alpaca_zh_map_fn(example): return { 'conversation': [{ 'input': - '### Human: {prompt}\n### Bot: '.format(**example), + '{instruction_zh}\n{input_zh}'.format(**example), 'output': - example['response'] + example['output_zh'] }] } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/arxiv_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/arxiv_map_fn.py new file mode 100644 index 000000000..f2f06a0fe --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/arxiv_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def arxiv_map_fn(example): + return { + 'conversation': [{ + 'input': example['abstract'], + 'output': example['title'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/code_alpaca_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/code_alpaca_map_fn.py new file mode 100644 index 000000000..cd54d4ec9 --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/code_alpaca_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def code_alpaca_map_fn(example): + return { + 'conversation': [{ + 'input': example['prompt'], + 'output': example['completion'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/colors_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/colors_map_fn.py new file mode 100644 index 000000000..8ebc18aae --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/colors_map_fn.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def colors_map_fn(example): + desc = ':'.join(example['description'].split(':')[1:]).strip() + return {'conversation': [{'input': desc, 'output': example['color']}]} diff --git a/xtuner/datasets/map_fns/dataset_map_fns/crime_kg_assitant_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/crime_kg_assitant_map_fn.py new file mode 100644 index 000000000..10eefe4fc --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/crime_kg_assitant_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def crime_kg_assitant_map_fn(example): + return { + 'conversation': [{ + 'input': example['input'], + 'output': example['output'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/law_reference_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/law_reference_map_fn.py new file mode 100644 index 000000000..74e687927 --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/law_reference_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def law_reference_map_fn(example): + return { + 'conversation': [{ + 'input': example['question'], + 'output': example['answer'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/medical_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/medical_map_fn.py new file mode 100644 index 000000000..497b8231f --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/medical_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def medical_map_fn(example): + return { + 'conversation': [{ + 'input': '{instruction}\n{input}'.format(**example), + 'output': example['output'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/oasst1_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/oasst1_map_fn.py new file mode 100644 index 000000000..e1e13a015 --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/oasst1_map_fn.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def oasst1_map_fn(example): + r"""Example before preprocessing: + example['text'] = '### Human: Can you explain xxx' + '### Assistant: Sure! xxx' + '### Human: I didn't understand how xxx' + '### Assistant: It has to do with a process xxx.' + + Example after preprocessing: + example['conversation'] = [ + { + 'input': 'Can you explain xxx', + 'output': 'Sure! xxx' + }, + { + 'input': 'I didn't understand how xxx', + 'output': 'It has to do with a process xxx.' + } + ] + """ + data = [] + for sentence in example['text'].strip().split('###'): + sentence = sentence.strip() + if sentence[:6] == 'Human:': + data.append(sentence[6:].strip()) + elif sentence[:10] == 'Assistant:': + data.append(sentence[10:].strip()) + if len(data) % 2: + # The last round of conversation solely consists of input + # without any output. + # Discard the input part of the last round, as this part is ignored in + # the loss calculation. + data.pop() + conversation = [] + for i in range(0, len(data), 2): + single_turn_conversation = {'input': data[i], 'output': data[i + 1]} + conversation.append(single_turn_conversation) + return {'conversation': conversation} diff --git a/xtuner/datasets/map_fns/dataset_map_fns/openai_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/openai_map_fn.py new file mode 100644 index 000000000..1c14a106c --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/openai_map_fn.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def openai_map_fn(example): + """ + Example before preprocessing: + example["messages"] = [ + { "role": "system", "content": "You are an assistant that + occasionally misspells words." }, + { "role": "user", "content": "Tell me a story." }, + { "role": "assistant", "content": "One day a student + went to schoool." } + ] + Example after preprocessing: + example["conversation"] = [ + { + "input": "You are an assistant that occasionally misspells + words. Tell me a story.", + "output": "One day a student went to schoool." + } + ] + """ + messages = example['messages'] + if len(messages) == 0: + return {'conversation': [{'input': '', 'output': ''}]} + if messages[0]['role'] == 'system': + messages[1][ + 'content'] = messages[0]['content'] + ' ' + messages[1]['content'] + messages = messages[1:] + if len(messages) % 2: + # The last round of conversation solely consists of input + # without any output. + # Discard the input part of the last round, as this part is ignored in + # the loss calculation. + messages.pop() + conversation = [] + for i in range(0, len(messages), 2): + single_turn_conversation = { + 'input': messages[i]['content'], + 'output': messages[i + 1]['content'] + } + conversation.append(single_turn_conversation) + return {'conversation': conversation} diff --git a/xtuner/datasets/map_fns/dataset_map_fns/openorca_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/openorca_map_fn.py new file mode 100644 index 000000000..74c992447 --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/openorca_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def openorca_map_fn(example): + return { + 'conversation': [{ + 'input': example['question'], + 'output': example['response'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/sql_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/sql_map_fn.py new file mode 100644 index 000000000..f74b5d654 --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/sql_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def sql_map_fn(example): + return { + 'conversation': [{ + 'input': '{context}\n{question}'.format(**example), + 'output': example['answer'] + }] + } diff --git a/xtuner/datasets/map_fns/dataset_map_fns/tiny_codes_map_fn.py b/xtuner/datasets/map_fns/dataset_map_fns/tiny_codes_map_fn.py new file mode 100644 index 000000000..a3498c47a --- /dev/null +++ b/xtuner/datasets/map_fns/dataset_map_fns/tiny_codes_map_fn.py @@ -0,0 +1,8 @@ +# Copyright (c) OpenMMLab. All rights reserved. +def tiny_codes_map_fn(example): + return { + 'conversation': [{ + 'input': example['prompt'], + 'output': example['response'] + }] + } diff --git a/xtuner/datasets/map_fns/model_map_fn/__init__.py b/xtuner/datasets/map_fns/model_map_fn/__init__.py deleted file mode 100644 index 0b9d219b1..000000000 --- a/xtuner/datasets/map_fns/model_map_fn/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .internlm_map_fn import internlm_map_fn -from .llama2_map_fn import llama2_map_fn - -__all__ = ['internlm_map_fn', 'llama2_map_fn'] diff --git a/xtuner/datasets/map_fns/model_map_fn/internlm_map_fn.py b/xtuner/datasets/map_fns/model_map_fn/internlm_map_fn.py deleted file mode 100644 index 31568c4de..000000000 --- a/xtuner/datasets/map_fns/model_map_fn/internlm_map_fn.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def internlm_map_fn(example): - user = '<|User|>' - eoh = '' - assistant = '<|Bot|>' - conversation = example.get('conversation', []) - for single_turn_conversation in conversation: - input = single_turn_conversation['input'] - single_turn_conversation['input'] = \ - f'{user}:{input}{eoh}\n{assistant}:' - return {'conversation': conversation} diff --git a/xtuner/datasets/map_fns/model_map_fn/llama2_map_fn.py b/xtuner/datasets/map_fns/model_map_fn/llama2_map_fn.py deleted file mode 100644 index 170647611..000000000 --- a/xtuner/datasets/map_fns/model_map_fn/llama2_map_fn.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -def llama2_map_fn(example): - B_INST, E_INST = '[INST]', '[/INST]' - B_SYS, E_SYS = '<>\n', '\n<>\n\n' - - DEFAULT_SYSTEM_PROMPT = \ - 'You are a helpful, respectful and honest assistant. Always answer ' \ - 'as helpfully as possible, while being safe. Your answers should ' \ - 'not include any harmful, unethical, racist, sexist, toxic, ' \ - 'dangerous, or illegal content. Please ensure that your responses ' \ - 'are socially unbiased and positive in nature.' - - conversation = example.get('conversation', []) - for single_turn_conversation in conversation: - input = single_turn_conversation['input'] - single_turn_conversation['input'] = f'{B_INST} {B_SYS} ' \ - f'{DEFAULT_SYSTEM_PROMPT} {E_SYS}{input} {E_INST}' - - return {'conversation': conversation} diff --git a/xtuner/datasets/map_fns/template_map_fn.py b/xtuner/datasets/map_fns/template_map_fn.py new file mode 100644 index 000000000..b748220e5 --- /dev/null +++ b/xtuner/datasets/map_fns/template_map_fn.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from functools import partial + + +def template_map_fn(example, template): + conversation = example.get('conversation', []) + for i, single_turn_conversation in enumerate(conversation): + input = single_turn_conversation['input'] + if i == 0: + single_turn_conversation[ + 'input'] = template.INSTRUCTION_START.format(input=input) + else: + single_turn_conversation['input'] = template.INSTRUCTION.format( + input=input) + + return {'conversation': conversation} + + +def template_map_fn_factory(template): + return partial(template_map_fn, template=template) diff --git a/xtuner/engine/__init__.py b/xtuner/engine/__init__.py index e1ad6a87b..6e2d891c6 100644 --- a/xtuner/engine/__init__.py +++ b/xtuner/engine/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .hooks import LogSampleHook, SampleGenerateHook +from .hooks import DatasetInfoHook, EvaluateChatHook -__all__ = ['SampleGenerateHook', 'LogSampleHook'] +__all__ = ['EvaluateChatHook', 'DatasetInfoHook'] diff --git a/xtuner/engine/hooks/__init__.py b/xtuner/engine/hooks/__init__.py index 6eae2792e..a8ad7f8ec 100644 --- a/xtuner/engine/hooks/__init__.py +++ b/xtuner/engine/hooks/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .log_data_sample import LogSampleHook -from .sample_generate_hook import SampleGenerateHook +from .dataset_info_hook import DatasetInfoHook +from .evaluate_chat_hook import EvaluateChatHook -__all__ = ['SampleGenerateHook', 'LogSampleHook'] +__all__ = ['EvaluateChatHook', 'DatasetInfoHook'] diff --git a/xtuner/engine/hooks/log_data_sample.py b/xtuner/engine/hooks/dataset_info_hook.py similarity index 97% rename from xtuner/engine/hooks/log_data_sample.py rename to xtuner/engine/hooks/dataset_info_hook.py index 46bb8e13c..2540c5ca4 100644 --- a/xtuner/engine/hooks/log_data_sample.py +++ b/xtuner/engine/hooks/dataset_info_hook.py @@ -4,7 +4,7 @@ from xtuner.registry import BUILDER -class LogSampleHook(Hook): +class DatasetInfoHook(Hook): def __init__(self, tokenizer): self.tokenizer = BUILDER.build(tokenizer) diff --git a/xtuner/engine/hooks/sample_generate_hook.py b/xtuner/engine/hooks/evaluate_chat_hook.py similarity index 91% rename from xtuner/engine/hooks/sample_generate_hook.py rename to xtuner/engine/hooks/evaluate_chat_hook.py index 78e901c09..770fb4766 100644 --- a/xtuner/engine/hooks/sample_generate_hook.py +++ b/xtuner/engine/hooks/evaluate_chat_hook.py @@ -7,7 +7,7 @@ from xtuner.utils import StopWordStoppingCriteria -class SampleGenerateHook(Hook): +class EvaluateChatHook(Hook): def __init__(self, tokenizer, @@ -51,7 +51,7 @@ def _generate_samples(self, runner, max_new_tokens=None): f'{self.tokenizer.decode(generation_output[0])}\n') def before_train(self, runner): - runner.logger.info('before_train in SampleGenerateHook.') + runner.logger.info('before_train in EvaluateChatHook .') self._generate_samples(runner, max_new_tokens=50) def after_train_iter(self, @@ -62,11 +62,11 @@ def after_train_iter(self, if self.every_n_iters is None or (batch_idx + 1) % self.every_n_iters != 0: return - runner.logger.info('after_train_iter in SampleGenerateHook.') + runner.logger.info('after_train_iter in EvaluateChatHook .') self._generate_samples(runner) def after_val(self, runner) -> None: if self.every_n_iters is not None: return - runner.logger.info('after_val in SampleGenerateHook.') + runner.logger.info('after_val in EvaluateChatHook .') self._generate_samples(runner) diff --git a/xtuner/tools/chat.py b/xtuner/tools/chat.py index 2126b5754..6f6c3f3ab 100644 --- a/xtuner/tools/chat.py +++ b/xtuner/tools/chat.py @@ -21,6 +21,11 @@ def parse_args(): help='config file name or path. Note: Please use the original ' 'configs, instead of the automatically saved log configs.') parser.add_argument('--adapter', default=None, help='adapter model') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default=None, + help='Specify a prompt option') parser.add_argument( '--is-deepspeed', action='store_true', @@ -34,11 +39,6 @@ def parse_args(): '--no-streamer', action='store_true', help='Whether to with streamer') parser.add_argument('--command-stop-word', default=None, help='Stop key') parser.add_argument('--answer-stop-word', default=None, help='Stop key') - parser.add_argument( - '--prompt-template', - choices=PROMPT_TEMPLATE.keys(), - default=None, - help='Specify a prompt option') parser.add_argument( '--max-new-tokens', type=int, diff --git a/xtuner/tools/chat_hf.py b/xtuner/tools/chat_hf.py index e778c75a4..994bdabeb 100644 --- a/xtuner/tools/chat_hf.py +++ b/xtuner/tools/chat_hf.py @@ -16,6 +16,11 @@ def parse_args(): parser.add_argument( 'model_name_or_path', help='Hugging Face model name or path') parser.add_argument('--adapter', default=None, help='adapter name or path') + parser.add_argument( + '--prompt-template', + choices=PROMPT_TEMPLATE.keys(), + default=None, + help='Specify a prompt option') parser.add_argument( '--bot-name', type=str, default='BOT', help='Name for Bot') parser.add_argument( @@ -27,11 +32,6 @@ def parse_args(): '--no-streamer', action='store_true', help='Whether to with streamer') parser.add_argument('--command-stop-word', default=None, help='Stop key') parser.add_argument('--answer-stop-word', default=None, help='Stop key') - parser.add_argument( - '--prompt-template', - choices=PROMPT_TEMPLATE.keys(), - default=None, - help='Specify a prompt option') parser.add_argument( '--max-new-tokens', type=int,