Skip to content

Commit

Permalink
Add a 3-stage PP config
Browse files Browse the repository at this point in the history
Pipelining is unique in that there is no need to stick to power-of-2
numbers of stages, and there maybe reasons an odd number is optimal
depending on how you divide up your cluster.

In any case, I only discovered this was not working when I tried to use
it for validation of the 1f1b schedule in a slightly-more-complicated
than 2-stage but simpler than 4-stage setup.

I ran into issues with DCP loading my initial seed checkpoint.

ghstack-source-id: 4d6072eb3e8adc1431afa27fb552c72ba4c26967
Pull Request resolved: #345
  • Loading branch information
wconstab committed May 18, 2024
1 parent edb25d6 commit a1aea74
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
15 changes: 15 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,21 @@ def build_test_list(args):
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_1f1b_3stage/",
"--experimental.pipeline_parallel_degree 3",
"--experimental.pipeline_parallel_split_points layers.1, layers.2",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
],
],
"PP 1D test 1f1b with 3 PP stages",
requires_seed_checkpoint=True,
ngpu=3,
),
OverrideDefinitions(
[
[
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
__all__ = ["Transformer"]

llama2_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16),
"debugmodel": ModelArgs(dim=256, n_layers=3, n_heads=16),
"271M": ModelArgs(dim=1024, n_layers=16, n_heads=8),
"1B": ModelArgs(dim=2048, n_layers=18, n_heads=16),
"7B": ModelArgs(dim=4096, n_layers=32, n_heads=32),
Expand All @@ -29,7 +29,7 @@
}

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16, rope_theta=500000),
"debugmodel": ModelArgs(dim=256, n_layers=3, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
Expand Down

0 comments on commit a1aea74

Please sign in to comment.