diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index fdef48ac..1c4db5de 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -387,8 +387,7 @@ def backward(ctx, grad_output): group = ctx.group use_bias = ctx.use_bias - handle_0: Optional[dist.Work] = None - handle_1: Optional[dist.Work] = None + handle: Optional[dist.Work] = None # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = grad_output.shape @@ -412,31 +411,69 @@ def backward(ctx, grad_output): # https://github.com/pytorch/pytorch/blob/c47cf9bc7f9e02f649ab4ed53fe4d35732c92ab6/torch/_refs/__init__.py#L2761 grad_output = grad_output.contiguous() - handle_0 = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - - grad_tensor = grad_output.matmul(weight) + handle = dist.all_gather_into_tensor(total_grad_output, grad_output, group=group, async_op=True) - # wait for the first all_gather to finish before starting the second all_gather - if handle_0 is not None: - handle_0.wait() - - # TODO @thomasw21: gather along another dimension - sharded_batch_size, *rest_size = grad_tensor.shape + # total_grad_output: [b, s, h_out] + # weight: [h_out, h_in/n] + # total_grad_tensor: [b, s, h_in/n] + # grad_output: [b/n, s, h_out] + sharded_batch_size, *rest_size_grad_output = grad_output.shape + rest_size_grad_tensor = rest_size_grad_output[:-1] + [weight.shape[1]] if group.size() == 1: - total_grad_tensor = grad_tensor + total_grad_tensor = grad_output.matmul(weight) else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_tensor = torch.empty( unsharded_batch_size, - *rest_size, - device=grad_tensor.device, - dtype=grad_tensor.dtype, + *rest_size_grad_tensor, + device=grad_output.device, + dtype=grad_output.dtype, requires_grad=False, ) + before_shard_grad_tensor, same_device_shard_grad_tensor, after_shard_grad_tensor = torch.split( + total_grad_tensor, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + # compute local shard + torch.mm( + input=grad_output.view(-1, grad_output.shape[-1]), + mat2=weight, + out=same_device_shard_grad_tensor.view(-1, weight.shape[1]), + ) - handle_1 = dist.all_gather_into_tensor(total_grad_tensor, grad_tensor, group=group, async_op=True) + if handle is not None: + handle.wait() + + before_shard_grad_output, _, after_shard_grad_output = torch.split( + total_grad_output, + split_size_or_sections=[ + sharded_batch_size * dist.get_rank(group), + sharded_batch_size, + sharded_batch_size * (group.size() - dist.get_rank(group) - 1), + ], + dim=0, + ) + + # before shard compute + if before_shard_grad_tensor.numel() > 0: + torch.mm( + input=before_shard_grad_output.view(-1, before_shard_grad_output.shape[-1]), + mat2=weight, + out=before_shard_grad_tensor.view(-1, weight.shape[1]), + ) + # after shard compute + if after_shard_grad_tensor.numel() > 0: + torch.mm( + input=after_shard_grad_output.view(-1, after_shard_grad_output.shape[-1]), + mat2=weight, + out=after_shard_grad_tensor.view(-1, weight.shape[1]), + ) # Convert the tensor shapes to 2D for execution compatibility tensor = tensor.contiguous() @@ -454,9 +491,6 @@ def backward(ctx, grad_output): grad_weight = total_grad_output.t().matmul(tensor) grad_bias = total_grad_output.sum(dim=0) if use_bias else None - if handle_1 is not None: - handle_1.wait() - return total_grad_tensor, grad_weight, grad_bias, None, None diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 127ba2fa..f5dcaeb0 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -208,14 +208,19 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL random_input = torch.randn(batch_size, in_features, device="cuda") # synchronize random_input across tp dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) - + random_input.requires_grad = True # Row linear receives as input sharded input - random_sharded_input = random_input[ - :, - dist.get_rank(parallel_context.tp_pg) - * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) - * in_features_per_rank, - ] + random_sharded_input = ( + random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + .detach() + .clone() + ) + random_sharded_input.requires_grad = True # Test that we get the same output after forward pass # TODO @kunhao: We may want to have our custom error type @@ -261,6 +266,16 @@ def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelL else: assert row_linear.bias is None + torch.testing.assert_close( + random_sharded_input.grad, + random_input.grad[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ], + ) + parallel_context.destroy()