Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tp mem cache #203

Merged
merged 20 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __post_init__(self):
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
3outeille marked this conversation as resolved.
Show resolved Hide resolved
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class ParallelismArgs:
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
8 changes: 5 additions & 3 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
3outeille marked this conversation as resolved.
Show resolved Hide resolved
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -154,6 +154,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
Expand All @@ -163,8 +164,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)
self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act))

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
Expand Down Expand Up @@ -315,6 +315,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
Expand Down Expand Up @@ -743,6 +744,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableReduceScatterSum.apply(grad_output, group), None
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None


class DifferentiableReduceScatterSum(torch.autograd.Function):
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
Expand Down
83 changes: 79 additions & 4 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch.nn import functional as F

import nanotron.distributed as dist
from nanotron.parallel.utils import MemoryBuffer
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
Expand Down Expand Up @@ -89,10 +89,10 @@ def forward(

@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors

# All the inputs have softmax as thier gradient.
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
Expand Down Expand Up @@ -338,21 +338,96 @@ def backward(ctx, grad_output):
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _ColumnLinearContextParallelNoAsync(torch.autograd.Function):
3outeille marked this conversation as resolved.
Show resolved Hide resolved
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
async communication disabled.
"""

@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_recompute_allgather: bool,
):

# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Prepare context.
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.input_size = input.shape
if tp_recompute_allgather:
ctx.save_for_backward(input, weight, bias)
else:
ctx.save_for_backward(total_input, weight, bias)

# Get linear output.
out = F.linear(total_input, weight, bias)
return out

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Either allgather the inputs again or get them from context.
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if tp_recompute_allgather:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input, weight, bias = ctx.saved_tensors

# Get the grad_output and total_input on the correct views to be able to transpose them below.
grad_output = grad_output.contiguous()
assert grad_output.dim() == 3
grad_output = grad_output.view(grad_output.size(0) * grad_output.size(1), grad_output.size(2))
total_input = total_input.view(total_input.size(0) * total_input.size(1), total_input.size(2))

# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None

return sub_grad_input, grad_weight, grad_bias, None, None



def column_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
3outeille marked this conversation as resolved.
Show resolved Hide resolved
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
input = differentiable_all_gather(input, group=group)
return _ColumnLinearContextParallelNoAsync.apply(input, weight, bias, group, tp_recompute_allgather)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")

Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
tp_recompute_allgather: bool = True,
):
self.pg = pg
self.world_size = pg.size()
Expand All @@ -59,6 +60,7 @@ def __init__(

self.in_features = in_features
self.out_features = out_features // self.world_size
self.tp_recompute_allgather = tp_recompute_allgather

super().__init__(
in_features=self.in_features,
Expand Down Expand Up @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
)

def extra_repr(self) -> str:
Expand Down
20 changes: 20 additions & 0 deletions src/nanotron/parallel/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
import functools
import operator
import os

import torch
from torch import nn

from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
from nanotron.utils import Singleton


class MemoryBuffer(metaclass=Singleton):
"""
Global memory buffer to store intermediate activations that need not to be cached for the backward pass.
"""

def __init__(self):
self.buffer = {}

def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
required_numel = functools.reduce(operator.mul, shape, 1)
if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel:
self.buffer[name, dtype] = torch.empty(
required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[name, dtype][:required_numel].view(shape)


def assert_cuda_max_connections_set_to_1(func):
Expand Down
25 changes: 22 additions & 3 deletions src/nanotron/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import functools
import inspect
import math
import os
import random
import socket
from contextlib import ExitStack, contextmanager
from typing import Callable, ContextManager, List, Optional
from typing import ContextManager, List, Optional

import torch
from packaging import version
Expand All @@ -15,6 +14,25 @@
from nanotron import distributed as dist


class Singleton(type):
"""
Singleton metaclass.
Create objects using this class as the metaclass to enable singleton behaviour.
For instance:
```
class Logger(metaclass=Singleton):
...
```
"""

_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
Expand Down Expand Up @@ -52,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup):
@contextmanager
def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None):
"""Context manager that executes the code in the context with all the local rank zero of the group going first.
Usefull to run only once per node first (e.g. to create local files, etc)
Useful to run only once per node first (e.g. to create local files, etc)
"""
is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0
if is_main:
Expand Down Expand Up @@ -123,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage:
else:
return tensor.storage().untyped()


def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype):
# TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage.
device = untyped_storage.device
Expand Down
Loading