From 9955242f8d30153b02753acbf2327465f3bce751 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:23:16 -0800 Subject: [PATCH] Update base for Update on "[cp] apply fsdp to model when CP is enabled without DP for correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned] --- torchtitan/parallelisms/parallel_dims.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index f609e6f5..15ec8fbd 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -83,7 +83,7 @@ def build_mesh(self, device_type): if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp") elif self.dp_shard > 1: # FSDP - mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp") + mesh["dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp") return mesh