Skip to content

Commit

Permalink
[Feature] Support to process internlm-style datasets (#232)
Browse files Browse the repository at this point in the history
Support process internlm-style datasets
  • Loading branch information
HIT-cwh authored Nov 17, 2023
1 parent 7975d07 commit 2e12134
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
3 changes: 2 additions & 1 deletion xtuner/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .concat_dataset import ConcatDataset
from .huggingface import process_hf_dataset
from .internlm import process_internlm_dataset
from .modelscope import process_ms_dataset
from .moss_sft import MOSSSFTDataset

__all__ = [
'process_hf_dataset', 'ConcatDataset', 'MOSSSFTDataset',
'process_ms_dataset'
'process_ms_dataset', 'process_internlm_dataset'
]
82 changes: 82 additions & 0 deletions xtuner/dataset/internlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
from functools import partial

import numpy as np
from datasets import concatenate_datasets, load_dataset, load_from_disk
from mmengine import print_log
from torch import distributed as dist
from tqdm import tqdm

from xtuner.utils import IGNORE_INDEX
from .utils import Packer


def add_labels(example, max_length):
tokens = example['tokens'][:max_length]
labels = copy.deepcopy(tokens)
tokens = list(np.abs(np.array(tokens)))
labels = np.array(labels)
labels[labels < 0] = IGNORE_INDEX
labels = list(labels)
return {'input_ids': tokens, 'labels': labels}


def process(dataset_folder=None,
cached_folder=None,
max_length=2048,
split='train',
shuffle_before_pack=True,
pack_to_max_length=False,
num_proc=32):
if cached_folder is not None:
try:
return load_from_disk(cached_folder)
except FileNotFoundError:
pass

assert dataset_folder is not None
ds = []
for root, dirs, files in os.walk(dataset_folder, followlinks=True):
for fn in tqdm(sorted(files), total=len(files), leave=False):
if fn.endswith('.bin'):
fp = os.path.join(root, fn)
ds.append(load_dataset('json', data_files=fp)[split])
dataset = concatenate_datasets(ds)
print_log(f'Find {len(dataset)} samples.', 'current')
dataset = dataset.map(
partial(add_labels, max_length=max_length),
remove_columns=list(dataset.column_names),
num_proc=num_proc)

# pack to max length
if pack_to_max_length:
if shuffle_before_pack:
dataset = dataset.shuffle()
dataset = dataset.flatten_indices()
dataset = dataset.map(
Packer(max_length), batched=True, num_proc=num_proc)
print_log(
f'After packing to {max_length}, '
f'the length of dataset is {len(dataset)}.', 'current')

dataset.save_to_disk(cached_folder)
print_log(f'Processed dataset has been saved in {cached_folder}.',
'current')

return dataset


def process_internlm_dataset(*args, **kwargs):
if not (dist.is_available() and dist.is_initialized()):
return process(*args, **kwargs)

if dist.get_rank() == 0:
dataset = process(*args, **kwargs)

dist.barrier()
if dist.get_rank() != 0:
# load processed dataset from `cached_folder`
dataset = process(*args, **kwargs)
return dataset

0 comments on commit 2e12134

Please sign in to comment.