Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization in async tp-linear #208

Merged
merged 29 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bcf405d
Implemented global memory buffer to reduce activation memory of diffe…
AleHD Jun 27, 2024
ed1ca7d
GLU fusion
AleHD Jun 27, 2024
9b0de5b
precommit
AleHD Jun 27, 2024
bbc259f
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 11, 2024
803b6da
Wrong backward fixed
AleHD Jul 16, 2024
59bfb6b
Removed useless prints
AleHD Jul 16, 2024
2c69e9a
Minor fixes
AleHD Jul 17, 2024
30439fd
precommit
AleHD Jul 17, 2024
1e02a9c
Added tp_recompute_allgather option
AleHD Jul 17, 2024
9cc81bb
Changed recompute default
AleHD Jul 17, 2024
956fbfd
Changed recompute default
AleHD Jul 17, 2024
b9e9201
Moved ColumnLinearNoAsync module for consistency
AleHD Jul 17, 2024
cd2ff64
Merge branch 'fix-row-parallel' of github.com:C-TC/nanotron into asyn…
AleHD Jul 17, 2024
25acc0e
Merge branch 'async_fix' into mem_fix_async
AleHD Jul 17, 2024
7cc6653
memory efficient async linear
AleHD Jul 18, 2024
cb0f260
precommit
AleHD Jul 18, 2024
6d85d03
Added no_recompute_allgather mode to async
AleHD Jul 18, 2024
49633df
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 23, 2024
2afd007
Fixed List not found
AleHD Jul 23, 2024
81e7a54
Merge branch 'main' into mem_fix_async
AleHD Jul 23, 2024
7e758db
Fixed tp=1 case
AleHD Jul 23, 2024
ce2a96b
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 30, 2024
cd84d4f
Fixed column parallel
AleHD Jul 30, 2024
d3db06a
Added tp_recompute_allgather test
AleHD Jul 30, 2024
6f82050
Merge branch 'fix_tp_mem_cache' into mem_fix_async
AleHD Jul 30, 2024
4c94b99
Added tp_recompute_allgather test
AleHD Jul 30, 2024
7daa186
Minor restyling
AleHD Aug 2, 2024
31c3c5a
Fixed names
AleHD Aug 2, 2024
0adb368
Merge pull request #1 from AleHD/fix_tp_mem_cache
AleHD Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __post_init__(self):
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
xrsrke marked this conversation as resolved.
Show resolved Hide resolved
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ParallelismArgs:
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
6 changes: 4 additions & 2 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
Expand All @@ -164,8 +165,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)
self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act))

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
Expand Down Expand Up @@ -316,6 +316,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
Expand Down Expand Up @@ -765,6 +766,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableReduceScatterSum.apply(grad_output, group), None
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None


class DifferentiableReduceScatterSum(torch.autograd.Function):
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
Expand Down
183 changes: 135 additions & 48 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@

import nanotron.distributed as dist
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1
from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1


class _ShardedCrossEntropy(torch.autograd.Function):
Expand Down Expand Up @@ -89,10 +88,10 @@ def forward(

@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors

# All the inputs have softmax as thier gradient.
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
Expand Down Expand Up @@ -121,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function):

@staticmethod
@assert_cuda_max_connections_set_to_1
def forward(ctx, tensor, weight, bias, group, tp_mode):
def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather):
ctx.use_bias = bias is not None
ctx.tp_mode = tp_mode
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.tensor_shape = tensor.size()

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
gathered_tensor = tensor
Expand All @@ -141,22 +142,27 @@ def forward(ctx, tensor, weight, bias, group, tp_mode):
# `tensor` can sometimes not be contiguous
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317
tensor = tensor.contiguous()
ctx.save_for_backward(tensor, weight)
# ctx.save_for_backward(tensor, weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comments


# TODO @thomasw21: gather along another dimension
sharded_batch_size, *intermediate_size, hidden_size = tensor.shape
if group is None:
group = dist.distributed_c10d._get_default_group()
gathered_batch_size = sharded_batch_size * group.size()

gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
)
if tp_recompute_allgather:
gathered_tensor = MemoryBuffer().get(
"allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype
)
else:
gathered_tensor = torch.empty(
gathered_batch_size,
*intermediate_size,
hidden_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
)

handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True)

Expand Down Expand Up @@ -204,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode):

# Wait communication
handle.wait()
if tp_recompute_allgather:
ctx.save_for_backward(tensor, weight)
else:
ctx.save_for_backward(gathered_tensor, weight)

# Compute all the other shards that are obtained from AllGather
# weights: w0 w1 w2 w3
Expand Down Expand Up @@ -261,8 +271,8 @@ def backward(ctx, grad_output):
use_bias = ctx.use_bias
tp_mode = ctx.tp_mode

handle: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
handle1: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather:
# TODO @thomasw21: gather along another dimension
sharded_batch_size, *rest_size = tensor.shape
if group is None:
Expand All @@ -273,14 +283,10 @@ def backward(ctx, grad_output):
else:
unsharded_batch_size = sharded_batch_size * group.size()

unsharded_tensor = torch.empty(
unsharded_batch_size,
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=False,
unsharded_tensor = MemoryBuffer().get(
"allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)
handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the tensor gradient computation
total_tensor = unsharded_tensor
Expand All @@ -289,9 +295,6 @@ def backward(ctx, grad_output):

grad_tensor = grad_output.matmul(weight)

if handle is not None:
handle.wait()

# Doing gather + slicing during the NeMo forward pass can make this tensor
# not be contiguous. PyTorch only checks if the tensor is contiguous, and only
# clones it if it's not contiguous:
Expand All @@ -303,60 +306,148 @@ def backward(ctx, grad_output):
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim)

handle: Optional[dist.Work] = None
handle2: Optional[dist.Work] = None
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
if group.size() == 1:
sub_grad_tensor = grad_tensor
else:
sub_grad_tensor = torch.empty(
tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False
)
# reduce_scatter
handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
# Asynchronous all-reduce
handle = dist.all_reduce(grad_tensor, group=group, async_op=True)
handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
else:
raise ValueError()

grad_bias = grad_output.sum(dim=0) if use_bias else None

if handle1 is not None:
handle1.wait()

# TODO @thomasw21: This sounds like we don't have the optimal physical layout
grad_weight = grad_output.t().matmul(total_tensor)
grad_bias = grad_output.sum(dim=0) if use_bias else None

if handle is not None:
handle.wait()
if handle2 is not None:
handle2.wait()

if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return sub_grad_tensor, grad_weight, grad_bias, None, None
return sub_grad_tensor, grad_weight, grad_bias, None, None, None
elif tp_mode is TensorParallelLinearMode.ALL_REDUCE:
return grad_tensor, grad_weight, grad_bias, None, None
return grad_tensor, grad_weight, grad_bias, None, None, None
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function):
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
async communication disabled.
"""

@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_recompute_allgather: bool,
):

# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if group.size() == 1:
total_input = input.contiguous()
elif tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Prepare context.
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.input_size = input.shape
if tp_recompute_allgather:
ctx.save_for_backward(input, weight, bias)
else:
ctx.save_for_backward(total_input, weight, bias)

# Get linear output.
out = F.linear(total_input, weight, bias)
return out

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Either allgather the inputs again or get them from context.
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if group.size() == 1 or not tp_recompute_allgather:
total_input, weight, bias = ctx.saved_tensors
else:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.contiguous()
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim)

# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
if group.size() == 1:
sub_grad_input = grad_input
else:
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
# We set grad_input to be contiguous in case it isn't already.
grad_input = grad_input.contiguous()
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None

return sub_grad_input, grad_weight, grad_bias, None, None


def column_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
input = differentiable_all_gather(input, group=group)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")

return F.linear(input, weight, bias)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
input, weight, bias, group, tp_recompute_allgather
)
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _RowLinearAsyncCommunication(torch.autograd.Function):
Expand Down Expand Up @@ -397,12 +488,8 @@ def backward(ctx, grad_output):
else:
unsharded_batch_size = sharded_batch_size * group.size()

total_grad_output = torch.empty(
unsharded_batch_size,
*rest_size,
device=grad_output.device,
dtype=grad_output.dtype,
requires_grad=False,
total_grad_output = MemoryBuffer().get(
"allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype
)

# Doing gather + slicing during the NeMo forward pass can make this tensor
Expand Down
Loading
Loading