From 3b05ca99227b479f080a03e63cffac4024ffe40c Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 20 May 2024 11:17:27 -0700 Subject: [PATCH] Add Pipeline Parallel (and 2D PP+FSDP) support runs PP+DP and PP+TP without issue, runs PP+TP+DP with decreasing loss, but fails DCP save Supports only simple schedules currently, gpipe and 1f1b. Ads cmdline/toml arg for specifiying split points, in a unified way between tracer or manual frontend. e.g. user can specifiy "layers.2,layers.4" as split points. Currently uses manual frontend by default, but allows specifying tracer frontend. Tracer frontend requires working around additional compatibility limitations, indicated by raising assertions, and is not ready for wider use yet. ghstack-source-id: ecde6ef7c3453c14e78707404fb1d4f4ace89f1b Pull Request resolved: https://github.com/pytorch/torchtitan/pull/318 --- create_seed_checkpoint.sh | 2 +- test_runner.py | 110 ++++++++- torchtitan/config_manager.py | 58 ++++- torchtitan/parallelisms/__init__.py | 6 +- torchtitan/parallelisms/parallelize_llama.py | 204 +++++++++++++++- torchtitan/parallelisms/pipelining_utils.py | 242 +++++++++++++++++++ train.py | 74 +++++- train_configs/debug_model.toml | 4 +- 8 files changed, 666 insertions(+), 34 deletions(-) create mode 100644 torchtitan/parallelisms/pipelining_utils.py diff --git a/create_seed_checkpoint.sh b/create_seed_checkpoint.sh index 38bab219f..1abc77ec5 100755 --- a/create_seed_checkpoint.sh +++ b/create_seed_checkpoint.sh @@ -25,7 +25,7 @@ LOG_RANK=0 CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"} seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint" -force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1" +force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1" overrides="" if [ $# -ne 0 ]; then overrides="$*" diff --git a/test_runner.py b/test_runner.py index ca9d13209..f7a4c7a44 100755 --- a/test_runner.py +++ b/test_runner.py @@ -25,6 +25,8 @@ class OverrideDefinitions: override_args: Sequence[Sequence[str]] = tuple(tuple(" ")) test_descr: str = "default" + requires_seed_checkpoint: bool = False + ngpu: int = 4 def build_test_list(args): @@ -35,6 +37,78 @@ def build_test_list(args): """ integration_tests_flavors = defaultdict(list) integration_tests_flavors["debug_model.toml"] = [ + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_1f1b/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule 1f1b", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test 1f1b", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_gpipe/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule gpipe", + "--training.data_parallel_degree 1", + ], + ], + "PP 1D test gpipe", + requires_seed_checkpoint=True, + ngpu=2, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule 1f1b", + "--training.data_parallel_degree 2", + ], + ], + "PP+DP 1f1b 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--experimental.pipeline_parallel_schedule gpipe", + "--training.data_parallel_degree 2", + ], + ], + "PP+DP gpipe 2D test", + requires_seed_checkpoint=True, + ), + OverrideDefinitions( + [ + [ + "--checkpoint.enable_checkpoint", + f"--job.dump_folder {args.output_dir}/pp_tp/", + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_split_points layers.1", + "--training.tensor_parallel_degree 2", + "--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP + ], + ], + "PP+TP 2D test", + requires_seed_checkpoint=True, + ), OverrideDefinitions( [ [ @@ -100,22 +174,42 @@ def build_test_list(args): return integration_tests_flavors +def _run_cmd(cmd): + return subprocess.run( + [cmd], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + shell=True, + ) + + def run_test(test_flavor: OverrideDefinitions, full_path: str): # run_test supports sequence of tests. for override_arg in test_flavor.override_args: - cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh" + + cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh" if override_arg: cmd += " " + " ".join(override_arg) print( f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}=====" ) - result = subprocess.run( - [cmd], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - shell=True, - ) + + if test_flavor.requires_seed_checkpoint: + dump_folder_arg = None + for arg in override_arg: + if "--job.dump_folder" in arg: + dump_folder_arg = arg + assert ( + dump_folder_arg is not None + ), "Can't use seed checkpoint if folder is not specified" + print("Creating seed checkpoint") + result = _run_cmd( + f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}" + ) + print(result.stdout) + + result = _run_cmd(cmd) print(result.stdout) if result.returncode != 0: raise Exception( diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 99cc0746a..81d9bc621 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -29,6 +29,10 @@ def torch_dtype(dtype_str: str) -> torch.dtype: return DTYPE_MAP[dtype_str] +def string_list(raw_arg): + return raw_arg.split(",") + + class JobConfig: """ A helper class to manage the train configuration. @@ -214,10 +218,56 @@ def __init__(self): help="Whether to apply loss parallel when sequence parallel is enabled", ) self.parser.add_argument( - "--training.pipeline_parallel_degree", + "--experimental.pipeline_parallel_degree", type=int, default=1, - help="Pipeline Parallelism degree. 1 means disabled.", + help=""" + Pipeline Parallelism degree, or number of ranks. 1 means disabled. + If using looped schedules, this still specifies the number of physical ranks, not the number + of stages. Stages per rank are inferred from split points degree, and schedule.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_points", + type=string_list, + nargs="+", + default=[], + help=""" + Specify comma-separated names of modules to use as the beginning of a split point. + + e.g. "layers.0,layers.2" will cause the model to be split into 3 stages, + the first containing all the layers up to layers.0, + the second containing layers.0 and up to layers.2, + the third containing layers.2 and all the remaining layers. + + Note: fully-automated splitting may be enabled in the future, + but currently the split points must be specified manually for both manual and tracer.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_schedule", + type=str, + choices=["1f1b", "gpipe"], + default="1f1b", + help=""" + Specify the Pipeline Parallel schedule to use. + + The schedule must be compatible with the split points and stages_per_rank. + + Looped schedules are not yet supported in torchtitan.""", + ) + self.parser.add_argument( + "--experimental.pipeline_parallel_split_mode", + type=str, + choices=["manual", "tracer"], + default="manual", + help=""" + Specify the split method (e.g. the Pipeline Parallelism Front End) + + "manual" means each rank will construct an nn.Module with the appropriate layers and .forward + implementation manually, and then wrap it in a PipelineStage. + + "tracer" means the full model will be initialized (via meta device) and then traced into a graph, + split via the provided split points, unflattened into an nn.Module, + and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""", ) self.parser.add_argument( "--training.mixed_precision_param", @@ -441,6 +491,10 @@ def parse_args_from_command_line( aux_parser.add_argument( "--" + arg, action="store_true" if val else "store_false" ) + elif arg == "experimental.pipeline_parallel_split_points": + # type inference breaks here, since the type is just 'list' and it ends up flattening + # e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...] + aux_parser.add_argument("--" + arg, type=string_list) else: aux_parser.add_argument("--" + arg, type=type(val)) diff --git a/torchtitan/parallelisms/__init__.py b/torchtitan/parallelisms/__init__.py index e791b832a..7e1b21c79 100644 --- a/torchtitan/parallelisms/__init__.py +++ b/torchtitan/parallelisms/__init__.py @@ -9,12 +9,16 @@ from torch.distributed.device_mesh import init_device_mesh from torchtitan.logging_utils import logger -from torchtitan.parallelisms.parallelize_llama import parallelize_llama +from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama models_parallelize_fns = { "llama2": parallelize_llama, "llama3": parallelize_llama, } +models_pipelining_fns = { + "llama2": pipeline_llama, + "llama3": pipeline_llama, +} @dataclass diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 4f1525879..909cd8d32 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,7 +8,7 @@ # llama model, i.e. activation checkpointing, etc. from collections import defaultdict -from typing import Tuple +from typing import Dict, Tuple import torch @@ -18,6 +18,11 @@ checkpoint_wrapper as ptd_checkpoint_wrapper, CheckpointImpl, ) +from torch.distributed.pipelining import pipeline, SplitPoint +from torch.distributed.pipelining._PipelineStage import ( + _PipelineStage, + ManualPipelineStage, +) from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -30,7 +35,7 @@ from torchtitan.config_manager import JobConfig from torchtitan.logging_utils import logger - +from torchtitan.parallelisms.pipelining_utils import split_stage_fqns # for selective AC no_recompute_list = { @@ -129,15 +134,191 @@ def get_tp_parallel_strategy( return RowwiseParallel, ColwiseParallel +def _llama_fqns(num_layers): + return ( + [ + "tok_embeddings", + ] + + [f"layers.{i}" for i in range(num_layers)] + + [ + "norm", + "output", + ] + ) + + +def pipeline_llama( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.experimental.pipeline_parallel_split_mode == "manual": + return pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + elif job_config.experimental.pipeline_parallel_split_mode == "tracer": + return pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_split_mode} is not a valid split mode" + ) + + +def _llama_trace_input(job_config, model_config, device="meta"): + """Get meta tensors with the right input shapes used for tracing""" + tokens_shape = (job_config.training.batch_size, job_config.training.seq_len) + tokens = torch.randint( + model_config.vocab_size, tokens_shape, dtype=torch.int64, device=device + ) + return (tokens,) + + +def pipeline_llama_manual( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + """ + This API gets individual torch.nn.Module objects for each pipeline stage (including virtual stages). + + The SPMD parallelisms should be applied to + """ + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + pp_size = pp_mesh.size() + # heuristically == PP dim but should be a config + microbatches = parallel_dims.pp + stage_idx = pp_rank # TODO support virtual stages + this_stage_layer_names = split_stage_fqns( + _llama_fqns(len(model.layers)), + job_config.experimental.pipeline_parallel_split_points, + pp_rank, + ) + + if pp_rank < pp_size - 1: + model.norm = None + model.output = None + if pp_rank > 0: + model.tok_embeddings = None + names = list(model.layers.keys()) + for name in names: + if f"layers.{name}" not in this_stage_layer_names: + del model.layers[name] + + logger.info(f"PP rank {pp_rank} is using this model chunk\n{model}") + + # TODO(whc) once ManualPipelineStage supports lazy shape inference, we can leave model on meta device longer and + # get rid of the input shape hardcoded here. For now, it should not be a big deal since we only materialize the + # layers of the model that map to this stage, not the whole model. + + if pp_rank == 0: + # first layer + input = torch.randint( + model_config.vocab_size, + size=(job_config.training.batch_size, job_config.training.seq_len), + dtype=torch.int64, + device=device, + ) + else: + # later layers (assume all start w/ a transformer layer) + input = torch.rand( + size=( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ), + dtype=job_config.training.mixed_precision_param + if parallel_dims.dp_enabled + else torch.float32, + device=device, + ) + + if pp_rank == pp_size - 1: + # last layer + output = torch.rand( + size=( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.vocab_size, + ), + dtype=torch.float32, + device=device, + ) + else: + # earlier layers (assume all end in a transformer layer) + output = torch.rand( + size=( + job_config.training.batch_size, + int(job_config.training.seq_len // parallel_dims.tp), + model_config.dim, + ), + dtype=job_config.training.mixed_precision_param + if parallel_dims.dp_enabled + else torch.float32, + device=device, + ) + + model.to_empty(device=device) + stage = ManualPipelineStage( + model, + pp_rank, + pp_size, + device, + microbatches, + input_args=input.chunk(microbatches)[0], + output_args=output.chunk(microbatches)[0], + group=pp_mesh.get_group("pp"), + ) + return (stage, model) + + +def pipeline_llama_tracer( + model, world_mesh, parallel_dims, job_config: JobConfig, device, model_config: Dict +): + if job_config.model.norm_type == "fused_rmsnorm": + # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode + # coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm + raise NotImplementedError( + "fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm." + ) + + # TODO(whc) maybe we can just fix this by feeding bf16 into the tracer for its input shapes? + raise NotImplementedError( + "pipeline tracer doesn't work with fsdp mixed precision currently. " + "To work around, edit fsdp mixed precision config to use fp32." + ) + pp_mesh = world_mesh["pp"] + pp_rank = pp_mesh.get_local_rank() + stage_idx = pp_mesh.get_local_rank() + layers_per_rank = len(model.layers) // parallel_dims.pp + split_spec = { + f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING + for i in range(1, parallel_dims.pp) + } + + # Create a pipeline representation from the model + pipe = pipeline( + model, + parallel_dims.pp, + example_args=_llama_trace_input(job_config, model_config), + split_spec=split_spec, + ) + model = pipe.get_stage_module(stage_idx) + stage = _PipelineStage( + stage_module=model, + stage_index=pp_rank, + pipe_info=pipe.pipe_info, + device=device, + group=pp_mesh.get_group(), + ) + return (stage, model) + + def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): """ - Apply parallelisms and activation checkpointing to the model. + Apply SPMD parallelisms and activation checkpointing to the model. NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - if parallel_dims.pp_enabled: - raise NotImplementedError("PP not implemented yet.") if parallel_dims.tp_enabled: if job_config.model.norm_type == "fused_rmsnorm": @@ -221,15 +402,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): transformer_block, job_config.activation_checkpoint ) # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 + # transformer block since FSDP would prefetch it immediately. + # When using Pipeline Parallelism, generally zero-2 is best so as to avoid repeated reshardings + # per microbatch. + reshard_after_forward = ( + int(layer_id) < len(model.layers) - 1 and not parallel_dims.pp_enabled + ) fully_shard( transformer_block, **fsdp_config, reshard_after_forward=reshard_after_forward, ) model.layers[layer_id] = transformer_block - model = fully_shard(model, **fsdp_config) + + model = fully_shard( + model, **fsdp_config, reshard_after_forward=not parallel_dims.pp_enabled + ) if ac_mode in ("full", "selective"): logger.info(f"Applied {ac_mode} activation checkpointing to the model") logger.info("Applied FSDP to the model") diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py new file mode 100644 index 000000000..e5ce425a7 --- /dev/null +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -0,0 +1,242 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# from torch.distributed.pipelining import Schedule1F1B, ScheduleGPipe +from collections import defaultdict + +from typing import Dict, List, Optional + +import torch.distributed as dist +from torch.distributed.pipelining import ScheduleGPipe + +# imports related to local copy of Schedule1F1B with local fix +from torch.distributed.pipelining.PipelineSchedule import ( + PipelineScheduleSingle, + # sorted_batch_p2p, +) +from torch.profiler import record_function + +from torchtitan.logging_utils import logger + +# haven't landed these yet in core +def batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None): + desc_str = f"{desc}, " if desc else "" + logger.debug(f"batch_p2p {desc_str}{p2p_ops}") # noqa: G004 + return dist.batch_isend_irecv(p2p_ops).pop() + + +def sorted_batch_p2p( + p2p_ops: List[dist.P2POp], desc: Optional[str] = None +) -> Dict[int, dist.Work]: + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # List is the list of ops towards the peer + ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list) + work_by_peer: Dict[int, dist.Work] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = batch_p2p(ops, desc=desc) + + return work_by_peer + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[List] = None, + kwarg_mbs: Optional[List] = None, + target_mbs: Optional[List] = None, + losses: Optional[List] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + # forward for num_microbatches + backward for num_microbatches + total_ops = self._n_microbatches * 2 + + # Example, 4 GPUs, 8 microbatches + # Stage 0: 6 warmup, 2 1f1b, 6 cooldown + # Stage 1: 4 warmup, 4 1f1b, 4 cooldown + # Stage 2: 2 warmup, 6 1f1b, 2 cooldown + # Stage 3: 0 warmup, 8 1f1b, 0 cooldown + # fwd only + warmup_steps = min( + self._n_microbatches, + 2 * (self._num_stages - self._stage.stage_index - 1), + ) + # fwd + bwd + main_1f1b_steps = self._n_microbatches - warmup_steps + # bwd only + cooldown_steps = total_ops - (warmup_steps + (2 * main_1f1b_steps)) + total_steps = warmup_steps + main_1f1b_steps + cooldown_steps + logger.debug( + f"Stage {self._stage.stage_index}: " # noqa: G004 + f"Warmup steps: {warmup_steps}, " + f"Main 1F1B steps: {main_1f1b_steps}, " + f"Cooldown steps: {cooldown_steps}, " + f"Total steps: {total_steps}" + ) + + # Delay send waits + fwd_sends_to_wait: List[dist.Work] = [] + bwd_sends_to_wait: List[dist.Work] = [] + + def is_forward_step(i): + assert i >= 0, i + return i < self._n_microbatches + + def is_backward_step(i): + assert i < total_steps, i + return i >= warmup_steps and self._has_backward + + def is_1f1b_step(i): + return is_forward_step(i) and is_backward_step(i) + + def is_warmup_step(i): + return is_forward_step(i) and not is_backward_step(i) + + def is_cooldown_step(i): + return not is_forward_step(i) and is_backward_step(i) + + def should_coalesce_fwd_send_bwd_recv(fwd_send_i): + return ( + is_1f1b_step(fwd_send_i) + or (is_warmup_step(fwd_send_i) and is_cooldown_step(fwd_send_i + 1)) + or ( + fwd_send_i >= 1 + and is_warmup_step(fwd_send_i - 1) + and is_cooldown_step(fwd_send_i) + ) + ) + + def should_coalesce_bwd_send_fwd_recv(bwd_send_i): + # The backward send to prev stage should be coalesced with the fwd recv from the previous stage + return bwd_send_i >= warmup_steps and is_1f1b_step(bwd_send_i + 1) + + # bwd chunk counter + bwd_mb_index = 0 + self._stage._configure_data_parallel_mode(last_backward=False) + for i in range(total_steps): + if is_forward_step(i): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops() + desc = "fwd_recv" + if should_coalesce_bwd_send_fwd_recv(i - 1): + desc += "_bwd_send" + ops.extend(self._stage.get_bwd_send_ops()) + + works = sorted_batch_p2p(ops, desc=desc) + for work in works.values(): + work.wait() + + output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + if not should_coalesce_fwd_send_bwd_recv(i): + ops = self._stage.get_fwd_send_ops() + works = sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + if is_backward_step(i): + self._stage._configure_data_parallel_mode( + last_backward=(i == total_steps - 1) + ) + with record_function(f"Backward {bwd_mb_index}"): + ops = self._stage.get_bwd_recv_ops() + desc = "bwd_recv" + if should_coalesce_fwd_send_bwd_recv(i): + ops.extend(self._stage.get_fwd_send_ops()) + desc += "_fwd_send" + + works = sorted_batch_p2p(ops, desc=desc) + for work in works.values(): + work.wait() + + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk(loss=loss) + + if not should_coalesce_bwd_send_fwd_recv(i): + # see Note: coalesced bwd-send/fwd-recv + ops = self._stage.get_bwd_send_ops() + works = sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + bwd_mb_index += 1 + + # Wait for all forward sends to finish + for work in fwd_sends_to_wait: + work.wait() + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + work.wait() + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + +def build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn): + if job_config.experimental.pipeline_parallel_schedule == "1f1b": + schedule_class = Schedule1F1B + elif job_config.experimental.pipeline_parallel_schedule == "gpipe": + schedule_class = ScheduleGPipe + else: + raise NotImplementedError( + f"{job_config.experimental.pipeline_parallel_schedule} is not implemented" + ) + logger.info( + f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}" + ) + return schedule_class( + stage, + n_microbatches=parallel_dims.pp, + loss_fn=loss_fn, + ) + + +def split_stage_fqns(fqns, split_points, stage_id): + """Helper for splitting ordered list of layer names into layers per stage. + + split_points is a list of layer names, each layer will be the first layer in a stage + """ + stages = [] + cur = [] + + for name in fqns: + if name in split_points: + assert len( + cur + ), f"{name} is not a valid split point, do not specify the first layer of stage 0" + stages.append(cur) + cur = [] + cur.append(name) + + stages.append(cur) + return stages[stage_id] diff --git a/train.py b/train.py index 318c7174e..90a745e5a 100644 --- a/train.py +++ b/train.py @@ -32,7 +32,12 @@ from torchtitan.lr_scheduling import get_lr_scheduler from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtitan.parallelisms import models_parallelize_fns, ParallelDims +from torchtitan.parallelisms import ( + models_parallelize_fns, + models_pipelining_fns, + ParallelDims, +) +from torchtitan.parallelisms.pipelining_utils import build_pipeline_schedule from torchtitan.profiling import maybe_enable_profiling from torchtitan.utils import ( Color, @@ -122,11 +127,12 @@ def main(job_config: JobConfig): parallel_dims = ParallelDims( dp=job_config.training.data_parallel_degree, tp=job_config.training.tensor_parallel_degree, - pp=job_config.training.pipeline_parallel_degree, + pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) - torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) init_distributed(job_config) world_mesh = parallel_dims.build_mesh(device_type="cuda") @@ -144,6 +150,10 @@ def main(job_config: JobConfig): dp_rank = dp_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 + + if parallel_dims.pp_enabled: + pp_mesh = world_mesh["pp"] + data_loader = build_hf_data_loader( job_config.training.dataset, job_config.training.dataset_path, @@ -201,13 +211,26 @@ def loss_fn(pred, labels): # obtain the peak flops of bf16 type for MFU calculation gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name) - # apply PT-D parallelisms and activation checkpointing + if parallel_dims.pp_enabled: + stage, model = models_pipelining_fns[model_name]( + model, world_mesh, parallel_dims, job_config, device, model_config + ) + + # apply PT-D DP/TP parallelisms and activation checkpointing model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # allocate sharded model on GPU and initialize weights via DTensor + model.to_empty(device="cuda") - model.init_weights() + + if parallel_dims.pp_enabled: + pp_schedule = build_pipeline_schedule(job_config, parallel_dims, stage, loss_fn) + else: + # If PP is enabled, we can't rely on init_weights, because some layers are missing. + # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation. + + # allocate sharded model on GPU and initialize weights via DTensor + model.init_weights() gpu_mem_stats = gpu_memory_monitor.get_peak_stats() logger.info( @@ -257,7 +280,13 @@ def loss_fn(pred, labels): logger.info("Created seed checkpoint") return - checkpoint.load() + checkpoint_loaded = checkpoint.load() + + if parallel_dims.pp_enabled and not checkpoint_loaded: + raise RuntimeError( + "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. " + "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`" + ) # plot losses loaded from checkpoint (if any) to TensorBoard # NOTE: Loss info after the last log step before checkpoint saving will not be ploted. @@ -299,14 +328,33 @@ def loss_fn(pred, labels): input_ids = input_ids.cuda() labels = labels.cuda() - optimizer.zero_grad() - # forward / backward - with loss_parallel_ctx(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - loss.backward() + if parallel_dims.pp_enabled: + # pipeline parallel forward / backward inside step() call + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 + + with loss_parallel_ctx(): + if pp_mesh.get_local_rank() == 0: + pp_schedule.step(input_ids) + elif is_last_stage: + losses = [] + pp_schedule.step(target=labels, losses=losses) + else: + pp_schedule.step() + + # accumulate losses across pipeline microbatches + loss = ( + torch.mean(torch.stack(losses)) + if is_last_stage + else torch.Tensor([-1.0]) + ) + else: + # Non-PP forward / backward + with loss_parallel_ctx(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + loss.backward() # clip gradients torch.nn.utils.clip_grad_norm_( diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 4541fec7b..009348b5c 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -36,11 +36,13 @@ max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1 tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 fp8_linear = "" compile = false dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M) +[experimental] +pipeline_parallel_degree = 1 + [checkpoint] enable_checkpoint = false folder = "checkpoint"