Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr/tj solergibert/189 #212

Merged
merged 8 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions docs/nanoset.md
Original file line number Diff line number Diff line change
@@ -1,49 +1,50 @@
# Nanosets
Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a kind of datasets based on [numpy memory-mapped arrays](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html). `Nanosets` are capable of serving batches from files containing pre-tokenized datasets. They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches.
Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a dataset for processing tokenized documents with [`datatrove`](https://github.com/huggingface/datatrove). They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches.
## Install
To use `Nanosets`, it's necessary to install Nanotron with the `nanosets` flavor.
```
pip install -e '.[nanosets]'
pip install nanotron[nanosets]
```
This will install the following dependencies:
- `transformers`: To tokenize the datasets
- `datasets`: To preprocess the datasets
- `datatrove`: To preprocess the datasets
- `numba`: To compile helper functions in order to speed up the creation of `Nanosets`
- `transformers`: For the tokenizers
## Data pre-processing
To use these datasets, first, we need to preprocess the data. The input format can either be a column of a Hugging Face Dataset or a .json file containing a text sample per line. For example:
To use this dataset, first, we need to preprocess the data using `datatrove`'s `DocumentTokenizer` pipeline. We invite you to take a look at `datatrove`, since it contains multiple features that allow, for example, filter out documents based on specific rules/criteria, extract text content from raw formats or scheduling the preprocessing in a Slurm cluster. We have also added a simple script capable of tokenizing datasets.

<pre>
{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
</pre>

The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. Below we show an example for processing a corpus with the Llama2 tokenizer.
The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. The input format can either be a Hugging Face Dataset, a path to a `.jsonl` or a path to a folder containing multiple `.jsonl` files. Below we show an example for processing a Hugging Face Dataset from the Hub with the Llama3 tokenizer.

<pre>
torchrun --nproc-per-node 16 tools/preprocess_data.py \
--input HuggingFaceH4/testing_alpaca_small \
--split train \
--column completion \
--output-prefix datasets/testing_alpaca_small \
--tokenizer-name-or-path openai-community/gpt2
python3 tools/preprocess_data.py \
--tokenizer-name-or-path meta-llama/Meta-Llama-3-8B \
--output-folder datasets/emotion \
--n-tasks 16 \
hf \
--dataset dair-ai/emotion \
</pre>

The preprocessing script has to be launched with `torchrun` in order to spawn `--nproc-per-node` workers that will preprocess the dataset concurrently. The `--input` dataset can be either a Hugging Face Dataset from the Hub or a `.json` file. The processed dataset will be stored in *`--output-prefix`_input_ids.npy*. In `--tokenizer-name-or-path`, we will have to specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`.
First with `--tokenizer-name-or-path` we will specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. Then we specify the `--output-folder` where we will store the tokenized documents and the number of workers with `--n-tasks`. Finally we will indicate the type of dataset (whether if it's a Hugging Face Dataset ["**hf**"] or in jsonl ["**jsonl**"] format) and the dataset that we want to preprocess. Check the different settings with `python3 tools/preprocess_data.py --help`, `python3 tools/preprocess_data.py hf --help` & `python3 tools/preprocess_data.py jsonl --help`.

The output will be one file named, in this case, `datasets/testing_alpaca_small_input_ids.npy`. We will then have to specify this file in the `dataset_path` field in the config file.
Every worker will store in `--output-folder` 3 different kind of files:
- `*.ds` Containing the tokenized documents
- `*.ds.index` Containing the bounds of each tokenized document
- `*.ds.metadata` Containing the number of tokens and tokenizer used

> [!IMPORTANT]
Remember to introduce the type of dataset to process. e.g. python3 tools/preprocess_data.py --tokenizer-name-or-path gpt2 --n-tasks 16 **jsonl** --dataset raw_datasets/c4-es-json-files

## Working with Nanosets

To work with `Nanosets`, we just need to configure 1 argument:
1. `dataset_path`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it:
1. `dataset_folder`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it:
1. If we specify a single path, we will create a `Nanoset` from a single dataset file.
```yaml
data_stages:
- name: General purpose training (Single dataset)
start_training_step: 1
data:
dataset:
dataset_path: datasets/SlimPajama-6B_input_ids.npy
dataset_folder: datasets/SlimPajama-6B
num_loading_workers: 0
seed: 1234
```
Expand All @@ -54,9 +55,9 @@ To work with `Nanosets`, we just need to configure 1 argument:
start_training_step: 15
data:
dataset:
dataset_path:
- datasets/SlimPajama-6B_input_ids.npy
- datasets/testing_alpaca_small_input_ids.npy
dataset_folder:
- datasets/SlimPajama-6B
- datasets/testing_alpaca_small
num_loading_workers: 0
seed: 1234
```
Expand All @@ -67,9 +68,9 @@ To work with `Nanosets`, we just need to configure 1 argument:
start_training_step: 25
data:
dataset:
dataset_path:
datasets/SlimPajama-6B_input_ids.npy: 0.8
datasets/testing_alpaca_small_input_ids.npy: 0.2
dataset_folder:
datasets/SlimPajama-6B: 0.8
datasets/testing_alpaca_small: 0.2
num_loading_workers: 0
seed: 1234
```
Expand All @@ -78,11 +79,14 @@ To work with `Nanosets`, we just need to configure 1 argument:

Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py).
```shell
torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml
torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml
```

## Under the hood
`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. The `dataset lengths` of each dataset will be determined by the `(dataset_number_of_tokens - 1) / sequence length`, discarding the last sample if its length < `sequence length`.
`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. Despite most of the extracting logic lies in `DatatroveFolderDataset`, `Nanosets` will take care of the following:
1. Creating dataset mixtures from different dataset folder paths
2. Ensure that in each epoch, we consume each sample only once
3. Ensure that we never exhaust the `DataLoader`

Based on the `dataset lengths`, the `dataset weights` and the `number of samples per epoch` (defined as the `sum(dataset lengths)`), we build the two indexes we need in order to extract samples from the `Nanoset` ([build_nanoset_index_helper](../src/nanotron/data/nanoset.py)):
- `dataset index`: Contains the index of the dataset from the list of `dataset paths` from which to extract the sample, respecting the established dataset weight.
Expand Down
24 changes: 12 additions & 12 deletions examples/config_nanoset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@ checkpoints:
data_stages:
- data:
dataset:
dataset_path: datasets/testing_alpaca_small_input_ids.npy
dataset_folder: datasets/c4-es/tokenized
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
- data:
dataset:
dataset_path:
- datasets/yelp_review_full_input_ids.npy
- datasets/testing_alpaca_small_input_ids.npy
dataset_folder:
- datasets/SlimPajama-6B/tokenized
- datasets/c4-es/tokenized
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
start_training_step: 15
- data:
dataset:
dataset_path:
datasets/testing_alpaca_small_input_ids.npy: 0.8
datasets/yelp_review_full_input_ids.npy: 0.2
dataset_folder:
datasets/SlimPajama-6B/tokenized: 0.8
datasets/c4-es/tokenized: 0.2
num_loading_workers: 1
seed: 42
name: Third purpose training (Blended dataset)
Expand Down Expand Up @@ -57,7 +57,7 @@ model:
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
max_position_embeddings: 256
max_position_embeddings: 1024
num_attention_heads: 4
num_hidden_layers: 2
num_key_value_heads: 4
Expand All @@ -67,7 +67,7 @@ model:
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 32000
vocab_size: 50257
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
Expand All @@ -88,11 +88,11 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
dp: 1
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 2
tp: 1
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
Expand All @@ -105,6 +105,6 @@ tokens:
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 128
sequence_length: 1024
train_steps: 200
val_check_interval: -1
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fast-modeling = [

nanosets = [
"transformers",
"datasets",
"datatrove[io,processing]@git+https://github.com/huggingface/datatrove",
"numba",
]

Expand Down
6 changes: 3 additions & 3 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,17 @@ def get_dataloader_from_data_stage(
elif isinstance(data.dataset, NanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
del tokenizer
# Create Nanoset
from nanotron.data.nanoset import Nanoset

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = Nanoset(
dataset_paths=data.dataset.dataset_path,
dataset_folders=data.dataset.dataset_folder,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_dtype=token_dtype,
token_size=token_size,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
random_seed=data.seed,
)
Expand Down
16 changes: 8 additions & 8 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,18 @@ def __post_init__(self):

@dataclass
class NanosetDatasetsArgs:
dataset_path: Union[str, dict, List[str]]
dataset_folder: Union[str, dict, List[str]]

def __post_init__(self):
if isinstance(self.dataset_path, str): # Case 1: 1 Dataset file
self.dataset_path = [self.dataset_path]
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
self.dataset_folder = [self.dataset_folder]
self.dataset_weights = [1]
elif isinstance(self.dataset_path, List): # Case 2: > 1 Dataset file
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.dataset_path, dict): # Case 3: dict with > 1 dataset_path and weights
tmp_dataset_path = self.dataset_path.copy()
self.dataset_path = list(tmp_dataset_path.keys())
self.dataset_weights = list(tmp_dataset_path.values())
elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights
tmp_dataset_folder = self.dataset_folder.copy()
self.dataset_folder = list(tmp_dataset_folder.keys())
self.dataset_weights = list(tmp_dataset_folder.values())


@dataclass
Expand Down
80 changes: 80 additions & 0 deletions src/nanotron/data/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import dataclasses
from typing import Dict, List, Union

import numpy as np
import torch
from nanotron import distributed as dist
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer


@dataclasses.dataclass
class NanosetDataCollatorForCLM:
"""
Data collator used for causal language modeling with Nanosets dataset.

- input_pp_rank: Discards last input id token
- output_pp_rank: Discards first label id token
- other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data.
"""

sequence_length: int
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext

def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [
self.input_pp_rank,
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}

# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)

# TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor?
input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape

result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}

result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)

assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"

# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)
if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length:
raise ValueError(
f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be"
f" {self.sequence_length}."
)

return result
4 changes: 2 additions & 2 deletions src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.dataloader import (
DataCollatorForCLM,
EmptyInfiniteDataset,
get_dataloader_worker_init,
get_sampler,
Expand Down Expand Up @@ -32,7 +32,7 @@ def build_nanoset_dataloader(
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0

data_collator = DataCollatorForCLM(
data_collator = NanosetDataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
Expand Down
Loading
Loading