From 5c48f38b3e346ac8762bc4319abc1e28fb4eaafe Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 4 Dec 2024 16:51:15 -0800 Subject: [PATCH] Update base for Update on "[cp] apply fsdp to model when CP is enabled without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned] --- torchtitan/parallelisms/parallel_dims.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 87d982fa..f609e6f5 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -76,7 +76,8 @@ def build_mesh(self, device_type): if self.dp_shard_enabled: dp_mesh_dim_names.append("dp_shard") - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + if dp_mesh_dim_names != []: + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") if self.cp > 1: if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP