Skip to content

Commit

Permalink
Update on "[cp] apply fsdp to model when CP is enabled without DP for…
Browse files Browse the repository at this point in the history
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Dec 11, 2024
2 parents 666a885 + 24a798b commit 6c353cd
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .ci/docker/ubuntu/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
ARG OS_VERSION

FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu${OS_VERSION}
FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu${OS_VERSION}

ARG OS_VERSION

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124
# install torchtitan to test the files in ./scripts, currently just for memory estimation
python -m pip install -e .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration_test_8gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ jobs:
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 8
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You may want to see how the model is defined or how parallelism techniques are a
4. `torch.compile` support
5. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md))
6. DDP and HSDP
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries)
7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel
9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc.
Expand All @@ -73,7 +73,7 @@ We report our [Performance](docs/performance.md) verified on 64/128 GPUs.
git clone https://github.com/pytorch/torchtitan
cd torchtitan
pip install -r requirements.txt
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall # or cu118
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 --force-reinstall
```

### Downloading a tokenizer
Expand Down
3 changes: 2 additions & 1 deletion docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ export_dtype = "bfloat16"
enable_checkpoint = true
folder = "checkpoint"
interval_type = "steps"
interval = 5
interval = 10
load_step = 5
model_weights_only = true
export_dtype = "bfloat16"
```
Expand Down
74 changes: 74 additions & 0 deletions docs/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Custom Datasets in TorchTitan

TorchTitan is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example.

## Quick Start
Locate the dataset configuration file:
```
torchtitan/datasets/hf_datasets/hf_datasets.py
```

## Adding Your Dataset
You'll need to add three components:
1. A dataset loader function
2. A sample processor function
3. A dataset configuration entry

### 1. Define Dataset Loader
Create a function that specifies how to load your dataset:

```python
def load_wikipedia_dataset(dataset_path: str, **kwargs):
"""Load Wikipedia dataset with specific configuration."""
logger.info("Loading Wikipedia dataset...")
return load_dataset(
dataset_path,
name="20220301.en",
split="train",
streaming=True,
trust_remote_code=True,
)
```

### 2. Define Sample Processor
Create a function that processes individual samples from your dataset:

```python
def process_wikipedia_text(sample: Dict[str, Any]) -> str:
"""Process Wikipedia dataset sample text."""
return f"{sample['title']}\n\n{sample['text']}"
```

### 3. Register Your Dataset
Add your dataset configuration to the DATASETS dictionary:

```python
DATASETS = {
# ... existing datasets ...
"wikipedia": DatasetConfig(
path="wikipedia", # default HuggingFace dataset path
loader=load_wikipedia_dataset,
text_processor=process_wikipedia_text,
),
}
```

### 4. Configure Your Training
In your training configuration file (`.toml`), set your dataset:

```toml
dataset = "wikipedia"
```

That's it! Your custom dataset is now ready to use with TorchTitan.

## Key Points
- The DatasetConfig contains all necessary components for a dataset:
- `path`: The default path to the dataset (can be overridden during training)
- `loader`: Function to load the dataset
- `text_processor`: Function to process individual samples
- The loader function should return a HuggingFace dataset object
- The processor function should return a string that combines the relevant fields from your dataset
- Use `streaming=True` for large datasets to manage memory efficiently

Now you can start training with your custom dataset!
7 changes: 6 additions & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,12 @@ def __init__(self):
0 is the default value.
""",
)

self.parser.add_argument(
"--checkpoint.load_step",
type=int,
default=-1,
help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.",
)
# activation checkpointing configs
self.parser.add_argument(
"--activation_checkpoint.mode",
Expand Down
153 changes: 72 additions & 81 deletions torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# LICENSE file in the root directory of this source tree.

import pickle
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

import torch
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset

from torchdata.stateful_dataloader import StatefulDataLoader

from torchtitan.datasets.tokenizer import Tokenizer
Expand All @@ -19,49 +19,56 @@
from datasets import Dataset, load_dataset
from datasets.distributed import split_dataset_by_node

# map from dataset name to a local directory, or
# a dataset repository on the HF hub
_supported_datasets = {
"c4_test": "test/assets/c4_test",
"c4": "allenai/c4",

def _load_c4_dataset(dataset_path: str):
"""Load C4 dataset with default configuration."""
return load_dataset(dataset_path, name="en", split="train", streaming=True)


def _process_c4_text(sample: Dict[str, Any]) -> str:
"""Process C4 dataset sample text."""
return sample["text"]


@dataclass
class DatasetConfig:
path: str
loader: Callable
text_processor: Callable


# Add your dataset here here - more information at docs/datasets.md
DATASETS = {
"c4": DatasetConfig(
path="allenai/c4",
loader=_load_c4_dataset,
text_processor=_process_c4_text,
),
"c4_test": DatasetConfig(
path="test/assets/c4_test",
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
}


class HuggingFaceDataset(IterableDataset, Stateful):
"""PyTorch Representation of the HuggingFace Dataset.
Args:
dataset_name (str): name of the dataset to load
dataset_path (Optional[str]):
Path to the dataset in the file system. If provided, data will be loaded
from this path instead of downloaded.
tokenizer (Tokenizer):
Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
infinite (bool): whether to loop infinitely over the dataset
We currently support the c4 dataset, and a subset of it for testing purposes:
c4_test (2K training entries)
c4 (177M training entries - this dataset is streamed due to the size)
>> c4 (EN) <<:
c4 cleaned, English version
Data input format (c4):
{
'url': 'https://klyq.com/beginners-bbq-class-taking-place-in-missoula/',
'text': 'Beginners BBQ Class Taking Place in Missoula!\nDo you want to get better at ...',
'timestamp': '2019-04-25T12:57:54Z'
}
Example use (c4):
>>> ds = HuggingFaceDataset(dataset_name="c4", dataset_path=None, tokenizer=tokenizer)
>>> for batch in Dataloader(ds, batch_size=8):
print(f"Batch size: {len(batch)}")
Batch size: 8
"""
def _validate_dataset(
dataset_name: str, dataset_path: str = None
) -> tuple[str, Callable, Callable]:
"""Validate dataset name and path."""
if dataset_name not in DATASETS:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(DATASETS.keys())}"
)

config = DATASETS[dataset_name]
path = dataset_path or config.path
logger.info(f"Preparing {dataset_name} dataset from {path}")
return path, config.loader, config.text_processor


class HuggingFaceDataset(IterableDataset, Stateful):
def __init__(
self,
dataset_name: str,
Expand All @@ -72,47 +79,41 @@ def __init__(
rank: int = 0,
infinite: bool = False,
) -> None:
# allow user to pass in a (local or HF hub) path to use unsupported datasets
if dataset_name not in _supported_datasets:
if dataset_path:
logger.warning(
f"Dataset {dataset_name} is not tested or verfied. "
f"Recommended datasets are: {list(_supported_datasets.keys())}"
)
else:
raise ValueError(
f"Dataset {dataset_name} is not supported. "
f"Supported datasets are: {list(_supported_datasets.keys())}"
)

if not dataset_path:
dataset_path = _supported_datasets[dataset_name]
logger.info(f"Preparing {dataset_name} dataset from {dataset_path}")

if dataset_name == "c4":
# c4 is huge, and requires both streaming and language selection
# (we default to en)
ds = load_dataset(dataset_path, name="en", split="train", streaming=True)
else:
ds = load_dataset(dataset_path, split="train")

# TODO: support shuffling
# Force lowercase for consistent comparison
dataset_name = dataset_name.lower()

path, dataset_loader, text_processor = _validate_dataset(
dataset_name, dataset_path
)
ds = dataset_loader(path)

self.dataset_name = dataset_name
self._data = split_dataset_by_node(ds, rank, world_size)
self._tokenizer = tokenizer
self.seq_len = seq_len
self.infinite = infinite
self._text_processor = text_processor

# variables for checkpointing
# Variables for checkpointing
self._sample_idx = 0
self._all_tokens: List[int] = []

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def __iter__(self):
max_buffer_token_len = 1 + self.seq_len

while True:
for sample in self._get_data_iter():
sample_text = sample["text"]
# Use the dataset-specific text processor
sample_text = self._text_processor(sample)
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
self._all_tokens.extend(sample_tokens)
self._sample_idx += 1
Expand All @@ -133,16 +134,6 @@ def __iter__(self):
self._sample_idx = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")

def _get_data_iter(self):
if self._sample_idx == 0:
return iter(self._data)

# As skipping to the end throws an error in case of map-style dataset, return an empty iterator
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))

def load_state_dict(self, state_dict):
self._sample_idx = state_dict["sample_idx"]
self._all_tokens = state_dict["token_buffer"]
Expand Down Expand Up @@ -184,12 +175,12 @@ def build_hf_data_loader(
tokenizer: Tokenizer,
batch_size: int,
seq_len: int,
world_size,
rank,
world_size: int,
rank: int,
infinite: bool = True,
):
"""Build a data loader for HuggingFace datasets."""
hf_ds = HuggingFaceDataset(
dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite
)

return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size)
3 changes: 1 addition & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def main(job_config: JobConfig):
# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)

# build dataloader
data_loader = build_hf_data_loader(
job_config.training.dataset,
Expand Down Expand Up @@ -206,7 +205,7 @@ def loss_fn(pred, labels):
logger.info("Created seed checkpoint")
return

checkpoint_loaded = checkpoint.load()
checkpoint_loaded = checkpoint.load(step=job_config.checkpoint.load_step)

if parallel_dims.pp_enabled and not checkpoint_loaded:
# TODO: fix this by allowing each rank to set their own seed
Expand Down

0 comments on commit 6c353cd

Please sign in to comment.