diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index d4950ef6c..c8e3adb3f 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -73,6 +73,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): # wrap the rest layers with FSDP model = wrap(model.cuda()) + # redundant if FSDP is used, but ensure the model is on device consistently regardless with parallelisms were used + model.cuda() + rank0_log("Applied parallelisms to the model...") return model