Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference] T5 Tensor Parallel support #361

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 98 additions & 22 deletions optimum/exporters/neuron/model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

from typing import TYPE_CHECKING, List, Optional

import neuronx_distributed
import torch
from transformers.models.t5.modeling_t5 import T5LayerCrossAttention

from ...utils import NormalizedConfigManager


if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -129,17 +132,60 @@ def __init__(
super().__init__()
self.model = model
self.config = model.config
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type)
self.num_beams = num_beams
self.device = device
self.tp_degree = tp_degree
if self.tp_degree is not None:
self.num_attention_heads_per_partition = (
self.normalized_config.num_attention_heads
// neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size()
)
self.past_key_values_sa = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.ones(
(
self.num_beams,
self.num_attention_heads_per_partition,
self.max_length - 1,
self.normalized_config.key_value_dim,
),
dtype=torch.float32,
),
requires_grad=False,
)
for _ in range(self.config.num_decoder_layers * 2)
]
)
self.past_key_values_ca = torch.nn.ParameterList(
[
torch.nn.Parameter(
torch.ones(
(
self.num_beams,
self.num_attention_heads_per_partition,
self.max_length,
self.normalized_config.key_value_dim,
),
dtype=torch.float32,
),
requires_grad=False,
)
for _ in range(self.config.num_decoder_layers * 2)
]
)

def forward(self, input_ids, attention_mask):
# Infer shapes
batch_size = input_ids.shape[0]
sequence_length = input_ids.shape[1]

encoder_output = self.model.encoder(
input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
)

last_hidden_state = encoder_output["last_hidden_state"]
Expand All @@ -151,7 +197,7 @@ def forward(self, input_ids, attention_mask):
present_key_value_states_sa = []
present_key_value_states_ca = []

for block in decoder_blocks:
for i, block in enumerate(decoder_blocks):
# Cross attention has to be initialized with the encoder hidden state
cross_attention: T5LayerCrossAttention = block.layer[1]
attention = cross_attention.EncDecAttention
Expand All @@ -165,28 +211,58 @@ def shape(states):
key_states = shape(attention.k(encoder_hidden_states))
value_states = shape(attention.v(encoder_hidden_states))

# cross_attn_kv_state
present_key_value_states_ca.append(key_states)
present_key_value_states_ca.append(value_states)

# Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant.
# The kv cache is padded here to keep a fixed shape.
# [key states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
if self.tp_degree is None:
# cross_attn_kv_state
present_key_value_states_ca.append(key_states)
present_key_value_states_ca.append(value_states)

# Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant.
# The kv cache is padded here to keep a fixed shape.
# [key states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
)
)
)
# [value states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
# [value states]
present_key_value_states_sa.append(
torch.zeros(
(self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv),
dtype=torch.float32,
device=self.device,
)
)
else:
present_key_value_states_ca.append((self.past_key_values_ca[i * 2] * 0) + key_states)
present_key_value_states_ca.append((self.past_key_values_ca[i * 2 + 1] * 0) + value_states)
present_key_value_states_sa.append(
self.past_key_values_sa[i * 2]
* torch.zeros(
(
self.batch_size,
self.num_attention_heads_per_partition,
self.max_length - 1,
self.model_config.d_kv,
),
dtype=torch.float32,
device="xla",
)
)
present_key_value_states_sa.append(
self.past_key_values_sa[i * 2 + 1]
* torch.zeros(
(
self.batch_size,
self.num_attention_heads_per_partition,
self.max_length - 1,
self.model_config.d_kv,
),
dtype=torch.float32,
device="xla",
)
)
)

return present_key_value_states_sa + present_key_value_states_ca

Expand Down
Loading