diff --git a/torchtrain/parallelisms/parallelize_llama.py b/torchtrain/parallelisms/parallelize_llama.py index d6db313f..399533e2 100644 --- a/torchtrain/parallelisms/parallelize_llama.py +++ b/torchtrain/parallelisms/parallelize_llama.py @@ -185,4 +185,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, args): rank0_log("Applied FSDP to the model...") + # redundant if FSDP is enabled, but ensure the model is on device regardless of which parallelisms were used + model.cuda() return model