-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support to process internlm-style datasets (#232)
Support process internlm-style datasets
- Loading branch information
Showing
2 changed files
with
84 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .concat_dataset import ConcatDataset | ||
from .huggingface import process_hf_dataset | ||
from .internlm import process_internlm_dataset | ||
from .modelscope import process_ms_dataset | ||
from .moss_sft import MOSSSFTDataset | ||
|
||
__all__ = [ | ||
'process_hf_dataset', 'ConcatDataset', 'MOSSSFTDataset', | ||
'process_ms_dataset' | ||
'process_ms_dataset', 'process_internlm_dataset' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import copy | ||
import os | ||
from functools import partial | ||
|
||
import numpy as np | ||
from datasets import concatenate_datasets, load_dataset, load_from_disk | ||
from mmengine import print_log | ||
from torch import distributed as dist | ||
from tqdm import tqdm | ||
|
||
from xtuner.utils import IGNORE_INDEX | ||
from .utils import Packer | ||
|
||
|
||
def add_labels(example, max_length): | ||
tokens = example['tokens'][:max_length] | ||
labels = copy.deepcopy(tokens) | ||
tokens = list(np.abs(np.array(tokens))) | ||
labels = np.array(labels) | ||
labels[labels < 0] = IGNORE_INDEX | ||
labels = list(labels) | ||
return {'input_ids': tokens, 'labels': labels} | ||
|
||
|
||
def process(dataset_folder=None, | ||
cached_folder=None, | ||
max_length=2048, | ||
split='train', | ||
shuffle_before_pack=True, | ||
pack_to_max_length=False, | ||
num_proc=32): | ||
if cached_folder is not None: | ||
try: | ||
return load_from_disk(cached_folder) | ||
except FileNotFoundError: | ||
pass | ||
|
||
assert dataset_folder is not None | ||
ds = [] | ||
for root, dirs, files in os.walk(dataset_folder, followlinks=True): | ||
for fn in tqdm(sorted(files), total=len(files), leave=False): | ||
if fn.endswith('.bin'): | ||
fp = os.path.join(root, fn) | ||
ds.append(load_dataset('json', data_files=fp)[split]) | ||
dataset = concatenate_datasets(ds) | ||
print_log(f'Find {len(dataset)} samples.', 'current') | ||
dataset = dataset.map( | ||
partial(add_labels, max_length=max_length), | ||
remove_columns=list(dataset.column_names), | ||
num_proc=num_proc) | ||
|
||
# pack to max length | ||
if pack_to_max_length: | ||
if shuffle_before_pack: | ||
dataset = dataset.shuffle() | ||
dataset = dataset.flatten_indices() | ||
dataset = dataset.map( | ||
Packer(max_length), batched=True, num_proc=num_proc) | ||
print_log( | ||
f'After packing to {max_length}, ' | ||
f'the length of dataset is {len(dataset)}.', 'current') | ||
|
||
dataset.save_to_disk(cached_folder) | ||
print_log(f'Processed dataset has been saved in {cached_folder}.', | ||
'current') | ||
|
||
return dataset | ||
|
||
|
||
def process_internlm_dataset(*args, **kwargs): | ||
if not (dist.is_available() and dist.is_initialized()): | ||
return process(*args, **kwargs) | ||
|
||
if dist.get_rank() == 0: | ||
dataset = process(*args, **kwargs) | ||
|
||
dist.barrier() | ||
if dist.get_rank() != 0: | ||
# load processed dataset from `cached_folder` | ||
dataset = process(*args, **kwargs) | ||
return dataset |