Skip to content

Commit

Permalink
Fixed tp=1 case
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Jul 23, 2024
1 parent 2afd007 commit 7e758db
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,14 @@ def forward(
# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
if group.size() == 1:
total_input = input.contiguous()
elif tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Prepare context.
ctx.group = group
Expand All @@ -383,21 +386,22 @@ def backward(ctx, grad_output: torch.Tensor):
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if tp_recompute_allgather:
if group.size() == 1 or not tp_recompute_allgather:
total_input, weight, bias = ctx.saved_tensors
else:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input, weight, bias = ctx.saved_tensors

# Get the grad_output and total_input on the correct views to be able to transpose them below.
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.contiguous()
assert grad_output.dim() == 3
grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2))
total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2))
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_tensor_first_dims, total_tensor_last_dim = total_tensor.shape[:-1], total_tensor.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim)

# Compute gradients.
grad_weight = grad_output.T @ total_input
Expand Down

0 comments on commit 7e758db

Please sign in to comment.