diff --git a/torchtrain/datasets/__init__.py b/torchtrain/datasets/__init__.py index d555aca6..8d91e584 100644 --- a/torchtrain/datasets/__init__.py +++ b/torchtrain/datasets/__init__.py @@ -1,7 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + from torchtrain.datasets.alpaca import build_alpaca_data_loader -from torchtrain.datasets.tokenizer import create_tokenizer from torchtrain.datasets.pad_batch_sequence import pad_batch_to_longest_seq +from torchtrain.datasets.tokenizer import create_tokenizer +__all__ = ["build_alpaca_data_loader", "create_tokenizer", "pad_batch_to_longest_seq"] dataloader_fn = { "alpaca": build_alpaca_data_loader, diff --git a/torchtrain/datasets/alpaca.py b/torchtrain/datasets/alpaca.py index c87a240a..28e847f1 100644 --- a/torchtrain/datasets/alpaca.py +++ b/torchtrain/datasets/alpaca.py @@ -1,18 +1,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -from typing import List, Tuple +from typing import List import torch - -from datasets import load_dataset -from torch.utils.data import IterableDataset, DataLoader, DistributedSampler +from torch.utils.data import DataLoader, IterableDataset from torchtrain.datasets.tokenizer import TokenizerIf +from datasets import load_dataset + class AlpacaDataset(IterableDataset): """PyTorch Representation of the Alpaca Dataset from Hugging Face. @@ -37,11 +34,7 @@ class AlpacaDataset(IterableDataset): Batch size: 8 """ - def __init__(self, - tokenizer: TokenizerIf, - seq_len: int = 2048, - **kwargs - ) -> None: + def __init__(self, tokenizer: TokenizerIf, seq_len: int = 2048, **kwargs) -> None: self._data = load_dataset("tatsu-lab/alpaca", split="train") self._tokenizer = tokenizer self.data_iterator = iter(self._data) @@ -52,7 +45,7 @@ def __len__(self): return len(self._data) def __iter__(self): - max_buffer_token_len = (1 + self.seq_len) + max_buffer_token_len = 1 + self.seq_len all_tokens: List[int] = [] for sample in self.data_iterator: @@ -71,11 +64,7 @@ def __iter__(self): def build_alpaca_data_loader( - tokenizer: TokenizerIf, - batch_size: int, - seq_len: int, - world_size, - rank + tokenizer: TokenizerIf, batch_size: int, seq_len: int, world_size, rank ): alpaca_ds = AlpacaDataset(tokenizer=tokenizer, seq_len=seq_len) # TOOD: sampler can't work with iterable dataset, figure out a way diff --git a/torchtrain/datasets/download_tokenizer.py b/torchtrain/datasets/download_tokenizer.py index 16d1e70b..e3e92cac 100644 --- a/torchtrain/datasets/download_tokenizer.py +++ b/torchtrain/datasets/download_tokenizer.py @@ -1,8 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. import os from typing import Optional @@ -11,20 +12,38 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: from huggingface_hub import hf_hub_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: - hf_hub_download(repo_id, "tokenizer.model", local_dir=f"torchtrain/datasets/tokenizer/", local_dir_use_symlinks=False, token=hf_token) + hf_hub_download( + repo_id, + "tokenizer.model", + local_dir="torchtrain/datasets/tokenizer/", + local_dir_use_symlinks=False, + token=hf_token, + ) except HTTPError as e: if e.response.status_code == 401: - print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) else: raise e -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Download tokenizer from HuggingFace.') - parser.add_argument('--repo_id', type=str, default="meta-llama/llama-2-70b", help='Repository ID to download from.') - parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.') + + parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.") + parser.add_argument( + "--repo_id", + type=str, + default="meta-llama/llama-2-70b", + help="Repository ID to download from.", + ) + parser.add_argument( + "--hf_token", type=str, default=None, help="HuggingFace API token." + ) args = parser.parse_args() hf_download(args.repo_id, args.hf_token) diff --git a/torchtrain/models/llama/__init__.py b/torchtrain/models/llama/__init__.py index 1f1cb776..8b70ce3d 100644 --- a/torchtrain/models/llama/__init__.py +++ b/torchtrain/models/llama/__init__.py @@ -1,8 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + from torchtrain.models.llama.model import ModelArgs, Transformer +__all__ = ["Transformer"] + llama_configs = { "debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16), "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), "13B": ModelArgs(dim=5120, n_layers=40, n_heads=40), - "70B": ModelArgs(dim=8192, n_layers=80, n_heads=64, n_kv_heads=8, ffn_dim_multiplier=1.3, multiple_of=4096), + "70B": ModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + ), } diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 2cbefc5f..ee504f7f 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -1,9 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. -import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn.functional as F @@ -26,25 +25,26 @@ class ModelArgs: class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - """ - Initialize the RMSNorm normalization layer. + """ + Initialize the RMSNorm normalization layer. - Args: - dim (int): The dimension of the input tensor. - eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. - Attributes: - eps (float): A small value added to the denominator for numerical stability. - weight (nn.Parameter): Learnable scaling parameter. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. - """ + """ + + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.empty(dim)) self.reset_parameters() - def _norm(self, x): + def _norm(self, x: torch.Tensor): """ Apply the RMSNorm normalization to the input tensor. @@ -57,7 +57,7 @@ def _norm(self, x): """ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - def forward(self, x): + def forward(self, x: torch.Tensor): """ Forward pass through the RMSNorm layer. @@ -111,10 +111,6 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): Returns: torch.Tensor: Reshaped frequency tensor. - - Raises: - AssertionError: If the frequency tensor doesn't match the expected shape. - AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. """ ndim = x.ndim assert 0 <= 1 < ndim @@ -165,28 +161,29 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: class Attention(nn.Module): - """Multi-head attention module.""" - def __init__(self, args: ModelArgs): - """ - Initialize the Attention module. + """ + Multi-head attention module. - Args: - args (ModelArgs): Model configuration parameters. - - Attributes: - n_kv_heads (int): Number of key and value heads. - n_heads (int): Number of query heads. - n_local_kv_heads (int): Number of local key and value heads. - n_rep (int): Number of repetitions for local heads. - head_dim (int): Dimension size of each attention head. - wq (Linear): Linear transformation for queries. - wk (Linear): Linear transformation for keys. - wv (Linear): Linear transformation for values. - wo (Linear): Linear transformation for output. - cache_k (torch.Tensor): Cached keys for attention. - cache_v (torch.Tensor): Cached values for attention. + Args: + args (ModelArgs): Model configuration parameters. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_local_kv_heads (int): Number of local key and value heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + cache_k (torch.Tensor): Cached keys for attention. + cache_v (torch.Tensor): Cached values for attention. + + """ + + def __init__(self, args: ModelArgs): - """ super().__init__() self.n_heads = args.n_heads self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads @@ -224,23 +221,43 @@ def forward( xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) # repeat k/v heads if n_kv_heads < n_heads - keys = repeat_kv(xk, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) - values = repeat_kv(xv, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim) + keys = repeat_kv( + xk, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) + values = repeat_kv( + xv, self.n_rep + ) # (bs, cache_len + seqlen, n_local_heads, head_dim) xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) - xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim) # we use casual mask for training - output = F.scaled_dot_product_attention( - xq, xk, xv, is_causal=True - ) - output = output.transpose(1, 2).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) output = output.view(bsz, seqlen, -1) return self.wo(output) class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (Optional[float]): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + def __init__( self, dim: int, @@ -248,21 +265,7 @@ def __init__( multiple_of: int, ffn_dim_multiplier: Optional[float], ): - """ - Initialize the FeedForward module. - - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. - Attributes: - w1 (Linear): Linear transformation for the first layer. - w2 (Linear): Linear transformation for the second layer. - w3 (Linear): Linear transformation for the third layer. - - """ super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier @@ -279,20 +282,22 @@ def forward(self, x): class RotaryEmbedding(nn.Module): + """ + RotaryEmbedding Module + """ + def __init__(self, params: ModelArgs): - """ - Initialize the embedding module. - """ super().__init__() self.params = params - self.tok_embeddings = nn.Embedding( - params.vocab_size, params.dim - ) + self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.freqs_cis = precompute_freqs_cis( - # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. - # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. - self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation + # of models is 4096. + # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training + # or fine-tuning. + self.params.dim // self.params.n_heads, + self.params.max_seq_len * 2, ) def forward(self, tokens: torch.Tensor): @@ -308,30 +313,32 @@ def forward(self, tokens: torch.Tensor): _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) self.freqs_cis = self.freqs_cis.to(h.device) - freqs_cis = self.freqs_cis[0 : seqlen] + freqs_cis = self.freqs_cis[0:seqlen] return h, freqs_cis class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): - """ - Initialize a TransformerBlock. + """ + TransformerBlock Module - Args: - layer_id (int): Identifier for the layer. - args (ModelArgs): Model configuration parameters. - - Attributes: - n_heads (int): Number of attention heads. - dim (int): Dimension size of the model. - head_dim (int): Dimension size of each attention head. - attention (Attention): Attention module. - feed_forward (FeedForward): FeedForward module. - layer_id (int): Identifier for the layer. - attention_norm (RMSNorm): Layer normalization for attention output. - ffn_norm (RMSNorm): Layer normalization for feedforward output. + Args: + layer_id (int): Identifier for the layer. + args (ModelArgs): Model configuration parameters. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, args: ModelArgs): - """ super().__init__() self.n_heads = args.n_heads self.dim = args.dim @@ -363,32 +370,32 @@ def forward( torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention( - self.attention_norm(x), freqs_cis - ) + h = x + self.attention(self.attention_norm(x), freqs_cis) out = h + self.feed_forward(self.ffn_norm(h)) return out class Transformer(nn.Module): - def __init__(self, params: ModelArgs): - """ - Initialize a Transformer model. + """ + Transformer Module - Args: - params (ModelArgs): Model configuration parameters. - - Attributes: - params (ModelArgs): Model configuration parameters. - vocab_size (int): Vocabulary size. - n_layers (int): Number of layers in the model. - tok_embeddings (ParallelEmbedding): Token embeddings. - layers (torch.nn.ModuleList): List of Transformer blocks. - norm (RMSNorm): Layer normalization for the model output. - output (ColumnParallelLinear): Linear layer for final output. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + Args: + params (ModelArgs): Model configuration parameters. + + Attributes: + params (ModelArgs): Model configuration parameters. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (ColumnParallelLinear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, params: ModelArgs): - """ super().__init__() self.params = params self.vocab_size = params.vocab_size @@ -401,9 +408,7 @@ def __init__(self, params: ModelArgs): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear( - params.dim, params.vocab_size, bias=False - ) + self.output = nn.Linear(params.dim, params.vocab_size, bias=False) def forward(self, tokens: torch.Tensor): """ diff --git a/torchtrain/parallelisms/__init__.py b/torchtrain/parallelisms/__init__.py index 4e4d344e..57d42687 100644 --- a/torchtrain/parallelisms/__init__.py +++ b/torchtrain/parallelisms/__init__.py @@ -1,4 +1,6 @@ -import os +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + import logging from dataclasses import dataclass @@ -31,14 +33,16 @@ def _validate(self): assert dp >= 1, dp assert sp >= 1, sp assert pp >= 1, pp - assert dp * sp * pp == self.world_size, ( - f"Invalid parallel dims: dp({dp}) * sp({sp}) * pp({pp}) != WORLD_SIZE({self.world_size})" - ) + assert ( + dp * sp * pp == self.world_size + ), f"Invalid parallel dims: dp({dp}) * sp({sp}) * pp({pp}) != WORLD_SIZE({self.world_size})" def build_mesh(self, device_type): dims = [] names = [] - for d, name in zip([self.dp, self.sp, self.pp], ["dp", "sp", "pp"]): + for d, name in zip( + [self.dp, self.sp, self.pp], ["dp", "sp", "pp"], strict=True + ): if d > 1: dims.append(d) names.append(name) diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index 39f7355f..d4950ef6 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -1,10 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + # this file applies the PTD parallelisms and various training techniques to the # llama model, i.e. activation checkpoint, etc. -import os -import torch import logging +import torch + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, @@ -18,14 +21,14 @@ from torch.distributed.fsdp.wrap import enable_wrap, wrap from torchtrain.logging_utils import rank0_log -from typing import Dict logger = logging.getLogger(__name__) # Uses PTD FSDP AC wrapper def checkpoint_wrapper(module, config): - return ptd_checkpoint_wrapper(module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False) - + return ptd_checkpoint_wrapper( + module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False + ) def parallelize_llama(model, world_mesh, parallel_dims, args): @@ -65,11 +68,11 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): transformer_block = checkpoint_wrapper(transformer_block, args) # Wraps each layer with FSDP - model.layers[layer_id]= wrap(transformer_block) + model.layers[layer_id] = wrap(transformer_block) # wrap the rest layers with FSDP model = wrap(model.cuda()) - rank0_log(f"Applied parallelisms to the model...") + rank0_log("Applied parallelisms to the model...") return model diff --git a/train.py b/train.py index 65282deb..e2872725 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + import argparse import os from dataclasses import dataclass, field @@ -6,22 +9,19 @@ # torch imports import torch import torch.nn.functional as F -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.utils.data import DataLoader - -from torchtrain.profiling import maybe_run_profiler -from torchtrain.logging_utils import init_logger, rank0_log +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler # torchtrain related -from torchtrain.datasets import ( - create_tokenizer, - dataloader_fn, -) -from torchtrain.models import models_config, model_name_to_cls, model_name_to_tokenizer -from torchtrain.parallelisms import ParallelDims, models_parallelize_fns +from torchtrain.datasets import create_tokenizer, dataloader_fn +from torchtrain.logging_utils import init_logger, rank0_log from torchtrain.lr_scheduling import get_lr_scheduler +from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config +from torchtrain.parallelisms import models_parallelize_fns, ParallelDims + +from torchtrain.profiling import maybe_run_profiler + @dataclass class TrainState: @@ -58,10 +58,11 @@ def main(args): init_logger() # init world mesh world_size = int(os.environ["WORLD_SIZE"]) - parallel_dims = ParallelDims(dp=args.dp_degree, sp=args.sp_degree, pp=args.pp_degree, world_size=world_size) + parallel_dims = ParallelDims( + dp=args.dp_degree, sp=args.sp_degree, pp=args.pp_degree, world_size=world_size + ) world_mesh = parallel_dims.build_mesh(device_type="cuda") - model_name = args.model # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] @@ -90,12 +91,7 @@ def main(args): model = model_cls.from_model_args(model_config) # apply PTD parallelisms + AC - model = models_parallelize_fns[model_name]( - model, - world_mesh, - parallel_dims, - args - ) + model = models_parallelize_fns[model_name](model, world_mesh, parallel_dims, args) # to use FSDP-customized gradient scaler and gradient clipping solutions assert isinstance(model, FSDP)