From 4f9eba631f9ab6a420871bf898190b1d65eebfae Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 7 Feb 2024 17:11:01 -0800 Subject: [PATCH] enable data loading for data parallel training ghstack-source-id: 08d335e3151097a273742be7cab615a75015d4dd Pull Request resolved: https://github.com/pytorch-labs/torchtrain/pull/49 --- torchtrain/datasets/alpaca.py | 30 +++++++++++++------- torchtrain/parallelisms/__init__.py | 1 + torchtrain/parallelisms/parallelize_llama.py | 2 +- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/torchtrain/datasets/alpaca.py b/torchtrain/datasets/alpaca.py index 28e847f1..f52d2112 100644 --- a/torchtrain/datasets/alpaca.py +++ b/torchtrain/datasets/alpaca.py @@ -17,6 +17,8 @@ class AlpacaDataset(IterableDataset): Args: tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method. seq_len (int): max sequence length + world_size (int): number of data parallel processes participating in training + rank (int): rank of the current data parallel process Data input format: { @@ -34,11 +36,20 @@ class AlpacaDataset(IterableDataset): Batch size: 8 """ - def __init__(self, tokenizer: TokenizerIf, seq_len: int = 2048, **kwargs) -> None: + def __init__( + self, + tokenizer: TokenizerIf, + seq_len: int = 2048, + world_size: int = 1, + rank: int = 0, + **kwargs + ) -> None: self._data = load_dataset("tatsu-lab/alpaca", split="train") self._tokenizer = tokenizer self.data_iterator = iter(self._data) self.seq_len = seq_len + self.world_size = world_size + self.rank = rank self.response_tag = "\n\n### Response:\n" def __len__(self): @@ -48,7 +59,12 @@ def __iter__(self): max_buffer_token_len = 1 + self.seq_len all_tokens: List[int] = [] - for sample in self.data_iterator: + for idx, sample in enumerate(self.data_iterator): + # select samples to pack in a round-robin fashion + # TODO: This is a temporary solution for small datasets like Alpaca. + # For larger datasets we need to use a more scalable approach. + if idx % self.world_size != self.rank: + continue sample_text = sample["text"] sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True) all_tokens.extend(sample_tokens) @@ -66,14 +82,6 @@ def __iter__(self): def build_alpaca_data_loader( tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank ): - alpaca_ds = AlpacaDataset(tokenizer=tokenizer, seq_len=seq_len) - # TOOD: sampler can't work with iterable dataset, figure out a way - # to sample in a distributed manner - # dist_sampler = DistributedSampler( - # alpaca_ds, - # world_size, - # rank, - # shuffle=True, - # ) + alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank) return DataLoader(alpaca_ds, batch_size=batch_size) diff --git a/torchtrain/parallelisms/__init__.py b/torchtrain/parallelisms/__init__.py index 57d42687..464397fa 100644 --- a/torchtrain/parallelisms/__init__.py +++ b/torchtrain/parallelisms/__init__.py @@ -46,6 +46,7 @@ def build_mesh(self, device_type): if d > 1: dims.append(d) names.append(name) + names = tuple(names) logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") return init_device_mesh(device_type, dims, mesh_dim_names=names) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index dbf418ea..d6db313f 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -156,7 +156,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): if parallel_dims.dp_enabled: dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh - assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names + assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names fsdp_config = { "mixed_precision": MixedPrecision( param_dtype=torch.bfloat16,