Skip to content

Commit

Permalink
Add 3D support
Browse files Browse the repository at this point in the history
ghstack-source-id: 95635f146ab321cc7e6683106def62c15494890a
Pull Request resolved: #344
  • Loading branch information
wconstab committed May 20, 2024
1 parent 3b05ca9 commit 4c62d1f
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_test_4gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 4
35 changes: 35 additions & 0 deletions .github/workflows/unit_test_8gpu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: 8 GPU Unit Test

on:
push:
branches: [ main ]
pull_request:

concurrency:
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
cancel-in-progress: true

jobs:
build-test:
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
runner: linux.g5.48xlarge.nvidia.gpu
gpu-arch-type: cuda
gpu-arch-version: "12.1"
# This image is faster to clone than the default, but it lacks CC needed by triton
# (1m25s vs 2m37s).
docker-image: torchtitan-ubuntu-20.04-clang12
repository: pytorch/torchtitan
upload-artifact: outputs
script: |
set -eux
# The generic Linux job chooses to use base env, not the one setup by the image
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
conda activate "${CONDA_ENV}"
pip config --user set global.progress_bar off
python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
mkdir artifacts-to-be-uploaded
python ./test_runner.py artifacts-to-be-uploaded --ngpu 8
28 changes: 26 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class OverrideDefinitions:
requires_seed_checkpoint: bool = False
ngpu: int = 4

def __repr__(self):
return self.test_descr


def build_test_list(args):
"""
Expand Down Expand Up @@ -170,6 +173,22 @@ def build_test_list(args):
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
f"--job.dump_folder {args.output_dir}/pp_dp_tp/",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.1",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+DP+TP 3D test",
requires_seed_checkpoint=True,
ngpu=8,
),
]
return integration_tests_flavors

Expand All @@ -188,7 +207,8 @@ def run_test(test_flavor: OverrideDefinitions, full_path: str):
# run_test supports sequence of tests.
for override_arg in test_flavor.override_args:

cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
all_ranks = ",".join(map(str, range(test_flavor.ngpu)))
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} ./run_llama_train.sh"
if override_arg:
cmd += " " + " ".join(override_arg)
print(
Expand Down Expand Up @@ -229,13 +249,17 @@ def run_tests(args):
)
if is_integration_test:
for test_flavor in integration_tests_flavors[config_file]:
run_test(test_flavor, full_path)
if (args.ngpu == 8 and test_flavor.ngpu == 8) or (
args.ngpu == 4 and test_flavor.ngpu <= 4
):
run_test(test_flavor, full_path)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("output_dir")
parser.add_argument("--config_dir", default="./train_configs")
parser.add_argument("--ngpu", default=4, type=int)
args = parser.parse_args()

if not os.path.exists(args.output_dir):
Expand Down

0 comments on commit 4c62d1f

Please sign in to comment.