diff --git a/QEfficient/base/pytorch_transforms.py b/QEfficient/base/pytorch_transforms.py index 25d31d5b..b787cebc 100644 --- a/QEfficient/base/pytorch_transforms.py +++ b/QEfficient/base/pytorch_transforms.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD-3-Clause # # ---------------------------------------------------------------------------- - from typing import Dict, Tuple, Type from torch import nn @@ -43,6 +42,9 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: for module in model.modules(): if repl_module := cls._module_mapping.get(type(module)): module.__class__ = repl_module + # Handling the __init__ calls in the models + if hasattr(module, "__qeff_init__"): + module.__qeff_init__() transformed = True return model, transformed diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index ed7a8060..035ee014 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -108,6 +108,9 @@ MistralForCausalLM.__name__, MixtralForCausalLM.__name__, Starcoder2ForCausalLM.__name__, + Qwen2ForCausalLM.__name__, + Phi3ForCausalLM.__name__, + PhiForCausalLM.__name__, ] ) # Create an instance of the named tuple diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index e33a357b..5f4aa2e5 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -20,17 +20,91 @@ ) from transformers.models.llama.modeling_llama import ( LlamaAttention, + LlamaConfig, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, - apply_rotary_pos_emb, + LlamaRotaryEmbedding, logger, repeat_kv, + rotate_half, ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super(LlamaRotaryEmbedding, self).__init__() # Initialize nn.Module + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + class QEffLlamaAttention(LlamaAttention): """ Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py @@ -38,6 +112,19 @@ class QEffLlamaAttention(LlamaAttention): - add new args cache idx for the kv retention """ + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.__qeff_init__() + + def __qeff_init__(self): + self.rotary_emb = QEffLlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + def forward( self, hidden_states: torch.Tensor, @@ -78,8 +165,18 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index c183024e..c506e3b4 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -15,27 +15,97 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.models.mistral.modeling_mistral import ( MistralAttention, + MistralConfig, MistralDecoderLayer, MistralForCausalLM, MistralModel, - apply_rotary_pos_emb, + MistralRotaryEmbedding, logger, repeat_kv, + rotate_half, ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +class QEffMistralRotaryEmbedding(MistralRotaryEmbedding): + """ + Copied from MistralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super(MistralRotaryEmbedding, self).__init__() # Initialize nn.Module + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + class QEffMistralAttention(MistralAttention): """ Copied from MistralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py @@ -43,6 +113,19 @@ class QEffMistralAttention(MistralAttention): - add new args cache idx for the kv retention """ + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.__qeff_init__() + + def __qeff_init__(self): + self.rotary_emb = QEffMistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + def forward( self, hidden_states: torch.Tensor, @@ -52,6 +135,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -73,7 +157,7 @@ def forward( ) kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} @@ -85,18 +169,7 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - + if attention_mask is not None: # no matter the length, we just slice it attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) # upcast attention to fp32 @@ -111,6 +184,7 @@ def forward( ) attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) @@ -140,6 +214,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -150,82 +225,43 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + return_legacy_cache = True + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) - elif attention_mask is None: - # Causal mask with # --- Rolling buffer --- and # Sliding window mask - # Change for Cloud AI 100 (vbaddi) - attention_mask = _create_causal_mask( - position_ids=position_ids, - target_length=past_key_values_length, - sliding_window=self.config.sliding_window, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + hidden_states = inputs_embeds # decoder layers @@ -241,30 +277,23 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, use_cache, - ) - elif batch_index is not None: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, + batch_index=batch_index, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -281,9 +310,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -311,6 +340,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -339,6 +370,8 @@ def forward( batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, + **kwargs, ) hidden_states = residual + hidden_states @@ -379,6 +412,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -424,6 +458,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) # Cast to int32 to avoid ONNXRT issue diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index bc7970c4..79031b02 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -16,26 +16,96 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from transformers.models.mixtral.modeling_mixtral import ( MixtralAttention, + MixtralConfig, MixtralDecoderLayer, MixtralForCausalLM, MixtralModel, + MixtralRotaryEmbedding, MixtralSparseMoeBlock, - apply_rotary_pos_emb, load_balancing_loss_func, logger, repeat_kv, + rotate_half, ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +class QEffMixtralRotaryEmbedding(MixtralRotaryEmbedding): + """ + Copied from MixtralForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super(MixtralRotaryEmbedding, self).__init__() # Initialize nn.Module + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + class QEffMixtralAttention(MixtralAttention): """ Copied from MixtralAttention: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py @@ -43,6 +113,19 @@ class QEffMixtralAttention(MixtralAttention): - add new args position idx for the cache_kwargs for kv retention """ + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.__qeff_init__() + + def __qeff_init__(self): + self.rotary_emb = QEffMixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + def forward( self, hidden_states: torch.Tensor, @@ -52,6 +135,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -73,7 +157,7 @@ def forward( ) kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = { @@ -201,6 +285,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( @@ -213,17 +298,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - - past_key_values_length = 0 + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -232,63 +310,30 @@ def forward( ) use_cache = False - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + use_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - elif attention_mask is None: - # Causal mask with # --- Rolling buffer --- and # Sliding window mask - # Change for Cloud AI 100 (vbaddi) - attention_mask = _create_causal_mask( - position_ids=position_ids, - target_length=past_key_values_length, - sliding_window=self.config.sliding_window, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) hidden_states = inputs_embeds @@ -306,33 +351,25 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, output_router_logits, use_cache, - ) - elif batch_index is not None: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - batch_index=batch_index, - output_attentions=output_attentions, - output_router_logits=output_router_logits, - use_cache=use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, + batch_index=batch_index, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -388,6 +425,7 @@ def forward( output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -425,6 +463,7 @@ def forward( batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = residual + hidden_states @@ -470,6 +509,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: @@ -520,6 +560,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) # Cast to int32 to avoid ONNXRT issue diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index f4a75308..188e603e 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -15,13 +15,10 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.phi.modeling_phi import ( PhiAttention, + PhiDecoderLayer, PhiForCausalLM, PhiModel, apply_rotary_pos_emb, @@ -47,8 +44,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -93,7 +92,7 @@ def forward( if past_key_value is not None: # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -154,11 +153,13 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -168,17 +169,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -187,44 +181,30 @@ def forward( ) use_cache = False - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + use_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) - position_ids = position_ids.unsqueeze(0) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = self.embed_dropout(inputs_embeds) - - # Attention mask. - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif attention_mask is None: - # update attention mask for Cloud Ai 100 - attention_mask = _create_causal_mask(position_ids, past_key_values_length) - elif self._use_sdpa and not output_attentions: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + inputs_embeds = self.embed_dropout(inputs_embeds) hidden_states = inputs_embeds # decoder layers @@ -240,19 +220,33 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, - past_key_values, output_attentions, + use_cache, + past_key_values, + cache_position, + ) + elif batch_index is not None: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -282,6 +276,79 @@ def forward( ) +class QEffPhiDecoderLayer(PhiDecoderLayer): + """ + Copied from PhiForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi/modeling_phi.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + attn_outputs = self.resid_dropout(attn_outputs) + + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + class QEffPhiForCausalLM(PhiForCausalLM): """ Copied from PhiForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi/modeling_phi.py @@ -296,12 +363,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -340,6 +409,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + batch_index=batch_index, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 17f1a8bf..68b01506 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -15,20 +15,94 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.phi3.modeling_phi3 import ( Phi3Attention, + Phi3Config, + Phi3DecoderLayer, Phi3ForCausalLM, Phi3Model, - apply_rotary_pos_emb, + Phi3RotaryEmbedding, logger, repeat_kv, + rotate_half, ) from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +class QEffPhi3RotaryEmbedding(Phi3RotaryEmbedding): + """ + Copied from Phi3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super(Phi3RotaryEmbedding, self).__init__() # Initialize nn.Module + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + class QEffPhi3Attention(Phi3Attention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -38,14 +112,29 @@ class QEffPhi3Attention(Phi3Attention): - add new args position idx for the cache_kwargs for kv retention """ + def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + # Define the general __qeff_init__() for any changes in the init calls + # Set the init in the module mapping pytorch transforms + self.__qeff_init__() + + def __qeff_init__(self): + self.rotary_emb = QEffPhi3RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") @@ -70,13 +159,13 @@ def forward( "with a layer index." ) kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -135,11 +224,13 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -149,17 +240,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -168,48 +252,30 @@ def forward( ) use_cache = False - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + use_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif attention_mask is None: - # update attention mask for Cloud Ai 100 - attention_mask = _create_causal_mask(position_ids, past_key_values_length, self.config.sliding_window) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) hidden_states = inputs_embeds @@ -226,20 +292,33 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, use_cache, + cache_position, + ) + elif batch_index is not None: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -269,6 +348,83 @@ def forward( ) +class QEffPhi3DecoderLayer(Phi3DecoderLayer): + """ + Copied from Phi3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + class QEffPhi3ForCausalLM(Phi3ForCausalLM): """ Copied from Phi3ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py @@ -283,12 +439,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -321,6 +479,7 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, + batch_index=batch_index, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index d5dc1b00..039cb569 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -15,16 +15,13 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.models.qwen2.modeling_qwen2 import ( Qwen2Attention, + Qwen2DecoderLayer, Qwen2ForCausalLM, Qwen2Model, apply_rotary_pos_emb, @@ -48,8 +45,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -75,7 +74,7 @@ def forward( if past_key_value is not None: # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids} + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -129,11 +128,13 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -143,15 +144,10 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -160,60 +156,30 @@ def forward( ) use_cache = False - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + use_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - elif attention_mask is None: - # update attention mask for Cloud Ai 100 - attention_mask = _create_causal_mask(position_ids, past_key_values_length, self.config.sliding_window) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) hidden_states = inputs_embeds @@ -230,20 +196,33 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, use_cache, + cache_position, + ) + elif batch_index is not None: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -274,6 +253,79 @@ def forward( ) +class QEffQwen2DecoderLayer(Qwen2DecoderLayer): + """ + Copied from Qwen2ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update the hidden_states, and fix for onnx model + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + class QEffQwen2ForCausalLM(Qwen2ForCausalLM): """ Copied from Qwen2ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py @@ -288,12 +340,14 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -333,6 +387,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index eb92c7da..9bff542a 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -15,13 +15,10 @@ from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, + Starcoder2DecoderLayer, Starcoder2ForCausalLM, Starcoder2Model, apply_rotary_pos_emb, @@ -50,6 +47,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -141,6 +139,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -150,15 +149,11 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -166,60 +161,30 @@ def forward( ) use_cache = False - past_key_values_length = 0 - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + use_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache) and not self.training: + use_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Starcoder2. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) - elif attention_mask is None: - # update attention mask for Cloud Ai 100 - attention_mask = _create_causal_mask(position_ids, past_key_values_length, self.config.sliding_window) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) hidden_states = inputs_embeds hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training) @@ -237,30 +202,23 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, use_cache, - ) - elif batch_index is not None: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - batch_index=batch_index, - output_attentions=output_attentions, - use_cache=use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -291,8 +249,14 @@ def forward( ) -class QEFFStarcoder2DecoderLayer(nn.Module): - # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward +class QEFFStarcoder2DecoderLayer(Starcoder2DecoderLayer): + """ + Copied from Starcoder2DecoderLayer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/starcoder2/modeling_starcoder2.py + The only differences are: + - add new args position idx for the cache_kwargs for kv retention + - update causal attention mask + """ + def forward( self, hidden_states: torch.Tensor, @@ -337,6 +301,7 @@ def forward( batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = residual + hidden_states @@ -378,6 +343,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -423,6 +389,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) # Cast to INT32 to avoid issue while running in ONNXRT diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/pytorch_transforms.py index d2b1498c..21390bd2 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/pytorch_transforms.py @@ -44,9 +44,21 @@ MixtralSparseMoeBlock, ) from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel -from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel -from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm -from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm +from transformers.models.phi.modeling_phi import PhiAttention, PhiDecoderLayer, PhiForCausalLM, PhiModel +from transformers.models.phi3.modeling_phi3 import ( + Phi3Attention, + Phi3DecoderLayer, + Phi3ForCausalLM, + Phi3Model, + Phi3RMSNorm, +) +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2ForCausalLM, + Qwen2Model, + Qwen2RMSNorm, +) from transformers.models.starcoder2.modeling_starcoder2 import ( Starcoder2Attention, Starcoder2DecoderLayer, @@ -99,9 +111,24 @@ QEffMptForCausalLM, QEFfMptModel, ) -from QEfficient.transformers.models.phi.modeling_phi import QEffPhiAttention, QEffPhiForCausalLM, QEffPhiModel -from QEfficient.transformers.models.phi3.modeling_phi3 import QEffPhi3Attention, QEffPhi3ForCausalLM, QEffPhi3Model -from QEfficient.transformers.models.qwen2.modeling_qwen2 import QEffQwen2Attention, QEffQwen2ForCausalLM, QEffQwen2Model +from QEfficient.transformers.models.phi.modeling_phi import ( + QEffPhiAttention, + QEffPhiDecoderLayer, + QEffPhiForCausalLM, + QEffPhiModel, +) +from QEfficient.transformers.models.phi3.modeling_phi3 import ( + QEffPhi3Attention, + QEffPhi3DecoderLayer, + QEffPhi3ForCausalLM, + QEffPhi3Model, +) +from QEfficient.transformers.models.qwen2.modeling_qwen2 import ( + QEffQwen2Attention, + QEffQwen2DecoderLayer, + QEffQwen2ForCausalLM, + QEffQwen2Model, +) from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import ( QEffStarcoder2Attention, QEFFStarcoder2DecoderLayer, @@ -194,4 +221,10 @@ class CBTransform(KVCacheTransform): MixtralDecoderLayer: QeffMixtralDecoderLayer, # Starcoder2 Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer, + # Phi + PhiDecoderLayer: QEffPhiDecoderLayer, + # Phi3 + Phi3DecoderLayer: QEffPhi3DecoderLayer, + # Qwen2 + Qwen2DecoderLayer: QEffQwen2DecoderLayer, } diff --git a/README.md b/README.md index a3f823d9..65032397 100644 --- a/README.md +++ b/README.md @@ -54,12 +54,12 @@ source qeff_env/bin/activate pip install -U pip # Clone and Install the QEfficient Repo. -pip install git+https://github.com/quic/efficient-transformers +pip install git+https://github.com/quic/efficient-transformers --extra-index-url https://download.pytorch.org/whl/cpu # Or build wheel package using the below command. pip install build wheel python -m build --wheel --outdir dist -pip install dist/QEfficient-0.0.1.dev0-py3-none-any.whl +pip install dist/QEfficient-0.0.1.dev0-py3-none-any.whl --extra-index-url https://download.pytorch.org/whl/cpu ``` diff --git a/pyproject.toml b/pyproject.toml index ab2d5e94..a45798d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,8 +19,9 @@ classifiers = [ ] requires-python = "==3.8.*" dependencies = [ - "transformers==4.41.2", - "torch@https://download.pytorch.org/whl/cpu/torch-2.0.0%2Bcpu-cp38-cp38-linux_x86_64.whl#sha256=354f281351cddb590990089eced60f866726415f7b287db5105514aa3c5f71ca", + "transformers==4.44.2", + "torch==2.0.0+cpu", + "peft==0.12.0", "datasets==2.7.0", "fsspec==2023.6.0", "multidict==6.0.4", diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 48ca48ca..79ccc292 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -19,7 +19,7 @@ pipeline python3.8 -m venv preflight_qeff . preflight_qeff/bin/activate pip install --upgrade pip setuptools - pip install .[test] + pip install .[test] --extra-index-url https://download.pytorch.org/whl/cpu pip install junitparser pytest-xdist rm -rf QEfficient ''' diff --git a/tests/transformers/test_transformer_pytorch_transforms.py b/tests/transformers/test_transformer_pytorch_transforms.py index a7ded6ea..a458ebeb 100644 --- a/tests/transformers/test_transformer_pytorch_transforms.py +++ b/tests/transformers/test_transformer_pytorch_transforms.py @@ -175,6 +175,7 @@ def test_kv_cache_transform( num_attention_heads=num_attention_heads, hidden_size=hidden_size, use_cache=True, + cache_position=None, ) hf_model = AutoModelForCausalLM.from_config(config=config, attn_implementation="eager")