Skip to content

Commit

Permalink
Update on "[cp] apply fsdp to model when CP is enabled without DP for…
Browse files Browse the repository at this point in the history
… correct loss and lower mem usage"


**Summary**
Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss.

**Test**
1. modify `train_configs/llama3_8b.toml`
```
steps = 20
context_parallel_degree = 8
```
2.  run training on 8xH100 GPUs
`CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh`
Before: CUDA OutOfMemory 
After: successful 20-steps training


[ghstack-poisoned]
  • Loading branch information
XilunWu committed Dec 5, 2024
2 parents b168683 + 5c48f38 commit 82f0cd6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def build_mesh(self, device_type):
if self.dp_shard_enabled:
dp_mesh_dim_names.append("dp_shard")

mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
if dp_mesh_dim_names != []:
mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")

# Mesh for param sharding
dp_shard_cp_mesh_dim_name = []
Expand Down

0 comments on commit 82f0cd6

Please sign in to comment.