From 6dbfb3c85c174fe14365546eb6abf3d41086cd96 Mon Sep 17 00:00:00 2001 From: muhtasham Date: Sat, 7 Dec 2024 22:22:17 +0100 Subject: [PATCH] remove unused tp param --- torchtitan/parallelisms/parallelize_llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 55047c92..150312f5 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -118,7 +118,6 @@ def parallelize_llama( dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], - tp_enabled=parallel_dims.tp_enabled, pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, ) @@ -338,7 +337,6 @@ def apply_fsdp( dp_mesh: DeviceMesh, param_dtype: torch.dtype, reduce_dtype: torch.dtype, - tp_enabled: bool, pp_enabled: bool, cpu_offload: bool = False, ):