Skip to content

Commit

Permalink
Add Pipeline Parallel (and 2D PP+FSDP) support
Browse files Browse the repository at this point in the history
runs PP+DP and PP+TP without issue,
runs PP+TP+DP with decreasing loss, but fails DCP save

Supports only simple schedules currently, gpipe and 1f1b.

Ads cmdline/toml arg for specifiying split points, in a unified
way between tracer or manual frontend.

  e.g. user can specifiy "layers.2,layers.4" as split points.

Currently uses manual frontend by default, but allows specifying
tracer frontend.  Tracer frontend requires working around additional
compatibility limitations, indicated by raising assertions, and is
not ready for wider use  yet.

ghstack-source-id: ecde6ef7c3453c14e78707404fb1d4f4ace89f1b
Pull Request resolved: #318
  • Loading branch information
wconstab committed May 20, 2024
1 parent d9c0a02 commit 3b05ca9
Show file tree
Hide file tree
Showing 8 changed files with 666 additions and 34 deletions.
2 changes: 1 addition & 1 deletion create_seed_checkpoint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ LOG_RANK=0
CONFIG_FILE=${CONFIG_FILE:-"./train_configs/debug_model.toml"}

seed_checkpoint="--checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint"
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --training.pipeline_parallel_degree 1"
force_1d="--training.data_parallel_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1"
overrides=""
if [ $# -ne 0 ]; then
overrides="$*"
Expand Down
110 changes: 102 additions & 8 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class OverrideDefinitions:

override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
test_descr: str = "default"
requires_seed_checkpoint: bool = False
ngpu: int = 4


def build_test_list(args):
Expand All @@ -35,6 +37,78 @@ def build_test_list(args):
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_1f1b/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
],
],
"PP 1D test 1f1b",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_gpipe/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
],
],
"PP 1D test gpipe",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_1f1b/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
],
],
"PP+DP 1f1b 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_gpipe/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
],
],
"PP+DP gpipe 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_tp/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+TP 2D test",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
Expand Down Expand Up @@ -100,22 +174,42 @@ def build_test_list(args):
return integration_tests_flavors


def _run_cmd(cmd):
return subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)


def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
)
result = subprocess.run(
[cmd],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
shell=True,
)

if test_flavor.requires_seed_checkpoint:
dump_folder_arg = None
for arg in override_arg:
if "--job.dump_folder" in arg:
dump_folder_arg = arg
assert (
dump_folder_arg is not None
), "Can't use seed checkpoint if folder is not specified"
print("Creating seed checkpoint")
result = _run_cmd(
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {dump_folder_arg}"
)
print(result.stdout)

result = _run_cmd(cmd)
print(result.stdout)
if result.returncode != 0:
raise Exception(
Expand Down
58 changes: 56 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def torch_dtype(dtype_str: str) -> torch.dtype:
return DTYPE_MAP[dtype_str]


def string_list(raw_arg):
return raw_arg.split(",")


class JobConfig:
"""
A helper class to manage the train configuration.
Expand Down Expand Up @@ -214,10 +218,56 @@ def __init__(self):
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--training.pipeline_parallel_degree",
"--experimental.pipeline_parallel_degree",
type=int,
default=1,
help="Pipeline Parallelism degree. 1 means disabled.",
help="""
Pipeline Parallelism degree, or number of ranks. 1 means disabled.
If using looped schedules, this still specifies the number of physical ranks, not the number
of stages. Stages per rank are inferred from split points degree, and schedule.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_points",
type=string_list,
nargs="+",
default=[],
help="""
Specify comma-separated names of modules to use as the beginning of a split point.
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
the first containing all the layers up to layers.0,
the second containing layers.0 and up to layers.2,
the third containing layers.2 and all the remaining layers.
Note: fully-automated splitting may be enabled in the future,
but currently the split points must be specified manually for both manual and tracer.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules are not yet supported in torchtitan.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_mode",
type=str,
choices=["manual", "tracer"],
default="manual",
help="""
Specify the split method (e.g. the Pipeline Parallelism Front End)
"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
implementation manually, and then wrap it in a PipelineStage.
"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
split via the provided split points, unflattened into an nn.Module,
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
Expand Down Expand Up @@ -441,6 +491,10 @@ def parse_args_from_command_line(
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "experimental.pipeline_parallel_split_points":
# type inference breaks here, since the type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down
6 changes: 5 additions & 1 deletion torchtitan/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@

from torch.distributed.device_mesh import init_device_mesh
from torchtitan.logging_utils import logger
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.parallelize_llama import parallelize_llama, pipeline_llama

models_parallelize_fns = {
"llama2": parallelize_llama,
"llama3": parallelize_llama,
}
models_pipelining_fns = {
"llama2": pipeline_llama,
"llama3": pipeline_llama,
}


@dataclass
Expand Down
Loading

0 comments on commit 3b05ca9

Please sign in to comment.