Skip to content

Commit

Permalink
enable data loading for data parallel training
Browse files Browse the repository at this point in the history
ghstack-source-id: 08d335e3151097a273742be7cab615a75015d4dd
Pull Request resolved: #49
  • Loading branch information
tianyu-l committed Feb 8, 2024
1 parent 15a001f commit 4f9eba6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 12 deletions.
30 changes: 19 additions & 11 deletions torchtrain/datasets/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class AlpacaDataset(IterableDataset):
Args:
tokenizer (Tokenizer): Tokenizer used to encode data. Tokenize must implement an `encode` and `decode` method.
seq_len (int): max sequence length
world_size (int): number of data parallel processes participating in training
rank (int): rank of the current data parallel process
Data input format:
{
Expand All @@ -34,11 +36,20 @@ class AlpacaDataset(IterableDataset):
Batch size: 8
"""

def __init__(self, tokenizer: TokenizerIf, seq_len: int = 2048, **kwargs) -> None:
def __init__(
self,
tokenizer: TokenizerIf,
seq_len: int = 2048,
world_size: int = 1,
rank: int = 0,
**kwargs
) -> None:
self._data = load_dataset("tatsu-lab/alpaca", split="train")
self._tokenizer = tokenizer
self.data_iterator = iter(self._data)
self.seq_len = seq_len
self.world_size = world_size
self.rank = rank
self.response_tag = "\n\n### Response:\n"

def __len__(self):
Expand All @@ -48,7 +59,12 @@ def __iter__(self):
max_buffer_token_len = 1 + self.seq_len
all_tokens: List[int] = []

for sample in self.data_iterator:
for idx, sample in enumerate(self.data_iterator):
# select samples to pack in a round-robin fashion
# TODO: This is a temporary solution for small datasets like Alpaca.
# For larger datasets we need to use a more scalable approach.
if idx % self.world_size != self.rank:
continue
sample_text = sample["text"]
sample_tokens = self._tokenizer.encode(sample_text, bos=True, eos=True)
all_tokens.extend(sample_tokens)
Expand All @@ -66,14 +82,6 @@ def __iter__(self):
def build_alpaca_data_loader(
tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank
):
alpaca_ds = AlpacaDataset(tokenizer=tokenizer, seq_len=seq_len)
# TOOD: sampler can't work with iterable dataset, figure out a way
# to sample in a distributed manner
# dist_sampler = DistributedSampler(
# alpaca_ds,
# world_size,
# rank,
# shuffle=True,
# )
alpaca_ds = AlpacaDataset(tokenizer, seq_len, world_size, rank)

return DataLoader(alpaca_ds, batch_size=batch_size)
1 change: 1 addition & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def build_mesh(self, device_type):
if d > 1:
dims.append(d)
names.append(name)
names = tuple(names)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
return init_device_mesh(device_type, dims, mesh_dim_names=names)

Expand Down
2 changes: 1 addition & 1 deletion torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, args):

if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ["dp"], dp_mesh.mesh_dim_names
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
fsdp_config = {
"mixed_precision": MixedPrecision(
param_dtype=torch.bfloat16,
Expand Down

0 comments on commit 4f9eba6

Please sign in to comment.