Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

early all-reduce total_norm in non-PP grad norm clipping #769

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions torchtitan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,16 +384,18 @@ def clip_grad_norm_(
grads, norm_type, error_if_nonfinite, foreach
)

if pp_mesh is not None:
if isinstance(total_norm, DTensor):
# will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor

# if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`
# we can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor
total_norm = total_norm.full_tensor()
# If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`.
# We can simply reduce the DTensor to get the total norm in this tensor's process group
# and then convert it to a local tensor.
# NOTE: It has two purposes:
# 1. to make sure the total norm is computed correctly when PP is used (see below)
# 2. to return a reduced total_norm tensor whose .item() would return the correct value
if isinstance(total_norm, DTensor):
# Will reach here if any non-PP parallelism is used.
# If only using PP, total_norm will be a local tensor.
total_norm = total_norm.full_tensor()

# TODO: cleanup maybe using DTensor
if pp_mesh is not None:
if math.isinf(norm_type):
dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group())
else:
Expand Down
Loading