diff --git a/torchtitan/utils.py b/torchtitan/utils.py index 8a065153..88663c00 100644 --- a/torchtitan/utils.py +++ b/torchtitan/utils.py @@ -384,16 +384,18 @@ def clip_grad_norm_( grads, norm_type, error_if_nonfinite, foreach ) - if pp_mesh is not None: - if isinstance(total_norm, DTensor): - # will reach here if PP + other parallelism is used. If only using PP, total_norm will be a local tensor - - # if total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial` - # we can simply reduce the DTensor to get the total norm in this tensor's process group - # and then convert it to a local tensor - total_norm = total_norm.full_tensor() + # If total_norm is a DTensor, the placements must be `torch.distributed._tensor.ops.math_ops._NormPartial`. + # We can simply reduce the DTensor to get the total norm in this tensor's process group + # and then convert it to a local tensor. + # NOTE: It has two purposes: + # 1. to make sure the total norm is computed correctly when PP is used (see below) + # 2. to return a reduced total_norm tensor whose .item() would return the correct value + if isinstance(total_norm, DTensor): + # Will reach here if any non-PP parallelism is used. + # If only using PP, total_norm will be a local tensor. + total_norm = total_norm.full_tensor() - # TODO: cleanup maybe using DTensor + if pp_mesh is not None: if math.isinf(norm_type): dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) else: