From ae713b7360b814632d5cbfb5d0d215766ef72fb7 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Thu, 2 Jan 2025 13:53:37 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- torchtitan/utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 8a065153..88663c00 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -384,16 +384,18 @@ def clip_grad_norm_( grads, norm_type, error_if_nonfinite, foreach ) - if pp_mesh is not None: - if isinstance(total_norm, DTensor): - # will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor - - # if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial` - # we can simply reduce the DTensor to get the total norm in this tensor's process group - # and then convert it to a local tensor - total_norm = total_norm.full_tensor() + # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. + # We can simply reduce the DTensor to get the total norm in this tensor's process group + # and then convert it to a local tensor. + # NOTE: It has two purposes: + # 1. to make sure the total norm is computed correctly when PP is used (see below) + # 2. to return a reduced total_norm tensor whose .item() would return the correct value + if isinstance(total_norm, DTensor): + # Will reach here if any non-PP parallelism is used. + # If only using PP, total_norm will be a local tensor. + total_norm = total_norm.full_tensor() - # TODO: cleanup maybe using DTensor + if pp_mesh is not None: if math.isinf(norm_type): dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) else: