Skip to content

Commit

Permalink
Moved ColumnLinearNoAsync module for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Jul 17, 2024
1 parent 956fbfd commit b9e9201
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 92 deletions.
90 changes: 0 additions & 90 deletions src/nanotron/parallel/tensor_parallel/column_linear.py

This file was deleted.

78 changes: 76 additions & 2 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import functional as F

import nanotron.distributed as dist
from nanotron.parallel.tensor_parallel.column_linear import column_linear_context_parallel
from nanotron.parallel.utils import MemoryBuffer
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_reduce_sum,
differentiable_identity,
Expand Down Expand Up @@ -338,6 +338,80 @@ def backward(ctx, grad_output):
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _ColumnLinearContextParallelNoAsync(torch.autograd.Function):
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
async communication disabled.
"""

@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_recompute_allgather: bool,
):

# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Prepare context.
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.input_size = input.shape
if tp_recompute_allgather:
ctx.save_for_backward(input, weight, bias)
else:
ctx.save_for_backward(total_input, weight, bias)

# Get linear output.
out = F.linear(total_input, weight, bias)
return out

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Either allgather the inputs again or get them from context.
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if tp_recompute_allgather:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input, weight, bias = ctx.saved_tensors

# Get the grad_output and total_input on the correct views to be able to transpose them below.
grad_output = grad_output.contiguous()
assert grad_output.dim() == 3
grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2))
total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2))

# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None

return sub_grad_input, grad_weight, grad_bias, None, None



def column_linear(
input: torch.Tensor,
weight: torch.Tensor,
Expand All @@ -353,7 +427,7 @@ def column_linear(
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return column_linear_context_parallel(input, weight, bias, group, tp_recompute_allgather)
return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")

Expand Down

0 comments on commit b9e9201

Please sign in to comment.