Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "[cp] apply fsdp to model when CP is enabled without DP for…
… correct loss and lower mem usage" **Summary** Previously CP forgot to shard the model via `apply_fsdp` when DP is not combined with CP. This leads to high peak memory usage and diverging loss. **Test** 1. modify `train_configs/llama3_8b.toml` ``` steps = 20 context_parallel_degree = 8 ``` 2. run training on 8xH100 GPUs `CONFIG_FILE="./train_configs/llama3_8b.toml" NGPU=8 LOG_RANK=0,1,2,3,4,5,6,7 ./run_llama_train.sh` Before: CUDA OutOfMemory After: successful 20-steps training [ghstack-poisoned]
- Loading branch information