Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory optimization in async tp-linear #208

Merged
merged 29 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bcf405d
Implemented global memory buffer to reduce activation memory of diffe…
AleHD Jun 27, 2024
ed1ca7d
GLU fusion
AleHD Jun 27, 2024
9b0de5b
precommit
AleHD Jun 27, 2024
bbc259f
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 11, 2024
803b6da
Wrong backward fixed
AleHD Jul 16, 2024
59bfb6b
Removed useless prints
AleHD Jul 16, 2024
2c69e9a
Minor fixes
AleHD Jul 17, 2024
30439fd
precommit
AleHD Jul 17, 2024
1e02a9c
Added tp_recompute_allgather option
AleHD Jul 17, 2024
9cc81bb
Changed recompute default
AleHD Jul 17, 2024
956fbfd
Changed recompute default
AleHD Jul 17, 2024
b9e9201
Moved ColumnLinearNoAsync module for consistency
AleHD Jul 17, 2024
cd2ff64
Merge branch 'fix-row-parallel' of github.com:C-TC/nanotron into asyn…
AleHD Jul 17, 2024
25acc0e
Merge branch 'async_fix' into mem_fix_async
AleHD Jul 17, 2024
7cc6653
memory efficient async linear
AleHD Jul 18, 2024
cb0f260
precommit
AleHD Jul 18, 2024
6d85d03
Added no_recompute_allgather mode to async
AleHD Jul 18, 2024
49633df
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 23, 2024
2afd007
Fixed List not found
AleHD Jul 23, 2024
81e7a54
Merge branch 'main' into mem_fix_async
AleHD Jul 23, 2024
7e758db
Fixed tp=1 case
AleHD Jul 23, 2024
ce2a96b
Merge branch 'main' into fix_tp_mem_cache
AleHD Jul 30, 2024
cd84d4f
Fixed column parallel
AleHD Jul 30, 2024
d3db06a
Added tp_recompute_allgather test
AleHD Jul 30, 2024
6f82050
Merge branch 'fix_tp_mem_cache' into mem_fix_async
AleHD Jul 30, 2024
4c94b99
Added tp_recompute_allgather test
AleHD Jul 30, 2024
7daa186
Minor restyling
AleHD Aug 2, 2024
31c3c5a
Fixed names
AleHD Aug 2, 2024
0adb368
Merge pull request #1 from AleHD/fix_tp_mem_cache
AleHD Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __post_init__(self):
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
xrsrke marked this conversation as resolved.
Show resolved Hide resolved
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class ParallelismArgs:
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
8 changes: 5 additions & 3 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -154,6 +154,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
Expand All @@ -163,8 +164,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)
self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act))

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
Expand Down Expand Up @@ -315,6 +315,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
Expand Down Expand Up @@ -743,6 +744,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableReduceScatterSum.apply(grad_output, group), None
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None


class DifferentiableReduceScatterSum(torch.autograd.Function):
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
Expand Down
Loading
Loading