From b0cc16a91931553adac3586d4f182711c0ba298c Mon Sep 17 00:00:00 2001 From: JingyaHuang Date: Mon, 4 Dec 2023 21:22:37 +0000 Subject: [PATCH] modif encoder wrapper --- optimum/exporters/neuron/model_wrappers.py | 120 +++++++++++++++++---- 1 file changed, 98 insertions(+), 22 deletions(-) diff --git a/optimum/exporters/neuron/model_wrappers.py b/optimum/exporters/neuron/model_wrappers.py index 0b1ae4504..86aa3693f 100644 --- a/optimum/exporters/neuron/model_wrappers.py +++ b/optimum/exporters/neuron/model_wrappers.py @@ -16,9 +16,12 @@ from typing import TYPE_CHECKING, List, Optional +import neuronx_distributed import torch from transformers.models.t5.modeling_t5 import T5LayerCrossAttention +from ...utils import NormalizedConfigManager + if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel @@ -71,9 +74,49 @@ def __init__( super().__init__() self.model = model self.config = model.config + self.normalized_config = NormalizedConfigManager.get_normalized_config_class(self.config.model_type) self.num_beams = num_beams self.device = device self.tp_degree = tp_degree + if self.tp_degree is not None: + self.num_attention_heads_per_partition = ( + self.normalized_config.num_attention_heads + // neuronx_distributed.parallel_layers.parallel_state.get_tensor_model_parallel_size() + ) + self.past_key_values_sa = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.ones( + ( + self.num_beams, + self.num_attention_heads_per_partition, + self.max_length - 1, + self.normalized_config.key_value_dim, + ), + dtype=torch.float32, + ), + requires_grad=False, + ) + for _ in range(self.config.num_decoder_layers * 2) + ] + ) + self.past_key_values_ca = torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.ones( + ( + self.num_beams, + self.num_attention_heads_per_partition, + self.max_length, + self.normalized_config.key_value_dim, + ), + dtype=torch.float32, + ), + requires_grad=False, + ) + for _ in range(self.config.num_decoder_layers * 2) + ] + ) def forward(self, input_ids, attention_mask): # Infer shapes @@ -81,7 +124,10 @@ def forward(self, input_ids, attention_mask): sequence_length = input_ids.shape[1] encoder_output = self.model.encoder( - input_ids=input_ids, attention_mask=attention_mask, output_attentions=False, output_hidden_states=False + input_ids=input_ids, + attention_mask=attention_mask, + output_attentions=False, + output_hidden_states=False, ) last_hidden_state = encoder_output["last_hidden_state"] @@ -93,7 +139,7 @@ def forward(self, input_ids, attention_mask): present_key_value_states_sa = [] present_key_value_states_ca = [] - for block in decoder_blocks: + for i, block in enumerate(decoder_blocks): # Cross attention has to be initialized with the encoder hidden state cross_attention: T5LayerCrossAttention = block.layer[1] attention = cross_attention.EncDecAttention @@ -107,28 +153,58 @@ def shape(states): key_states = shape(attention.k(encoder_hidden_states)) value_states = shape(attention.v(encoder_hidden_states)) - # cross_attn_kv_state - present_key_value_states_ca.append(key_states) - present_key_value_states_ca.append(value_states) - - # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant. - # The kv cache is padded here to keep a fixed shape. - # [key states] - present_key_value_states_sa.append( - torch.zeros( - (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), - dtype=torch.float32, - device=self.device, + if self.tp_degree is None: + # cross_attn_kv_state + present_key_value_states_ca.append(key_states) + present_key_value_states_ca.append(value_states) + + # Self attention kv states are initialized to zeros. This is done to keep the size of the kv cache tensor constant. + # The kv cache is padded here to keep a fixed shape. + # [key states] + present_key_value_states_sa.append( + torch.zeros( + (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), + dtype=torch.float32, + device=self.device, + ) ) - ) - # [value states] - present_key_value_states_sa.append( - torch.zeros( - (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), - dtype=torch.float32, - device=self.device, + # [value states] + present_key_value_states_sa.append( + torch.zeros( + (self.num_beams * batch_size, self.config.num_heads, sequence_length - 1, self.config.d_kv), + dtype=torch.float32, + device=self.device, + ) + ) + else: + present_key_value_states_ca.append((self.past_key_values_ca[i * 2] * 0) + key_states) + present_key_value_states_ca.append((self.past_key_values_ca[i * 2 + 1] * 0) + value_states) + present_key_value_states_sa.append( + self.past_key_values_sa[i * 2] + * torch.zeros( + ( + self.batch_size, + self.num_attention_heads_per_partition, + self.max_length - 1, + self.model_config.d_kv, + ), + dtype=torch.float32, + device="xla", + ) + ) + present_key_value_states_sa.append( + self.past_key_values_sa[i * 2 + 1] + * torch.zeros( + ( + self.batch_size, + self.num_attention_heads_per_partition, + self.max_length - 1, + self.model_config.d_kv, + ), + dtype=torch.float32, + device="xla", + ) ) - ) return present_key_value_states_sa + present_key_value_states_ca