Skip to content

Commit

Permalink
Merge pull request #207 from C-TC/recompute
Browse files Browse the repository at this point in the history
Add layer-wise activation recomputation to llama model
  • Loading branch information
3outeille authored Jul 14, 2024
2 parents f1adf52 + 7e15516 commit 4c23ed0
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParallelismArgs:
pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
tp_linear_async_communication: Whether to use async communication in TP linear layers
recompute_layer: Whether to recompute each Transformer layer to save memory.
"""

dp: int
Expand All @@ -31,6 +32,7 @@ class ParallelismArgs:
pp_engine: Optional[PipelineEngine] = None
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

expert_parallel_size: int = 1

Expand Down
32 changes: 27 additions & 5 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from torch import nn
from torch.utils.checkpoint import CheckpointFunction

from nanotron import distributed as dist
from nanotron import logging
Expand Down Expand Up @@ -618,12 +619,14 @@ def __init__(

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

def forward(

self.recompute_layer = parallel_config.recompute_layer

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
) -> List[Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

Expand All @@ -636,12 +639,31 @@ def forward(
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual

return hidden_states, output["sequence_mask"]

def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)

def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:

if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
else:
hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)

return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
"sequence_mask": sequence_mask,
}


class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
Expand Down

0 comments on commit 4c23ed0

Please sign in to comment.