From f6a9daa13e04f01adde911ab412c617e8dee6a9d Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Fri, 20 Dec 2024 15:57:01 -0800 Subject: [PATCH] [BE] improve argparser ghstack-source-id: b3fe728d567c7204ba2e9be4d8b34ae850f00ac0 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/759 --- README.md | 2 +- scripts/estimate/estimation.py | 2 +- tests/integration_tests.py | 2 +- torchtitan/config_manager.py | 27 ++++++++------------ torchtitan/parallelisms/parallelize_llama.py | 3 +++ train.py | 7 +++-- train_configs/debug_model.toml | 3 ++- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 82469f4d..d8c75efc 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ You may want to see how the model is defined or how parallelism techniques are a 1. Multi-dimensional composable parallelisms - [FSDP2](docs/fsdp.md) with per-parameter sharding - [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including [async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)) - - Pipeline Parallel + - [Pipeline Parallel](https://discuss.pytorch.org/t/distributed-w-torchtitan-training-with-zero-bubble-pipeline-parallelism/214420) - Context Parallel 2. Selective layer and operator activation checkpointing 3. [Distributed checkpointing](https://discuss.pytorch.org/t/distributed-w-torchtitan-optimizing-checkpointing-efficiency-with-pytorch-dcp/211250) (including async checkpointing) diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 09b1ce4f..c4d1e9c9 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -70,7 +70,7 @@ def estimate_memory(job_config: JobConfig): tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, - enable_loss_parallel=job_config.training.enable_loss_parallel, + enable_loss_parallel=not job_config.training.disable_loss_parallel, ) device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 6ef7fe65..41c9d209 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -367,7 +367,7 @@ def build_test_list(): [ "--checkpoint.enable_checkpoint", "--experimental.pipeline_parallel_degree 2", - "--training.enable_cpu_offload True", + "--training.enable_cpu_offload", "--optimizer.early_step_in_backward", ], ], diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index b92d00a4..d59e34bc 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -78,10 +78,14 @@ def __init__(self): ) self.parser.add_argument( "--job.use_for_integration_test", - default=False, action="store_true", help="Add this config to the integration test suite", ) + self.parser.add_argument( + "--job.print_args", + action="store_true", + help="Print the args to terminal", + ) # profiling configs self.parser.add_argument( @@ -104,7 +108,6 @@ def __init__(self): self.parser.add_argument( "--profiling.enable_memory_snapshot", action="store_true", - default=False, help="Whether to dump memory snapshot", ) self.parser.add_argument( @@ -124,14 +127,12 @@ def __init__(self): self.parser.add_argument( "--metrics.enable_tensorboard", action="store_true", - default=False, help="Whether to log metrics to TensorBoard", ) self.parser.add_argument( - "--metrics.enable_color_printing", + "--metrics.disable_color_printing", action="store_true", - default=True, - help="Whether to enable color printing in logs", + help="Whether to disable color printing in logs", ) self.parser.add_argument( "--metrics.save_tb_folder", @@ -139,6 +140,7 @@ def __init__(self): default="tb", help="Folder to dump TensorBoard states", ) + # TODO: store_true & default=True make impossible for cmd to set it to False self.parser.add_argument( "--metrics.rank_0_only", action="store_true", @@ -152,7 +154,6 @@ def __init__(self): self.parser.add_argument( "--metrics.enable_wandb", action="store_true", - default=False, help="Whether to log metrics to Weights & Biases", ) @@ -191,13 +192,11 @@ def __init__(self): ) self.parser.add_argument( "--optimizer.fused", - default=False, action="store_true", help="Whether the fused implementation(CUDA only) is used.", ) self.parser.add_argument( "--optimizer.early_step_in_backward", - default=False, action="store_true", help=""" Whether to apply optimizer in the backward. Caution, optimizer_in_backward @@ -270,8 +269,7 @@ def __init__(self): ) self.parser.add_argument( "--training.enable_cpu_offload", - type=bool, - default=False, + action="store_true", help=""" Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""", ) @@ -282,14 +280,12 @@ def __init__(self): help="Tensor Parallelism degree. 1 means disabled.", ) self.parser.add_argument( - "--training.enable_loss_parallel", - default=True, + "--training.disable_loss_parallel", action="store_true", help="Whether to apply loss parallel when sequence parallel is enabled", ) self.parser.add_argument( "--experimental.enable_async_tensor_parallel", - default=False, action="store_true", help="Whether to apply async tensor parallel (currently only effective when compile is enabled)", ) @@ -545,13 +541,11 @@ def __init__(self): self.parser.add_argument( "--float8.enable_fsdp_float8_all_gather", action="store_true", - default=False, help="Whether enable float8 all-gather in FSDP", ) self.parser.add_argument( "--float8.precompute_float8_dynamic_scale_for_fsdp", action="store_true", - default=False, help="Whether precompute float8 scales dynamically for FSDP", ) self.parser.add_argument( @@ -607,7 +601,6 @@ def __init__(self): self.parser.add_argument( "--memory_estimation.disable_fake_mode", help="Whether to estimate memory under FakeTensorMode", - default=False, action="store_true", ) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 53f00686..fce22c48 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -107,6 +107,9 @@ def parallelize_llama( if parallel_dims.cp_enabled: logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: if world_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") diff --git a/train.py b/train.py index 8dbe80f5..0b69690a 100644 --- a/train.py +++ b/train.py @@ -36,8 +36,11 @@ def main(job_config: JobConfig): init_logger() logger.info(f"Starting job: {job_config.job.description}") + if job_config.job.print_args: + logger.info(f"Running with args: {job_config.to_dict()}") + # used for colorful printing - color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor + color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color # take control of garbage collection to avoid stragglers gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq) @@ -51,7 +54,7 @@ def main(job_config: JobConfig): tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, world_size=world_size, - enable_loss_parallel=job_config.training.enable_loss_parallel, + enable_loss_parallel=not job_config.training.disable_loss_parallel, ) device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") device_module.set_device(device) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 07fcd338..733bc0ae 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -3,6 +3,7 @@ [job] dump_folder = "./outputs" description = "Llama 3 debug training" +print_args = false use_for_integration_test = true [profiling] @@ -14,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 1 -enable_color_printing = true +disable_color_printing = false enable_tensorboard = false save_tb_folder = "tb" enable_wandb = false