diff --git a/.gitignore b/.gitignore index f9a95ef82..9204ac688 100644 --- a/.gitignore +++ b/.gitignore @@ -104,7 +104,6 @@ venv.bak/ .mypy_cache/ # custom -/data .vscode .idea .DS_Store diff --git a/data/toy_custom_incremental_data.json b/data/toy_custom_incremental_data.json new file mode 100644 index 000000000..bfd4950f0 --- /dev/null +++ b/data/toy_custom_incremental_data.json @@ -0,0 +1,16 @@ +[{ + "conversation":[ + { + "input": "", + "output": "I am an artificial intelligence (AI) assistant named Puyu. I was created by the Shanghai AI Laboratory and my purpose is to assist users with various tasks through natural language processing technology." + } + ] +}, +{ + "conversation":[ + { + "input": "", + "output": "I am an artificial intelligence programmed to assist with various types of tasks, including answering questions, providing information, and performing automated processes." + } + ] +}] diff --git a/data/toy_custom_multi_turn_data.json b/data/toy_custom_multi_turn_data.json new file mode 100644 index 000000000..a5c8a47fb --- /dev/null +++ b/data/toy_custom_multi_turn_data.json @@ -0,0 +1,32 @@ +[{ + "conversation":[ + { + "input": "Hello?", + "output": "Hello! How can I help you?" + }, + { + "input": "What's the date today?", + "output": "Today is Monday, August 14, 2023." + }, + { + "input": "Thank you!", + "output": "You are welcome." + } + ] +}, +{ + "conversation":[ + { + "input": "Hello?", + "output": "Hello! How can I help you?" + }, + { + "input": "How's the weather today in Rosso?", + "output": "The weather in Rosso on Wednesday, August 16th, is going to be cloudy for most of the day, together with moderate rain around noon." + }, + { + "input": "Thank you!", + "output": "You are welcome." + } + ] +}] diff --git a/data/toy_custom_single_turn_data.json b/data/toy_custom_single_turn_data.json new file mode 100644 index 000000000..5785a1d33 --- /dev/null +++ b/data/toy_custom_single_turn_data.json @@ -0,0 +1,18 @@ +[{ + "conversation": + [ + { + "input": "Give three tips for staying healthy.", + "output": "1.Eat a balanced diet. 2. Exercise regularly. 3. Get enough sleep." + } + ] +}, +{ + "conversation": + [ + { + "input": "How to study English?", + "output": "1. Set clear goals. 2. Create a study plan. 3. Build vocabulary. 4. Practice speaking." + } + ] +}] diff --git a/docs/en/user_guides/incremental_pretraining.md b/docs/en/user_guides/incremental_pretraining.md index 1b4ece29d..607e91b3e 100644 --- a/docs/en/user_guides/incremental_pretraining.md +++ b/docs/en/user_guides/incremental_pretraining.md @@ -95,7 +95,9 @@ The following modifications need to be made to the config file copied in Step 3: from xtuner.dataset import process_hf_dataset from datasets import load_dataset - from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory -+ from map_fn import oasst1_incremental_map_fn ++ from mmengine.config import read_base ++ with read_base(): ++ from .map_fn import oasst1_incremental_map_fn ... ####################################################################### # PART 1 Settings # @@ -127,8 +129,33 @@ train_dataloader = dict( sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=default_collate_fn)) ... +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, +- instruction=prompt_template.INSTRUCTION_START) ++ ) +] +... +``` + +#### Step 5, Log Processed Dataset (Optional) + +After modifying the config file, you can print the first data of the processed dataset to verify whether the dataset has been constructed correctly. + +```bash +xtuner log-dataset $CONFIG ``` +`$CONFIG` represents the file path of the modified configuration file in Step 4. + ### Using Custom Datasets When using custom datasets for incremental pre-training, we recommend constructing the dataset according to the [incremental pre-training data format](./dataset_format.md#incremental-pre-training-dataset-format) defined by XTuner. If the custom dataset is in other formats such as oasst1, refer to the section on [Using Dataset in HuggingFace Hub](#using-dataset-in-huggingface-hub). @@ -191,18 +218,20 @@ from datasets import load_dataset ####################################################################### - data_path = 'timdettmers/openassistant-guanaco' - prompt_template = PROMPT_TEMPLATE.openassistant -+ data_path = 'path/to/your/data' ++ data_path = 'path/to/your/json/data' ... ####################################################################### # STEP 3 Dataset & Dataloader # ####################################################################### train_dataset = dict( type=process_hf_dataset, - dataset=dict(type=load_dataset, path=data_path), +- dataset=dict(type=load_dataset, path=data_path), ++ dataset=dict( ++ type=load_dataset, path='json', data_files=dict(train=data_path)), tokenizer=tokenizer, max_length=max_length, - dataset_map_fn=oasst1_map_fn, -+ dataset_map_fn=oasst1_incremental_map_fn, ++ dataset_map_fn=None, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template_map_fn=None, @@ -217,4 +246,29 @@ train_dataloader = dict( sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=default_collate_fn)) ... +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, +- instruction=prompt_template.INSTRUCTION_START) ++ ) +] +... ``` + +#### Step 5, Check custom Dataset (Optional) + +After modifying the config file, you can execute the 'xtuner/tools/check_custom_dataset.py' script to verify the correct construction of the dataset. + +```bash +xtuner check-custom-dataset $CONFIG +``` + +`$CONFIG` represents the file path of the modified configuration file in Step 4. diff --git a/docs/en/user_guides/multi_turn_conversation.md b/docs/en/user_guides/multi_turn_conversation.md index 45c1f4aed..e265af305 100644 --- a/docs/en/user_guides/multi_turn_conversation.md +++ b/docs/en/user_guides/multi_turn_conversation.md @@ -155,7 +155,9 @@ from xtuner.dataset import process_hf_dataset from datasets import load_dataset - from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory + from xtuner.dataset.map_fns import template_map_fn_factory -+ from .map_fn import oasst1_multi_turns_map_fn ++ from mmengine.config import read_base ++ with read_base(): ++ from .map_fn import oasst1_multi_turns_map_fn ... ####################################################################### # PART 1 Settings # @@ -189,6 +191,16 @@ train_dataloader = dict( ... ``` +#### Step 6, Log Processed Dataset (Optional) + +After modifying the config file, you can print the first data of the processed dataset to verify whether the dataset has been constructed correctly. + +```bash +xtuner log-dataset $CONFIG +``` + +`$CONFIG` represents the file path of the modified configuration file in Step 5. + ## Using Custom Datasets When using a custom multi-turn dialogue dataset for command fine-tuning, we recommend constructing the dataset in the [multi-turn dialogue data format](./dataset_format.md#multi-turn-dialogue-dataset-format) as defined by XTuner. If the custom dataset format is oasst1 or other formats, you can refer to the section on [Using Datasets in HuggingFace Hub](#using-dataset-in-huggingface-hub). @@ -260,7 +272,7 @@ from datasets import load_dataset # PART 1 Settings # ####################################################################### - data_path = 'timdettmers/openassistant-guanaco' -+ data_path = 'path/to/your/data' ++ data_path = 'path/to/your/json/data' + prompt_template = PROMPT_TEMPLATE.openassistant ... @@ -269,7 +281,9 @@ from datasets import load_dataset ####################################################################### train_dataset = dict( type=process_hf_dataset, - dataset=dict(type=load_dataset, path=data_path), +- dataset=dict(type=load_dataset, path=data_path), ++ dataset=dict( ++ type=load_dataset, path='json', data_files=dict(train=data_path)), tokenizer=tokenizer, max_length=max_length, + dataset_map_fn=None, @@ -287,3 +301,13 @@ train_dataloader = dict( collate_fn=dict(type=default_collate_fn)) ... ``` + +#### Step 6, Check Processed Dataset (Optional) + +After modifying the config file, you can execute the 'xtuner/tools/check_custom_dataset.py' script to verify the correct construction of the dataset. + +```bash +xtuner check-custom-dataset $CONFIG +``` + +`$CONFIG` represents the file path of the modified configuration file in Step 5. diff --git a/docs/en/user_guides/single_turn_conversation.md b/docs/en/user_guides/single_turn_conversation.md index b52a194e3..28e477b60 100644 --- a/docs/en/user_guides/single_turn_conversation.md +++ b/docs/en/user_guides/single_turn_conversation.md @@ -129,7 +129,9 @@ from xtuner.dataset import process_hf_dataset from datasets import load_dataset - from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory + from xtuner.dataset.map_fns import template_map_fn_factory -+ from .map_fn import alpaca_map_fn ++ from mmengine.config import read_base ++ with read_base(): ++ from .map_fn import alpaca_map_fn ... ####################################################################### # PART 1 Settings # @@ -164,6 +166,16 @@ train_dataloader = dict( ... ``` +#### Step 6, Log Processed Dataset (Optional) + +After modifying the config file, you can print the first data of the processed dataset to verify whether the dataset has been constructed correctly. + +```bash +xtuner log-dataset $CONFIG +``` + +`$CONFIG` represents the file path of the modified configuration file in Step 5. + ## Using Custom Datasets When using a custom single-turn dialogue dataset for command fine-tuning, we recommend constructing the dataset in the [single-turn dialogue data format](./dataset_format.md#single-turn-dialogue-dataset-format) as defined by XTuner. If the custom dataset format is oasst1 or other formats, you can refer to the section on [Using Datasets in HuggingFace Hub](#using-dataset-in-huggingface-hub). @@ -228,7 +240,7 @@ from datasets import load_dataset ####################################################################### - alpaca_zh_path = 'silk-road/alpaca-data-gpt4-chinese' - alpaca_en_path = 'tatsu-lab/alpaca' -+ data_path = 'path/to/your/data' ++ data_path = 'path/to/your/json/data' + prompt_template = PROMPT_TEMPLATE.alpaca ####################################################################### @@ -236,7 +248,9 @@ from datasets import load_dataset ####################################################################### train_dataset = dict( type=process_hf_dataset, - dataset=dict(type=load_dataset, path=data_path), +- dataset=dict(type=load_dataset, path=data_path), ++ dataset=dict( ++ type=load_dataset, path='json', data_files=dict(train=data_path)), tokenizer=tokenizer, max_length=max_length, + dataset_map_fn=None, @@ -254,3 +268,13 @@ train_dataloader = dict( collate_fn=dict(type=default_collate_fn)) ... ``` + +#### Step 6, Check Processed Dataset (Optional) + +After modifying the config file, you can execute the 'xtuner/tools/check_custom_dataset.py' script to verify the correct construction of the dataset. + +```bash +xtuner check-custom-dataset $CONFIG +``` + +`$CONFIG` represents the file path of the modified configuration file in Step 5. diff --git a/docs/zh_cn/user_guides/incremental_pretraining.md b/docs/zh_cn/user_guides/incremental_pretraining.md index 4964396ea..852d35915 100644 --- a/docs/zh_cn/user_guides/incremental_pretraining.md +++ b/docs/zh_cn/user_guides/incremental_pretraining.md @@ -95,7 +95,9 @@ xtuner copy-cfg internlm_7b_qlora_oasst1_e3 . from xtuner.dataset import process_hf_dataset from datasets import load_dataset - from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory -+ from map_fn import oasst1_incremental_map_fn ++ from mmengine.config import read_base ++ with read_base(): ++ from .map_fn import oasst1_incremental_map_fn ... ####################################################################### # PART 1 Settings # @@ -127,8 +129,33 @@ train_dataloader = dict( sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=default_collate_fn)) ... +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, +- instruction=prompt_template.INSTRUCTION_START) ++ ) +] +... ``` +#### Step 5, 打印数据集(可选) + +在修改配置文件后,可以打印处理后数据集的第一条数据,以验证数据集是否正确构建。 + +```bash +xtuner log-dataset $CONFIG +``` + +其中 `$CONFIG` 是 Step 4 修改过的 config 的文件路径。 + ### 使用自定义数据集 在使用自定义数据集进行增量预训练时,我们推荐将数据集构造为 XTuner 定义的[增量预训练数据格式](./dataset_format.md#增量预训练数据集格式)。若自定义数据集格式为 `oasst1` 等其他格式,可参考[使用HuggingFace Hub数据集](#使用huggingface-hub数据集)一节。 @@ -204,7 +231,7 @@ train_dataset = dict( tokenizer=tokenizer, max_length=max_length, - dataset_map_fn=oasst1_map_fn, -+ dataset_map_fn=oasst1_incremental_map_fn, ++ dataset_map_fn=None, - template_map_fn=dict( - type=template_map_fn_factory, template=prompt_template), + template_map_fn=None, @@ -219,4 +246,29 @@ train_dataloader = dict( sampler=dict(type=DefaultSampler, shuffle=True), collate_fn=dict(type=default_collate_fn)) ... +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, +- instruction=prompt_template.INSTRUCTION_START) ++ ) +] +... +``` + +#### Step 5, 检查数据集(可选) + +在修改配置文件后,可以运行`xtuner/tools/check_custom_dataset.py`脚本验证数据集是否正确构建。 + +```bash +xtuner check-custom-dataset $CONFIG ``` + +其中 `$CONFIG` 是 Step 4 修改过的 config 的文件路径。 diff --git a/docs/zh_cn/user_guides/multi_turn_conversation.md b/docs/zh_cn/user_guides/multi_turn_conversation.md index a05213aaf..b52dd5ea8 100644 --- a/docs/zh_cn/user_guides/multi_turn_conversation.md +++ b/docs/zh_cn/user_guides/multi_turn_conversation.md @@ -155,7 +155,9 @@ from xtuner.dataset import process_hf_dataset from datasets import load_dataset - from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory + from xtuner.dataset.map_fns import template_map_fn_factory -+ from .map_fn import oasst1_multi_turns_map_fn ++ from mmengine.config import read_base ++ with read_base(): ++ from .map_fn import oasst1_multi_turns_map_fn ... ####################################################################### # PART 1 Settings # @@ -189,6 +191,16 @@ train_dataloader = dict( ... ``` +#### Step 6, 打印数据集(可选) + +在修改配置文件后,可以打印处理后数据集的第一条数据,以验证数据集是否正确构建。 + +```bash +xtuner log-dataset $CONFIG +``` + +其中 `$CONFIG` 是 Step 5 修改过的 config 的文件路径。 + ## 使用自定义数据集 在使用自定义多轮对话数据集进行指令微调时,我们推荐将数据集构造为 XTuner 定义的[多轮对话数据格式](./dataset_format.md#多轮对话数据集格式)。若自定义数据集格式为 `oasst1` 等其他格式,可参考[使用 HuggingFace Hub 数据集](#使用-huggingface-hub-数据集)一节。 @@ -289,3 +301,13 @@ train_dataloader = dict( collate_fn=dict(type=default_collate_fn)) ... ``` + +#### Step 6, 检查数据集(可选) + +在修改配置文件后,可以运行`xtuner/tools/check_custom_dataset.py`脚本验证数据集是否正确构建。 + +```bash +xtuner check-custom-dataset $CONFIG +``` + +其中 `$CONFIG` 是 Step 5 修改过的 config 的文件路径。 diff --git a/docs/zh_cn/user_guides/single_turn_conversation.md b/docs/zh_cn/user_guides/single_turn_conversation.md index 7336a3335..deb6f2d95 100644 --- a/docs/zh_cn/user_guides/single_turn_conversation.md +++ b/docs/zh_cn/user_guides/single_turn_conversation.md @@ -129,7 +129,9 @@ from xtuner.dataset import process_hf_dataset from datasets import load_dataset - from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory + from xtuner.dataset.map_fns import template_map_fn_factory -+ from .map_fn import alpaca_map_fn ++ from mmengine.config import read_base ++ with read_base(): ++ from .map_fn import alpaca_map_fn ... ####################################################################### # PART 1 Settings # @@ -164,6 +166,16 @@ train_dataloader = dict( ... ``` +#### Step 6, 打印数据集(可选) + +在修改配置文件后,可以打印处理后数据集的第一条数据,以验证数据集是否正确构建。 + +```bash +xtuner log-dataset $CONFIG +``` + +其中 `$CONFIG` 是 Step 5 修改过的 config 的文件路径。 + ## 使用自定义数据集 在使用自定义单轮对话数据集进行指令微调时,我们推荐将数据集构造为XTuner定义的[单轮对话数据格式](./dataset_format.md#单轮对话数据集格式)。若自定义数据集格式为 `alpaca` 等其他格式,可参考[使用 HuggingFace Hub 数据集](#使用-huggingface-hub-数据集)一节。 @@ -256,3 +268,13 @@ train_dataloader = dict( collate_fn=dict(type=default_collate_fn)) ... ``` + +#### Step 6, 检查数据集(可选) + +在修改配置文件后,可以运行`xtuner/tools/check_custom_dataset.py`脚本验证数据集是否正确构建。 + +```bash +xtuner check-custom-dataset $CONFIG +``` + +其中 `$CONFIG` 是 Step 5 修改过的 config 的文件路径。 diff --git a/xtuner/dataset/huggingface.py b/xtuner/dataset/huggingface.py index 132d762ab..051193e6a 100644 --- a/xtuner/dataset/huggingface.py +++ b/xtuner/dataset/huggingface.py @@ -7,7 +7,7 @@ from mmengine import print_log from mmengine.config import Config, ConfigDict -from xtuner.registry import BUILDER +from xtuner.registry import BUILDER, MAP_FUNC from .utils import Packer, encode_fn @@ -71,6 +71,9 @@ def process_hf_dataset(dataset, # Extract the useful data for training from the original dataset. if dataset_map_fn is not None: + if isinstance(dataset_map_fn, str): + dataset_map_fn = MAP_FUNC.get(dataset_map_fn) + dataset = dataset.map(dataset_map_fn) # Add prompt template, such as ### Human: xxx ###Assistant: xxx diff --git a/xtuner/entry_point.py b/xtuner/entry_point.py index f36f2cc39..b7e4edf91 100644 --- a/xtuner/entry_point.py +++ b/xtuner/entry_point.py @@ -8,13 +8,14 @@ from mmengine.logging import print_log import xtuner -from xtuner.tools import chat, copy_cfg, list_cfg, test, train +from xtuner.tools import (chat, check_custom_dataset, copy_cfg, list_cfg, + log_dataset, test, train) from xtuner.tools.data_preprocess import arxiv as arxiv_preprocess from xtuner.tools.model_converters import merge, pth_to_hf, split # Define valid modes -MODES = ('list-cfg', 'copy-cfg', 'train', 'test', 'chat', 'convert', - 'preprocess') +MODES = ('list-cfg', 'copy-cfg', 'log-dataset', 'check-custom-dataset', + 'train', 'test', 'chat', 'convert', 'preprocess') CLI_HELP_MSG = \ f""" @@ -46,6 +47,10 @@ xtuner chat $NAME_OR_PATH_TO_LLM --adapter $NAME_OR_PATH_TO_ADAPTER --prompt-template $PROMPT_TEMPLATE 6-1. Preprocess arxiv dataset: xtuner preprocess arxiv $SRC_FILE $DST_FILE --start-date $START_DATE --categories $CATEGORIES + 7-1. Log processed dataset: + xtuner log-dataset $CONFIG + 7-2. Verify the correctness of the config file for the custom dataset. + xtuner check-custom-dataset Run special commands: @@ -112,6 +117,8 @@ modes = { 'list-cfg': list_cfg.__file__, 'copy-cfg': copy_cfg.__file__, + 'log-dataset': log_dataset.__file__, + 'check-custom-dataset': check_custom_dataset.__file__, 'train': train.__file__, 'test': test.__file__, 'chat': chat.__file__, diff --git a/xtuner/registry.py b/xtuner/registry.py index 0835da79d..7c8907e0b 100644 --- a/xtuner/registry.py +++ b/xtuner/registry.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmengine.registry import Registry -__all__ = ['BUILDER'] +__all__ = ['BUILDER', 'MAP_FUNC'] BUILDER = Registry('builder') +MAP_FUNC = Registry('map_fn') diff --git a/xtuner/tools/check_custom_dataset.py b/xtuner/tools/check_custom_dataset.py new file mode 100644 index 000000000..53fd070c3 --- /dev/null +++ b/xtuner/tools/check_custom_dataset.py @@ -0,0 +1,160 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from functools import partial + +import numpy as np +from datasets import DatasetDict +from mmengine.config import Config + +from xtuner.dataset.utils import Packer, encode_fn +from xtuner.registry import BUILDER + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Verify the correctness of the config file for the ' + 'custom dataset.') + parser.add_argument( + 'config', + help='config file name or path. Note: Please use the original ' + 'configs, instead of the automatically saved log configs.') + args = parser.parse_args() + return args + + +def is_standard_format(dataset): + example = next(iter(dataset)) + if 'conversation' not in example: + return False + conversation = example['conversation'] + if not isinstance(conversation, list): + return False + for item in conversation: + if (not isinstance(item, dict)) or ('input' + not in item) or ('output' + not in item): + return False + input, output = item['input'], item['output'] + if (not isinstance(input, str)) or (not isinstance(output, str)): + return False + return True + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + + tokenizer = BUILDER.build(cfg.tokenizer) + if cfg.get('framework', 'mmengine').lower() == 'huggingface': + train_dataset = cfg.train_dataset + else: + train_dataset = cfg.train_dataloader.dataset + + dataset = train_dataset.dataset + max_length = train_dataset.max_length + dataset_map_fn = train_dataset.get('dataset_map_fn', None) + template_map_fn = train_dataset.get('template_map_fn', None) + max_dataset_length = train_dataset.get('max_dataset_length', 10) + split = train_dataset.get('split', 'train') + remove_unused_columns = train_dataset.get('remove_unused_columns', False) + rename_maps = train_dataset.get('rename_maps', []) + shuffle_before_pack = train_dataset.get('shuffle_before_pack', True) + pack_to_max_length = train_dataset.get('pack_to_max_length', True) + input_ids_with_output = train_dataset.get('input_ids_with_output', True) + + if dataset.get('path', '') != 'json': + raise ValueError( + 'You are using custom datasets for SFT. ' + 'The custom datasets should be in json format. To load your JSON ' + 'file, you can use the following code snippet: \n' + '"""\nfrom datasets import load_dataset \n' + 'dataset = dict(type=load_dataset, path=\'json\', ' + 'data_files=\'your_json_file.json\')\n"""\n' + 'For more details, please refer to Step 5 in the ' + '`Using Custom Datasets` section of the documentation found at' + ' docs/zh_cn/user_guides/single_turn_conversation.md.') + + try: + dataset = BUILDER.build(dataset) + except RuntimeError: + raise RuntimeError( + 'Unable to load the custom JSON file using ' + '`datasets.load_dataset`. Your data-related config is ' + f'{train_dataset}. Please refer to the official documentation on' + ' `load_dataset` (https://huggingface.co/docs/datasets/loading) ' + 'for more details.') + + if isinstance(dataset, DatasetDict): + dataset = dataset[split] + + if not is_standard_format(dataset) and dataset_map_fn is None: + raise ValueError( + 'If the custom dataset is not in the XTuner-defined ' + 'format, please utilize `dataset_map_fn` to map the original data' + ' to the standard format. For more details, please refer to ' + 'Step 1 and Step 5 in the `Using Custom Datasets` section of the ' + 'documentation found at ' + '`docs/zh_cn/user_guides/single_turn_conversation.md`.') + + if is_standard_format(dataset) and dataset_map_fn is not None: + raise ValueError( + 'If the custom dataset is already in the XTuner-defined format, ' + 'please set `dataset_map_fn` to None.' + 'For more details, please refer to Step 1 and Step 5 in the ' + '`Using Custom Datasets` section of the documentation found at' + ' docs/zh_cn/user_guides/single_turn_conversation.md.') + + max_dataset_length = min(max_dataset_length, len(dataset)) + indices = np.random.choice(len(dataset), max_dataset_length, replace=False) + dataset = dataset.select(indices) + + if dataset_map_fn is not None: + dataset = dataset.map(dataset_map_fn) + + print('#' * 20 + ' dataset after `dataset_map_fn` ' + '#' * 20) + print(dataset[0]['conversation']) + + if template_map_fn is not None: + template_map_fn = BUILDER.build(template_map_fn) + dataset = dataset.map(template_map_fn) + + print('#' * 20 + ' dataset after adding templates ' + '#' * 20) + print(dataset[0]['conversation']) + + for old, new in rename_maps: + dataset = dataset.rename_column(old, new) + + if pack_to_max_length and (not remove_unused_columns): + raise ValueError('We have to remove unused columns if ' + '`pack_to_max_length` is set to True.') + + dataset = dataset.map( + partial( + encode_fn, + tokenizer=tokenizer, + max_length=max_length, + input_ids_with_output=input_ids_with_output), + remove_columns=list(dataset.column_names) + if remove_unused_columns else None) + + print('#' * 20 + ' encoded input_ids ' + '#' * 20) + print(dataset[0]['input_ids']) + print('#' * 20 + ' encoded labels ' + '#' * 20) + print(dataset[0]['labels']) + + if pack_to_max_length and split == 'train': + if shuffle_before_pack: + dataset = dataset.shuffle() + dataset = dataset.flatten_indices() + dataset = dataset.map(Packer(max_length), batched=True) + + print('#' * 20 + ' input_ids after packed to max_length ' + + '#' * 20) + print(dataset[0]['input_ids']) + print('#' * 20 + ' labels after packed to max_length ' + '#' * 20) + print(dataset[0]['labels']) + + +if __name__ == '__main__': + main() diff --git a/xtuner/tools/log_dataset.py b/xtuner/tools/log_dataset.py new file mode 100644 index 000000000..e6c24eb37 --- /dev/null +++ b/xtuner/tools/log_dataset.py @@ -0,0 +1,39 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +from mmengine.config import Config + +from xtuner.registry import BUILDER + + +def parse_args(): + parser = argparse.ArgumentParser(description='Log processed dataset.') + parser.add_argument( + 'config', + help='config file name or path. Note: Please use the original ' + 'configs, instead of the automatically saved log configs.') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + + tokenizer = BUILDER.build(cfg.tokenizer) + if cfg.get('framework', 'mmengine').lower() == 'huggingface': + train_dataset = BUILDER.build(cfg.train_dataset) + else: + train_dataset = BUILDER.build(cfg.train_dataloader.dataset) + + print('#' * 20 + ' text ' + '#' * 20) + print(tokenizer.decode(train_dataset[0]['input_ids'])) + print('#' * 20 + ' input_ids ' + '#' * 20) + print(train_dataset[0]['input_ids']) + print('#' * 20 + ' labels ' + '#' * 20) + print(train_dataset[0]['labels']) + + +if __name__ == '__main__': + main() diff --git a/xtuner/tools/test.py b/xtuner/tools/test.py index 8fdfa339f..543e318b0 100644 --- a/xtuner/tools/test.py +++ b/xtuner/tools/test.py @@ -2,12 +2,14 @@ import argparse import os import os.path as osp +from types import FunctionType from mmengine.config import Config, DictAction from mmengine.registry import RUNNERS from mmengine.runner import Runner from xtuner.configs import cfgs_name_path +from xtuner.registry import MAP_FUNC def parse_args(): @@ -42,6 +44,21 @@ def parse_args(): return args +def register_function(cfg_dict): + if isinstance(cfg_dict, dict): + for key, value in dict.items(cfg_dict): + if isinstance(value, FunctionType): + value_str = str(value) + if value_str not in MAP_FUNC: + MAP_FUNC.register_module(module=value, name=value_str) + cfg_dict[key] = value_str + else: + register_function(value) + elif isinstance(cfg_dict, (list, tuple)): + for value in cfg_dict: + register_function(value) + + def main(): args = parse_args() @@ -58,6 +75,10 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + # register FunctionType object in cfg to `MAP_FUNC` Registry and + # change these FunctionType object to str + register_function(cfg._cfg_dict) + # work_dir is determined in this priority: CLI > segment in file > filename if args.work_dir is not None: # update configs according to CLI args if args.work_dir is not None diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py index 52ca9919a..90d297727 100644 --- a/xtuner/tools/train.py +++ b/xtuner/tools/train.py @@ -5,6 +5,7 @@ import os import os.path as osp from functools import partial +from types import FunctionType from mmengine.config import Config, DictAction from mmengine.logging import print_log @@ -17,7 +18,7 @@ from xtuner.dataset.collate_fns import default_collate_fn from xtuner.model.modules import dispatch_modules from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict -from xtuner.registry import BUILDER +from xtuner.registry import BUILDER, MAP_FUNC def parse_args(): @@ -57,6 +58,21 @@ def parse_args(): return args +def register_function(cfg_dict): + if isinstance(cfg_dict, dict): + for key, value in dict.items(cfg_dict): + if isinstance(value, FunctionType): + value_str = str(value) + if value_str not in MAP_FUNC: + MAP_FUNC.register_module(module=value, name=value_str) + cfg_dict[key] = value_str + else: + register_function(value) + elif isinstance(cfg_dict, (list, tuple)): + for value in cfg_dict: + register_function(value) + + def main(): args = parse_args() @@ -73,6 +89,10 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) + # register FunctionType object in cfg to `MAP_FUNC` Registry and + # change these FunctionType object to str + register_function(cfg._cfg_dict) + if cfg.get('framework', 'mmengine').lower() == 'huggingface': # set default training_args if cfg.get('training_args', None) is None: