Skip to content

Commit

Permalink
Revert "Merge pull request #229 from huggingface/main"
Browse files Browse the repository at this point in the history
This reverts commit 8d2014f, reversing
changes made to 971a46a.
  • Loading branch information
3outeille committed Sep 5, 2024
1 parent 8d2014f commit eb5f112
Show file tree
Hide file tree
Showing 23 changed files with 347 additions and 751 deletions.
62 changes: 29 additions & 33 deletions docs/nanoset.md
Original file line number Diff line number Diff line change
@@ -1,50 +1,49 @@
# Nanosets
Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a dataset for processing tokenized documents with [`datatrove`](https://github.com/huggingface/datatrove). They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches.
Nanotron incorporates [`Nanosets`](../src/nanotron/data/nanoset.py), a kind of datasets based on [numpy memory-mapped arrays](https://numpy.org/doc/stable/reference/generated/numpy.memmap.html). `Nanosets` are capable of serving batches from files containing pre-tokenized datasets. They allow reading tokens from one or multiple datasets and even specifying the weight of each dataset when building batches.
## Install
To use `Nanosets`, it's necessary to install Nanotron with the `nanosets` flavor.
```
pip install nanotron[nanosets]
pip install -e '.[nanosets]'
```
This will install the following dependencies:
- `datatrove`: To preprocess the datasets
- `transformers`: To tokenize the datasets
- `datasets`: To preprocess the datasets
- `numba`: To compile helper functions in order to speed up the creation of `Nanosets`
- `transformers`: For the tokenizers
## Data pre-processing
To use this dataset, first, we need to preprocess the data using `datatrove`'s `DocumentTokenizer` pipeline. We invite you to take a look at `datatrove`, since it contains multiple features that allow, for example, filter out documents based on specific rules/criteria, extract text content from raw formats or scheduling the preprocessing in a Slurm cluster. We have also added a simple script capable of tokenizing datasets.

The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. The input format can either be a Hugging Face Dataset, a path to a `.jsonl` or a path to a folder containing multiple `.jsonl` files. Below we show an example for processing a Hugging Face Dataset from the Hub with the Llama3 tokenizer.
To use these datasets, first, we need to preprocess the data. The input format can either be a column of a Hugging Face Dataset or a .json file containing a text sample per line. For example:

<pre>
python3 tools/preprocess_data.py \
--tokenizer-name-or-path meta-llama/Meta-Llama-3-8B \
--output-folder datasets/emotion \
--n-tasks 16 \
hf \
--dataset dair-ai/emotion \
{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
</pre>

First with `--tokenizer-name-or-path` we will specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`. Then we specify the `--output-folder` where we will store the tokenized documents and the number of workers with `--n-tasks`. Finally we will indicate the type of dataset (whether if it's a Hugging Face Dataset ["**hf**"] or in jsonl ["**jsonl**"] format) and the dataset that we want to preprocess. Check the different settings with `python3 tools/preprocess_data.py --help`, `python3 tools/preprocess_data.py hf --help` & `python3 tools/preprocess_data.py jsonl --help`.
The preprocessing is done using the [`tools/preprocess_data.py`](../tools/preprocess_data.py) script. Below we show an example for processing a corpus with the Llama2 tokenizer.

Every worker will store in `--output-folder` 3 different kind of files:
- `*.ds` Containing the tokenized documents
- `*.ds.index` Containing the bounds of each tokenized document
- `*.ds.metadata` Containing the number of tokens and tokenizer used
<pre>
torchrun --nproc-per-node 16 tools/preprocess_data.py \
--input HuggingFaceH4/testing_alpaca_small \
--split train \
--column completion \
--output-prefix datasets/testing_alpaca_small \
--tokenizer-name-or-path openai-community/gpt2
</pre>

> [!IMPORTANT]
Remember to introduce the type of dataset to process. e.g. python3 tools/preprocess_data.py --tokenizer-name-or-path gpt2 --n-tasks 16 **jsonl** --dataset raw_datasets/c4-es-json-files
The preprocessing script has to be launched with `torchrun` in order to spawn `--nproc-per-node` workers that will preprocess the dataset concurrently. The `--input` dataset can be either a Hugging Face Dataset from the Hub or a `.json` file. The processed dataset will be stored in *`--output-prefix`_input_ids.npy*. In `--tokenizer-name-or-path`, we will have to specify a tokenizer in the same way as we do when using `AutoTokenizers.from_pretrained(...)`.

The output will be one file named, in this case, `datasets/testing_alpaca_small_input_ids.npy`. We will then have to specify this file in the `dataset_path` field in the config file.

## Working with Nanosets

To work with `Nanosets`, we just need to configure 1 argument:
1. `dataset_folder`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it:
1. `dataset_path`: This argument specifies the file or files that will compose the `Nanoset`. There are 3 ways to specify it:
1. If we specify a single path, we will create a `Nanoset` from a single dataset file.
```yaml
data_stages:
- name: General purpose training (Single dataset)
start_training_step: 1
data:
dataset:
dataset_folder: datasets/SlimPajama-6B
dataset_path: datasets/SlimPajama-6B_input_ids.npy
num_loading_workers: 0
seed: 1234
```
Expand All @@ -55,9 +54,9 @@ To work with `Nanosets`, we just need to configure 1 argument:
start_training_step: 15
data:
dataset:
dataset_folder:
- datasets/SlimPajama-6B
- datasets/testing_alpaca_small
dataset_path:
- datasets/SlimPajama-6B_input_ids.npy
- datasets/testing_alpaca_small_input_ids.npy
num_loading_workers: 0
seed: 1234
```
Expand All @@ -68,9 +67,9 @@ To work with `Nanosets`, we just need to configure 1 argument:
start_training_step: 25
data:
dataset:
dataset_folder:
datasets/SlimPajama-6B: 0.8
datasets/testing_alpaca_small: 0.2
dataset_path:
datasets/SlimPajama-6B_input_ids.npy: 0.8
datasets/testing_alpaca_small_input_ids.npy: 0.2
num_loading_workers: 0
seed: 1234
```
Expand All @@ -79,14 +78,11 @@ To work with `Nanosets`, we just need to configure 1 argument:

Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py).
```shell
torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml
torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml
```

## Under the hood
`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. Despite most of the extracting logic lies in `DatatroveFolderDataset`, `Nanosets` will take care of the following:
1. Creating dataset mixtures from different dataset folder paths
2. Ensure that in each epoch, we consume each sample only once
3. Ensure that we never exhaust the `DataLoader`
`Nanosets` are responsible of building samples of `sequence length + 1` tokens from the preprocessed dataset files. The `dataset lengths` of each dataset will be determined by the `(dataset_number_of_tokens - 1) / sequence length`, discarding the last sample if its length < `sequence length`.

Based on the `dataset lengths`, the `dataset weights` and the `number of samples per epoch` (defined as the `sum(dataset lengths)`), we build the two indexes we need in order to extract samples from the `Nanoset` ([build_nanoset_index_helper](../src/nanotron/data/nanoset.py)):
- `dataset index`: Contains the index of the dataset from the list of `dataset paths` from which to extract the sample, respecting the established dataset weight.
Expand Down
24 changes: 12 additions & 12 deletions examples/config_nanoset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@ checkpoints:
data_stages:
- data:
dataset:
dataset_folder: datasets/c4-es/tokenized
dataset_path: datasets/testing_alpaca_small_input_ids.npy
num_loading_workers: 1
seed: 42
name: General purpose training (Single dataset)
start_training_step: 1
- data:
dataset:
dataset_folder:
- datasets/SlimPajama-6B/tokenized
- datasets/c4-es/tokenized
dataset_path:
- datasets/yelp_review_full_input_ids.npy
- datasets/testing_alpaca_small_input_ids.npy
num_loading_workers: 1
seed: 42
name: Second purpose training (> 1 dataset)
start_training_step: 15
- data:
dataset:
dataset_folder:
datasets/SlimPajama-6B/tokenized: 0.8
datasets/c4-es/tokenized: 0.2
dataset_path:
datasets/testing_alpaca_small_input_ids.npy: 0.8
datasets/yelp_review_full_input_ids.npy: 0.2
num_loading_workers: 1
seed: 42
name: Third purpose training (Blended dataset)
Expand Down Expand Up @@ -57,7 +57,7 @@ model:
initializer_range: 0.02
intermediate_size: 64
is_llama_config: true
max_position_embeddings: 1024
max_position_embeddings: 256
num_attention_heads: 4
num_hidden_layers: 2
num_key_value_heads: 4
Expand All @@ -67,7 +67,7 @@ model:
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 50257
vocab_size: 32000
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
Expand All @@ -88,11 +88,11 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 1
dp: 2
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 1
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
profiler: null
Expand All @@ -105,6 +105,6 @@ tokens:
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 1024
sequence_length: 128
train_steps: 200
val_check_interval: -1
12 changes: 0 additions & 12 deletions examples/mamba/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ pip install -r requirements.txt

> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5
## Bug related to nanotron
Encountered the following issue when ran train_mamba.sh:
```
causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
```
Solved this by doing:
pip uninstall mamba-ssm
pip install causal_conv1d==1.1.1
pip install mamba-ssm --no-cache-dir
https://github.com/state-spaces/mamba/issues/169


## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/state-spaces/mamba
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fast-modeling = [

nanosets = [
"transformers",
"datatrove[io,processing]@git+https://github.com/huggingface/datatrove",
"datasets",
"numba",
]

Expand Down
6 changes: 3 additions & 3 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,17 @@ def get_dataloader_from_data_stage(
elif isinstance(data.dataset, NanosetDatasetsArgs):
# Get tokenizer cardinality
tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path)
token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2
token_dtype = np.int32 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else np.uint16
del tokenizer
# Create Nanoset
from nanotron.data.nanoset import Nanoset

with main_rank_first(trainer.parallel_context.world_pg):
train_dataset = Nanoset(
dataset_folders=data.dataset.dataset_folder,
dataset_paths=data.dataset.dataset_path,
dataset_weights=data.dataset.dataset_weights,
sequence_length=trainer.sequence_length,
token_size=token_size,
token_dtype=token_dtype,
train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size,
random_seed=data.seed,
)
Expand Down
16 changes: 10 additions & 6 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,25 @@ def __post_init__(self):

@dataclass
class NanosetDatasetsArgs:
dataset_folder: Union[str, List[str]]
dataset_weights: Optional[List[float]] = None
dataset_path: Union[str, dict, List[str]]

def __post_init__(self):
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
self.dataset_folder = [self.dataset_folder]
if isinstance(self.dataset_path, str): # Case 1: 1 Dataset file
self.dataset_path = [self.dataset_path]
self.dataset_weights = [1]
elif isinstance(self.dataset_path, List): # Case 2: > 1 Dataset file
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.dataset_path, dict): # Case 3: dict with > 1 dataset_path and weights
tmp_dataset_path = self.dataset_path.copy()
self.dataset_path = list(tmp_dataset_path.keys())
self.dataset_weights = list(tmp_dataset_path.values())


@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down Expand Up @@ -140,7 +145,6 @@ class CheckpointsArgs:
checkpoints_path: Path
checkpoint_interval: int
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[Path] = None
checkpoints_path_is_shared_file_system: Optional[bool] = False

Expand Down
3 changes: 0 additions & 3 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ class LlamaConfig:
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
Expand Down
2 changes: 0 additions & 2 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ class ParallelismArgs:
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
80 changes: 0 additions & 80 deletions src/nanotron/data/collator.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import nanotron.distributed as dist
from nanotron import logging
from nanotron.data.collator import NanosetDataCollatorForCLM
from nanotron.dataloader import (
DataCollatorForCLM,
EmptyInfiniteDataset,
get_dataloader_worker_init,
get_sampler,
Expand Down Expand Up @@ -32,7 +32,7 @@ def build_nanoset_dataloader(
# No need to spawn a lot of workers, we can just use main
dataloader_num_workers = 0

data_collator = NanosetDataCollatorForCLM(
data_collator = DataCollatorForCLM(
sequence_length=sequence_length,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
Expand Down
Loading

0 comments on commit eb5f112

Please sign in to comment.