Skip to content

Commit

Permalink
[BE] improve argparser
Browse files Browse the repository at this point in the history
ghstack-source-id: b3fe728d567c7204ba2e9be4d8b34ae850f00ac0
Pull Request resolved: #759
  • Loading branch information
tianyu-l committed Dec 23, 2024
1 parent ba24697 commit f6a9daa
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ You may want to see how the model is defined or how parallelism techniques are a
1. Multi-dimensional composable parallelisms
- [FSDP2](docs/fsdp.md) with per-parameter sharding
- [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including [async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487))
- Pipeline Parallel
- [Pipeline Parallel](https://discuss.pytorch.org/t/distributed-w-torchtitan-training-with-zero-bubble-pipeline-parallelism/214420)
- Context Parallel
2. Selective layer and operator activation checkpointing
3. [Distributed checkpointing](https://discuss.pytorch.org/t/distributed-w-torchtitan-optimizing-checkpointing-efficiency-with-pytorch-dcp/211250) (including async checkpointing)
Expand Down
2 changes: 1 addition & 1 deletion scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def estimate_memory(job_config: JobConfig):
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
enable_loss_parallel=not job_config.training.disable_loss_parallel,
)

device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def build_test_list():
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--training.enable_cpu_offload True",
"--training.enable_cpu_offload",
"--optimizer.early_step_in_backward",
],
],
Expand Down
27 changes: 10 additions & 17 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,14 @@ def __init__(self):
)
self.parser.add_argument(
"--job.use_for_integration_test",
default=False,
action="store_true",
help="Add this config to the integration test suite",
)
self.parser.add_argument(
"--job.print_args",
action="store_true",
help="Print the args to terminal",
)

# profiling configs
self.parser.add_argument(
Expand All @@ -104,7 +108,6 @@ def __init__(self):
self.parser.add_argument(
"--profiling.enable_memory_snapshot",
action="store_true",
default=False,
help="Whether to dump memory snapshot",
)
self.parser.add_argument(
Expand All @@ -124,21 +127,20 @@ def __init__(self):
self.parser.add_argument(
"--metrics.enable_tensorboard",
action="store_true",
default=False,
help="Whether to log metrics to TensorBoard",
)
self.parser.add_argument(
"--metrics.enable_color_printing",
"--metrics.disable_color_printing",
action="store_true",
default=True,
help="Whether to enable color printing in logs",
help="Whether to disable color printing in logs",
)
self.parser.add_argument(
"--metrics.save_tb_folder",
type=str,
default="tb",
help="Folder to dump TensorBoard states",
)
# TODO: store_true & default=True make impossible for cmd to set it to False
self.parser.add_argument(
"--metrics.rank_0_only",
action="store_true",
Expand All @@ -152,7 +154,6 @@ def __init__(self):
self.parser.add_argument(
"--metrics.enable_wandb",
action="store_true",
default=False,
help="Whether to log metrics to Weights & Biases",
)

Expand Down Expand Up @@ -191,13 +192,11 @@ def __init__(self):
)
self.parser.add_argument(
"--optimizer.fused",
default=False,
action="store_true",
help="Whether the fused implementation(CUDA only) is used.",
)
self.parser.add_argument(
"--optimizer.early_step_in_backward",
default=False,
action="store_true",
help="""
Whether to apply optimizer in the backward. Caution, optimizer_in_backward
Expand Down Expand Up @@ -270,8 +269,7 @@ def __init__(self):
)
self.parser.add_argument(
"--training.enable_cpu_offload",
type=bool,
default=False,
action="store_true",
help="""
Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP""",
)
Expand All @@ -282,14 +280,12 @@ def __init__(self):
help="Tensor Parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.enable_loss_parallel",
default=True,
"--training.disable_loss_parallel",
action="store_true",
help="Whether to apply loss parallel when sequence parallel is enabled",
)
self.parser.add_argument(
"--experimental.enable_async_tensor_parallel",
default=False,
action="store_true",
help="Whether to apply async tensor parallel (currently only effective when compile is enabled)",
)
Expand Down Expand Up @@ -545,13 +541,11 @@ def __init__(self):
self.parser.add_argument(
"--float8.enable_fsdp_float8_all_gather",
action="store_true",
default=False,
help="Whether enable float8 all-gather in FSDP",
)
self.parser.add_argument(
"--float8.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
default=False,
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
Expand Down Expand Up @@ -607,7 +601,6 @@ def __init__(self):
self.parser.add_argument(
"--memory_estimation.disable_fake_mode",
help="Whether to estimate memory under FakeTensorMode",
default=False,
action="store_true",
)

Expand Down
3 changes: 3 additions & 0 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def parallelize_llama(

if parallel_dims.cp_enabled:
logger.info("Applied Context Parallel to the model")

if job_config.training.enable_cpu_offload:
logger.info("Applied CPU Offloading to the model")
elif parallel_dims.dp_replicate_enabled:
if world_mesh.ndim > 1:
raise RuntimeError("DDP has not supported > 1D parallelism")
Expand Down
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ def main(job_config: JobConfig):
init_logger()
logger.info(f"Starting job: {job_config.job.description}")

if job_config.job.print_args:
logger.info(f"Running with args: {job_config.to_dict()}")

# used for colorful printing
color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor
color = utils.NoColor if job_config.metrics.disable_color_printing else utils.Color

# take control of garbage collection to avoid stragglers
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
Expand All @@ -51,7 +54,7 @@ def main(job_config: JobConfig):
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
enable_loss_parallel=not job_config.training.disable_loss_parallel,
)
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
device_module.set_device(device)
Expand Down
3 changes: 2 additions & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
[job]
dump_folder = "./outputs"
description = "Llama 3 debug training"
print_args = false
use_for_integration_test = true

[profiling]
Expand All @@ -14,7 +15,7 @@ save_memory_snapshot_folder = "memory_snapshot"

[metrics]
log_freq = 1
enable_color_printing = true
disable_color_printing = false
enable_tensorboard = false
save_tb_folder = "tb"
enable_wandb = false
Expand Down

0 comments on commit f6a9daa

Please sign in to comment.