diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 09b1ce4f..c4d1e9c9 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -70,7 +70,7 @@ def estimate_memory(job_config: JobConfig): tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, - enable_loss_parallel=job_config.training.enable_loss_parallel, + enable_loss_parallel=not job_config.training.disable_loss_parallel, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")