Skip to content

Commit

Permalink
Update base for Update on "[cp] apply fsdp to model when CP is enable…
Browse files Browse the repository at this point in the history
…d without DP for correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Dec 5, 2024
1 parent 5c48f38 commit 9955242
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def build_mesh(self, device_type):
if self.dp_replicate > 1 and self.dp_shard > 1: # HSDP
mesh["dp_replicate", "dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")
elif self.dp_shard > 1: # FSDP
mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
mesh["dp_shard", "cp"]._flatten(mesh_dim_name="dp_cp")

return mesh

Expand Down

0 comments on commit 9955242

Please sign in to comment.