diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index d6db313fd..1ef4dd99b 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -185,4 +185,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): rank0_log("Applied FSDP to the model...") + # redundant if FSDP is used, but ensure the model is on device consistently regardless with parallelisms were used + model.cuda() return model