diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index e35b4cac52..ca0d06aebc 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -1,6 +1,8 @@ +import math from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions @@ -9,6 +11,141 @@ from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +import habana_frameworks.torch.core as htcore + + +def gaudi_flash_attn_v1( + query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_causal, scale, softmax_mode, q_block_size +): + """ + Gaudi version of Flash Attention V1 to support long sequence at prompt phase + Causal mask is not supported in this optimization + """ + if is_causal: + raise ValueError("Causal mask is not supported for long input sequences") + + q_len = query_layer.size(-2) + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + q_padding = q_tiles * q_block_size - q_len + query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0) + if attention_mask is not None: + attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0) + row_o_list = [] + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = query_layer[:, :, s:e, :] + row_mask = attention_mask[:, :, s:e, :] + attn_output_partial = FusedSDPA.apply( + row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode + ) + row_o_list.append(attn_output_partial) + attn_output = torch.cat(row_o_list, dim=-2) + if q_padding != 0: + attn_output = attn_output[:, :, :-q_padding, :] + return attn_output + + +def apply_FusedSDPA( + self, + query, + key, + value, + attention_mask=None, + flash_attention_recompute=False, + flash_attention_fast_softmax=False, + flash_attention_causal_mask=False, +): + """ + Copied from GPTBigCodeSdpaAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + The only differences are: + - replaced torch.nn.functional.scaled_dot_product_attention with Habana's FusedSDPA + - removed WA for key and value tensor expanding over heads dimension. That WA also works but dramatically drops throughput + - added args use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask to control parameters of FusedSDPA + - added special case handling for input larger 8192 with function gaudi_flash_attn_v1 + """ + + scale = None + if not self.scale_attn_weights: + scale = 1 + + # MQA models: (batch_size, query_length, num_heads * head_dim) + # MHA models: (batch_size, num_heads, query_length, head_dim) + query_shape = query.shape + batch_size = query_shape[0] + + if self.multi_query: + query_length = query_shape[1] + + # SDPA requires the dimension [..., sequence_length, head_dim]. + query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2) + + # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions. + key = key.unsqueeze(1) + value = value.unsqueeze(1) + + else: + query_length = query_shape[-1] + + if attention_mask is not None: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + + sdpa_result = None + enable_recompute = flash_attention_recompute and query_length > 1 + + if query_length > 1 and flash_attention_causal_mask: + attention_mask = None + use_causal_mask = True + else: + use_causal_mask = self.is_causal and attention_mask is None and query_length > 1 + + import habana_frameworks.torch.hpu as ht + + with ht.sdp_kernel(enable_recompute=enable_recompute): + if query_length > 8192: + sdpa_result = gaudi_flash_attn_v1( + query, + key, + value, + attention_mask, + self.attn_pdrop if self.training else 0.0, + use_causal_mask, + scale, + "fast" if flash_attention_fast_softmax else "None", + 4096, + ) + htcore.mark_step() + else: + sdpa_result = FusedSDPA.apply( + query, + key, + value, + attention_mask, + self.attn_pdrop if self.training else 0.0, + use_causal_mask, + scale, + "fast" if flash_attention_fast_softmax else "None", + ) + + if self.multi_query: + # (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim) + sdpa_result = sdpa_result.transpose(1, 2) + + # Reshape is kind of expensive here, as it does a memory copy, + # but I did not manage to make away without it (logits do not match when using view) + # (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim) + sdpa_result = sdpa_result.reshape(query_shape) + + return sdpa_result, None + + def gaudi_gpt_bigcode_attention_forward( self, hidden_states: torch.Tensor, @@ -20,14 +157,18 @@ def gaudi_gpt_bigcode_attention_forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[ Tuple[torch.Tensor, Optional[torch.Tensor]], Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], ]: """ - Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - optimize KV cache """ if encoder_hidden_states is not None: @@ -65,7 +206,21 @@ def gaudi_gpt_bigcode_attention_forward( value = torch.cat((past_value, value), dim=-2) present = torch.cat((key, value), dim=-1) if use_cache else None - attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) + if not output_attentions and head_mask is None and use_flash_attention: + # Difference with the original implementation: there is no need to transpose the key here, + # as SDPA expects seq_length to be at index -2 for the key as well + attn_output, attn_weights = apply_FusedSDPA( + self, + query, + key, + value, + attention_mask, + flash_attention_recompute, + flash_attention_fast_softmax, + flash_attention_causal_mask, + ) + else: + attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask) if not self.multi_query: attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape) @@ -93,11 +248,15 @@ def gaudi_gpt_bigcode_block_forward( use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ - Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask """ residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -109,6 +268,10 @@ def gaudi_gpt_bigcode_block_forward( use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_fast_softmax=flash_attention_fast_softmax, + flash_attention_causal_mask=flash_attention_causal_mask, ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] @@ -167,13 +330,21 @@ def gaudi_gpt_bigcode_model_forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: """ - Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask - if token_idx and past_key_values are passed, set self_attention_mask based on the static shape of past_key_values """ + + # This flag used for correct tensors reshape for attention kernel + self._use_sdpa = use_flash_attention + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -322,6 +493,10 @@ def gaudi_gpt_bigcode_model_forward( use_cache=use_cache, output_attentions=output_attentions, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_fast_softmax=flash_attention_fast_softmax, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = outputs[0] @@ -358,10 +533,10 @@ def gaudi_gpt_bigcode_model_forward( class GaudiGPTBigCodeForCausalLM(GPTBigCodeForCausalLM): """ - Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py + Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py The only differences are: - - add new args token_idx - - add token_idx into model_inputs + - add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask + - add token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask into model_inputs - when KV cache is enabled, slice next_input_ids from input_ids based on the token_idx - when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx """ @@ -422,6 +597,10 @@ def prepare_inputs_for_generation( "attention_mask": attention_mask, "token_type_ids": token_type_ids, "token_idx": token_idx, + "use_flash_attention": kwargs.get("use_flash_attention", False), + "flash_attention_recompute": kwargs.get("flash_attention_recompute", False), + "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False), } ) return model_inputs @@ -443,6 +622,10 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -467,6 +650,10 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_fast_softmax=flash_attention_fast_softmax, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = transformer_outputs[0] diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 313434490b..4e116242f5 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -25,7 +25,7 @@ ("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354), ("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076), ("tiiuae/falcon-40b", 1, True, 25.202450111088346), - ("bigcode/starcoder", 1, False, 65.58632640700114), + ("bigcode/starcoder", 256, False, 4329.754794647058), ("Salesforce/codegen2-1B", 1, False, 446.4029486883532), ("mosaicml/mpt-30b", 1, False, 36.06464336116623), ("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782), @@ -142,6 +142,9 @@ def _test_text_generation( if "falcon" in model_name.lower() or "starcoder2" in model_name.lower(): command += ["--use_flash_attention", "--flash_attention_causal_mask"] + if "starcoder" in model_name.lower() and "starcoder2" not in model_name.lower(): + command += ["--use_flash_attention"] + if "starcoder2" in model_name.lower(): command += ["--flash_attention_recompute"]