Skip to content

Commit

Permalink
Update on "Configure RNGs appropriately for Pipeline + SPMD"
Browse files Browse the repository at this point in the history
DTensor has existing RNG management. It requires a shared seed for every
rank in its 'world' (SPMD world).  Then it manages offsets per rank
using its own RNG tracker to ensure same or different random values
across ranks depending on the device-mesh and the type of sharding on
the current operation being performed.  (TODO: link to docs)

When used together with pipeline parallelism, it is important to use a
different seed for each separate SPMD world.  E.g. if the user specified
seed 1234, then we can literally use 1234 for all the ranks on PP=0, but
then we should use a different seed (e.g. 1234 + 1) for ranks on PP=1.
This partitions the world into PP separate SPMD worlds and uses a unique
seed for each SPMD world.

Control 'deterministic' mode separately from rng seed

The use case for 'deterministic' mode may be more for debugging, while
users may want to control RNG seeds used for real runs.

[ghstack-poisoned]
  • Loading branch information
wconstab committed Dec 9, 2024
2 parents 738d5cf + 66da249 commit 2c90b39
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 4 deletions.
3 changes: 1 addition & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,7 @@ def __init__(self):
)
self.parser.add_argument(
"--training.deterministic",
type=bool,
default=False,
action="store_true",
help="Use deterministic algorithms wherever possible, may be slower",
)
# checkpointing configs
Expand Down
2 changes: 0 additions & 2 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def parallelize_llama(
dp_mesh,
param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param],
reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce],
tp_enabled=parallel_dims.tp_enabled,
pp_enabled=parallel_dims.pp_enabled,
cpu_offload=job_config.training.enable_cpu_offload,
)
Expand Down Expand Up @@ -338,7 +337,6 @@ def apply_fsdp(
dp_mesh: DeviceMesh,
param_dtype: torch.dtype,
reduce_dtype: torch.dtype,
tp_enabled: bool,
pp_enabled: bool,
cpu_offload: bool = False,
):
Expand Down

0 comments on commit 2c90b39

Please sign in to comment.