From fbd07ecb35feb4dd08064ceee5fe06399e911304 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 4 Dec 2024 12:45:30 -0800 Subject: [PATCH 1/4] Migrate CI to CUDA 12.4 Pytorch stopped releasing cu121 nightlies. ghstack-source-id: 39850c42c5ec0a8898a208718f35392e98a427f9 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/718 --- .ci/docker/ubuntu/Dockerfile | 2 +- .github/workflows/integration_test_4gpu.yaml | 2 +- .github/workflows/integration_test_8gpu.yaml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index ba276c29..deb69a50 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -1,6 +1,6 @@ ARG OS_VERSION -FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu${OS_VERSION} +FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu${OS_VERSION} ARG OS_VERSION diff --git a/.github/workflows/integration_test_4gpu.yaml b/.github/workflows/integration_test_4gpu.yaml index 72cdb8af..6c506887 100644 --- a/.github/workflows/integration_test_4gpu.yaml +++ b/.github/workflows/integration_test_4gpu.yaml @@ -37,7 +37,7 @@ jobs: pip config --user set global.progress_bar off - python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 # install torchtitan to test the files in ./scripts, currently just for memory estimation python -m pip install -e . diff --git a/.github/workflows/integration_test_8gpu.yaml b/.github/workflows/integration_test_8gpu.yaml index 0d8c79db..0b8f2a1f 100644 --- a/.github/workflows/integration_test_8gpu.yaml +++ b/.github/workflows/integration_test_8gpu.yaml @@ -36,6 +36,6 @@ jobs: pip config --user set global.progress_bar off - python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 + python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 mkdir artifacts-to-be-uploaded python ./test_runner.py artifacts-to-be-uploaded --ngpu 8 From 0ed31f55977dd9f6bb476736b432d00750c741c0 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 4 Dec 2024 13:24:56 -0800 Subject: [PATCH 2/4] Update readme with cu124 version ghstack-source-id: 6eb7a87df8b3585e53993684d0b9682aeb99cfe5 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/719 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2468d4a0..714d2ae2 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,7 @@ We report our [Performance](docs/performance.md) verified on 64/128 GPUs. git clone https://github.com/pytorch/torchtitan cd torchtitan pip install -r requirements.txt -pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall # or cu118 +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124 --force-reinstall ``` ### Downloading a tokenizer From afa82294f633e4a430d5c4c0c63ae1328e7193cc Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 5 Dec 2024 11:56:22 -0800 Subject: [PATCH 3/4] Add checkpoint load step (#716) Fixes https://github.com/pytorch/torchtitan/issues/662 followed @fegin advice to test this and indeed things are working https://gist.github.com/msaroufim/2925b3f17b631bf370a49f185b6e169d ``` [checkpoint] enable_checkpoint = true folder = "checkpoint" interval_type = "steps" interval = 10 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] load_step = 10 ``` --- docs/checkpoint.md | 3 ++- torchtitan/config_manager.py | 7 ++++++- train.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 3f2c8cb2..72e6a021 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -49,7 +49,8 @@ export_dtype = "bfloat16" enable_checkpoint = true folder = "checkpoint" interval_type = "steps" -interval = 5 +interval = 10 +load_step = 5 model_weights_only = true export_dtype = "bfloat16" ``` diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index e7bca6f1..d17e263d 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -480,7 +480,12 @@ def __init__(self): 0 is the default value. """, ) - + self.parser.add_argument( + "--checkpoint.load_step", + type=int, + default=-1, + help="Load the checkpoint at the specified step. If -1, load the latest checkpoint.", + ) # activation checkpointing configs self.parser.add_argument( "--activation_checkpoint.mode", diff --git a/train.py b/train.py index 9e8b1fa8..53c813f1 100644 --- a/train.py +++ b/train.py @@ -206,7 +206,7 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint_loaded = checkpoint.load() + checkpoint_loaded = checkpoint.load(step=job_config.checkpoint.load_step) if parallel_dims.pp_enabled and not checkpoint_loaded: # TODO: fix this by allowing each rank to set their own seed From 7281e0be8feeb607f3c3f12cc3ceaafed87912c9 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Thu, 5 Dec 2024 15:40:25 -0800 Subject: [PATCH 4/4] Custom Dataset refactoring + docs (#715) EDIT: removed the specific new functions in hf_datasets.py and kept most of the doc changes and will not go for a registration based API Fixes https://github.com/pytorch/torchtitan/issues/311 This PR describes the status quo of how new datasets should be registered today, in that there's the implicit assumption that people are installing torchtitan from source and updating hf_datasets.py to support new datasets. As an example I passed in the wikipedia dataset The main "nice" thing about this PR is that `class HuggingFaceDataset` is now agnostic to the c4 dataset which makes it easier for new people to add datasets without reading the rest of the file There's another direction this PR could have went in which was to allow custom dataset registration, the benefit is people can support new datasets without installing titan from source but registration apis can feel kinda "bureaucratic" and presumably people would need to register the dataset somewhere, probably `train.py`? Not totally sure which is more in line with the repo's goals so opening this PR to discuss ```python def register_dataset( name: str, loader: Callable[[str, Dict[str, Any]], Any], processor: Callable[[Dict[str, Any]], str], path: Optional[str] = None, ) -> None: DATASET_LOADERS[name] = loader DATASET_TEXT_PROCESSORS[name] = processor def wikipedia_loader(dataset_path: str, **kwargs): return load_dataset( dataset_path, name="20220301.en", split="train", streaming=True, trust_remote_code=True, ) def wikipedia_processor(sample: Dict[str, Any]) -> str: return f"{sample['title']}\n\n{sample['text']}" register_dataset( name="wikipedia", loader=wikipedia_loader, processor=wikipedia_processor, path="wikipedia" ) ``` --- README.md | 2 +- docs/datasets.md | 74 ++++++++++++++ torchtitan/datasets/hf_datasets.py | 153 ++++++++++++++--------------- train.py | 1 - 4 files changed, 147 insertions(+), 83 deletions(-) create mode 100644 docs/datasets.md diff --git a/README.md b/README.md index 714d2ae2..40a0b4a9 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ You may want to see how the model is defined or how parallelism techniques are a 4. `torch.compile` support 5. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md)) 6. DDP and HSDP -7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) +7. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md) 8. Learning rate scheduler, meta-init, (optional) fused RMSNorm kernel 9. Loss, GPU memory, throughput (tokens/sec), and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) 10. Debugging tools including CPU/GPU profiling, [memory profiling](docs/memory_profiler.md), [Flight Recorder](#debugging), etc. diff --git a/docs/datasets.md b/docs/datasets.md new file mode 100644 index 00000000..e13da2dd --- /dev/null +++ b/docs/datasets.md @@ -0,0 +1,74 @@ +# Custom Datasets in TorchTitan + +TorchTitan is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example. + +## Quick Start +Locate the dataset configuration file: +``` +torchtitan/datasets/hf_datasets/hf_datasets.py +``` + +## Adding Your Dataset +You'll need to add three components: +1. A dataset loader function +2. A sample processor function +3. A dataset configuration entry + +### 1. Define Dataset Loader +Create a function that specifies how to load your dataset: + +```python +def load_wikipedia_dataset(dataset_path: str, **kwargs): + """Load Wikipedia dataset with specific configuration.""" + logger.info("Loading Wikipedia dataset...") + return load_dataset( + dataset_path, + name="20220301.en", + split="train", + streaming=True, + trust_remote_code=True, + ) +``` + +### 2. Define Sample Processor +Create a function that processes individual samples from your dataset: + +```python +def process_wikipedia_text(sample: Dict[str, Any]) -> str: + """Process Wikipedia dataset sample text.""" + return f"{sample['title']}\n\n{sample['text']}" +``` + +### 3. Register Your Dataset +Add your dataset configuration to the DATASETS dictionary: + +```python +DATASETS = { + # ... existing datasets ... + "wikipedia": DatasetConfig( + path="wikipedia", # default HuggingFace dataset path + loader=load_wikipedia_dataset, + text_processor=process_wikipedia_text, + ), +} +``` + +### 4. Configure Your Training +In your training configuration file (`.toml`), set your dataset: + +```toml +dataset = "wikipedia" +``` + +That's it! Your custom dataset is now ready to use with TorchTitan. + +## Key Points +- The DatasetConfig contains all necessary components for a dataset: + - `path`: The default path to the dataset (can be overridden during training) + - `loader`: Function to load the dataset + - `text_processor`: Function to process individual samples +- The loader function should return a HuggingFace dataset object +- The processor function should return a string that combines the relevant fields from your dataset +- Use `streaming=True` for large datasets to manage memory efficiently + +Now you can start training with your custom dataset! diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 9db036b0..745cf40f 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -5,12 +5,12 @@ # LICENSE file in the root directory of this source tree. import pickle -from typing import Any, Dict, List, Optional +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional import torch from torch.distributed.checkpoint.stateful import Stateful from torch.utils.data import IterableDataset - from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.datasets.tokenizer import Tokenizer @@ -19,49 +19,56 @@ from datasets import Dataset, load_dataset from datasets.distributed import split_dataset_by_node -# map from dataset name to a local directory, or -# a dataset repository on the HF hub -_supported_datasets = { - "c4_test": "test/assets/c4_test", - "c4": "allenai/c4", + +def _load_c4_dataset(dataset_path: str): + """Load C4 dataset with default configuration.""" + return load_dataset(dataset_path, name="en", split="train", streaming=True) + + +def _process_c4_text(sample: Dict[str, Any]) -> str: + """Process C4 dataset sample text.""" + return sample["text"] + + +@dataclass +class DatasetConfig: + path: str + loader: Callable + text_processor: Callable + + +# Add your dataset here here - more information at docs/datasets.md +DATASETS = { + "c4": DatasetConfig( + path="allenai/c4", + loader=_load_c4_dataset, + text_processor=_process_c4_text, + ), + "c4_test": DatasetConfig( + path="test/assets/c4_test", + loader=lambda path: load_dataset(path, split="train"), + text_processor=_process_c4_text, + ), } -class HuggingFaceDataset(IterableDataset, Stateful): - """PyTorch Representation of the HuggingFace Dataset. - - Args: - dataset_name (str): name of the dataset to load - dataset_path (Optional[str]): - Path to the dataset in the file system. If provided, data will be loaded - from this path instead of downloaded. - tokenizer (Tokenizer): - Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. - seq_len (int): max sequence length - world_size (int): number of data parallel processes participating in training - rank (int): rank of the current data parallel process - infinite (bool): whether to loop infinitely over the dataset - - We currently support the c4 dataset, and a subset of it for testing purposes: - c4_test (2K training entries) - c4 (177M training entries - this dataset is streamed due to the size) - - >> c4 (EN) <<: - c4 cleaned, English version - Data input format (c4): - { - 'url': 'https://klyq.com/beginners-bbq-class-taking-place-in-missoula/', - 'text': 'Beginners BBQ Class Taking Place in Missoula!\nDo you want to get better at ...', - 'timestamp': '2019-04-25T12:57:54Z' - } - - Example use (c4): - >>> ds = HuggingFaceDataset(dataset_name="c4", dataset_path=None, tokenizer=tokenizer) - >>> for batch in Dataloader(ds, batch_size=8): - print(f"Batch size: {len(batch)}") - Batch size: 8 - """ +def _validate_dataset( + dataset_name: str, dataset_path: str = None +) -> tuple[str, Callable, Callable]: + """Validate dataset name and path.""" + if dataset_name not in DATASETS: + raise ValueError( + f"Dataset {dataset_name} is not supported. " + f"Supported datasets are: {list(DATASETS.keys())}" + ) + + config = DATASETS[dataset_name] + path = dataset_path or config.path + logger.info(f"Preparing {dataset_name} dataset from {path}") + return path, config.loader, config.text_processor + +class HuggingFaceDataset(IterableDataset, Stateful): def __init__( self, dataset_name: str, @@ -72,47 +79,41 @@ def __init__( rank: int = 0, infinite: bool = False, ) -> None: - # allow user to pass in a (local or HF hub) path to use unsupported datasets - if dataset_name not in _supported_datasets: - if dataset_path: - logger.warning( - f"Dataset {dataset_name} is not tested or verfied. " - f"Recommended datasets are: {list(_supported_datasets.keys())}" - ) - else: - raise ValueError( - f"Dataset {dataset_name} is not supported. " - f"Supported datasets are: {list(_supported_datasets.keys())}" - ) - - if not dataset_path: - dataset_path = _supported_datasets[dataset_name] - logger.info(f"Preparing {dataset_name} dataset from {dataset_path}") - - if dataset_name == "c4": - # c4 is huge, and requires both streaming and language selection - # (we default to en) - ds = load_dataset(dataset_path, name="en", split="train", streaming=True) - else: - ds = load_dataset(dataset_path, split="train") - - # TODO: support shuffling + # Force lowercase for consistent comparison + dataset_name = dataset_name.lower() + + path, dataset_loader, text_processor = _validate_dataset( + dataset_name, dataset_path + ) + ds = dataset_loader(path) + self.dataset_name = dataset_name self._data = split_dataset_by_node(ds, rank, world_size) self._tokenizer = tokenizer self.seq_len = seq_len self.infinite = infinite + self._text_processor = text_processor - # variables for checkpointing + # Variables for checkpointing self._sample_idx = 0 self._all_tokens: List[int] = [] + def _get_data_iter(self): + if self._sample_idx == 0: + return iter(self._data) + + if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): + return iter([]) + + return iter(self._data.skip(self._sample_idx)) + def __iter__(self): max_buffer_token_len = 1 + self.seq_len while True: for sample in self._get_data_iter(): - sample_text = sample["text"] + # Use the dataset-specific text processor + sample_text = self._text_processor(sample) sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) self._all_tokens.extend(sample_tokens) self._sample_idx += 1 @@ -133,16 +134,6 @@ def __iter__(self): self._sample_idx = 0 logger.warning(f"Dataset {self.dataset_name} is being re-looped") - def _get_data_iter(self): - if self._sample_idx == 0: - return iter(self._data) - - # As skipping to the end throws an error in case of map-style dataset, return an empty iterator - if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): - return iter([]) - - return iter(self._data.skip(self._sample_idx)) - def load_state_dict(self, state_dict): self._sample_idx = state_dict["sample_idx"] self._all_tokens = state_dict["token_buffer"] @@ -184,12 +175,12 @@ def build_hf_data_loader( tokenizer: Tokenizer, batch_size: int, seq_len: int, - world_size, - rank, + world_size: int, + rank: int, infinite: bool = True, ): + """Build a data loader for HuggingFace datasets.""" hf_ds = HuggingFaceDataset( dataset_name, dataset_path, tokenizer, seq_len, world_size, rank, infinite ) - return DPAwareDataLoader(rank, hf_ds, batch_size=batch_size) diff --git a/train.py b/train.py index 53c813f1..c1e8fffe 100644 --- a/train.py +++ b/train.py @@ -86,7 +86,6 @@ def main(job_config: JobConfig): # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) - # build dataloader data_loader = build_hf_data_loader( job_config.training.dataset,