Skip to content

Commit

Permalink
gpt_bigcode: added FusedSDPA kernel (#1138)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonchar committed Jul 29, 2024
1 parent 2cd2e68 commit 59d182d
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 11 deletions.
207 changes: 197 additions & 10 deletions optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
Expand All @@ -9,6 +11,141 @@
from ...modeling_attn_mask_utils import GaudiAttentionMaskConverter


try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

import habana_frameworks.torch.core as htcore


def gaudi_flash_attn_v1(
query_layer, key_layer, value_layer, attention_mask, dropout_rate, is_causal, scale, softmax_mode, q_block_size
):
"""
Gaudi version of Flash Attention V1 to support long sequence at prompt phase
Causal mask is not supported in this optimization
"""
if is_causal:
raise ValueError("Causal mask is not supported for long input sequences")

q_len = query_layer.size(-2)
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
q_padding = q_tiles * q_block_size - q_len
query_layer = F.pad(query_layer, (0, 0, 0, q_padding), "constant", 0)
if attention_mask is not None:
attention_mask = F.pad(attention_mask, (0, 0, 0, q_padding), "constant", -10000.0)
row_o_list = []
for i in range(q_tiles):
s, e = i * q_block_size, (i + 1) * q_block_size
row_q = query_layer[:, :, s:e, :]
row_mask = attention_mask[:, :, s:e, :]
attn_output_partial = FusedSDPA.apply(
row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode
)
row_o_list.append(attn_output_partial)
attn_output = torch.cat(row_o_list, dim=-2)
if q_padding != 0:
attn_output = attn_output[:, :, :-q_padding, :]
return attn_output


def apply_FusedSDPA(
self,
query,
key,
value,
attention_mask=None,
flash_attention_recompute=False,
flash_attention_fast_softmax=False,
flash_attention_causal_mask=False,
):
"""
Copied from GPTBigCodeSdpaAttention._attn: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- replaced torch.nn.functional.scaled_dot_product_attention with Habana's FusedSDPA
- removed WA for key and value tensor expanding over heads dimension. That WA also works but dramatically drops throughput
- added args use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask to control parameters of FusedSDPA
- added special case handling for input larger 8192 with function gaudi_flash_attn_v1
"""

scale = None
if not self.scale_attn_weights:
scale = 1

# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]

if self.multi_query:
query_length = query_shape[1]

# SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)

# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)

else:
query_length = query_shape[-1]

if attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

sdpa_result = None
enable_recompute = flash_attention_recompute and query_length > 1

if query_length > 1 and flash_attention_causal_mask:
attention_mask = None
use_causal_mask = True
else:
use_causal_mask = self.is_causal and attention_mask is None and query_length > 1

import habana_frameworks.torch.hpu as ht

with ht.sdp_kernel(enable_recompute=enable_recompute):
if query_length > 8192:
sdpa_result = gaudi_flash_attn_v1(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
4096,
)
htcore.mark_step()
else:
sdpa_result = FusedSDPA.apply(
query,
key,
value,
attention_mask,
self.attn_pdrop if self.training else 0.0,
use_causal_mask,
scale,
"fast" if flash_attention_fast_softmax else "None",
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)

# Reshape is kind of expensive here, as it does a memory copy,
# but I did not manage to make away without it (logits do not match when using view)
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)

return sdpa_result, None


def gaudi_gpt_bigcode_attention_forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -20,14 +157,18 @@ def gaudi_gpt_bigcode_attention_forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Optional[torch.Tensor]],
Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
]:
"""
Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Copied from GPTBigCodeAttention.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
- optimize KV cache
"""
if encoder_hidden_states is not None:
Expand Down Expand Up @@ -65,7 +206,21 @@ def gaudi_gpt_bigcode_attention_forward(
value = torch.cat((past_value, value), dim=-2)
present = torch.cat((key, value), dim=-1) if use_cache else None

attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)
if not output_attentions and head_mask is None and use_flash_attention:
# Difference with the original implementation: there is no need to transpose the key here,
# as SDPA expects seq_length to be at index -2 for the key as well
attn_output, attn_weights = apply_FusedSDPA(
self,
query,
key,
value,
attention_mask,
flash_attention_recompute,
flash_attention_fast_softmax,
flash_attention_causal_mask,
)
else:
attn_output, attn_weights = self._attn(query, key.transpose(-1, -2), value, attention_mask, head_mask)

if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
Expand Down Expand Up @@ -93,11 +248,15 @@ def gaudi_gpt_bigcode_block_forward(
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Copied from GPTBigCodeBlock.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
"""
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
Expand All @@ -109,6 +268,10 @@ def gaudi_gpt_bigcode_block_forward(
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
Expand Down Expand Up @@ -167,13 +330,21 @@ def gaudi_gpt_bigcode_model_forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
"""
Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Copied from GPTBigCodeModel.forward: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
- if token_idx and past_key_values are passed, set self_attention_mask based on the static shape of past_key_values
"""

# This flag used for correct tensors reshape for attention kernel
self._use_sdpa = use_flash_attention

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -322,6 +493,10 @@ def gaudi_gpt_bigcode_model_forward(
use_cache=use_cache,
output_attentions=output_attentions,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -358,10 +533,10 @@ def gaudi_gpt_bigcode_model_forward(

class GaudiGPTBigCodeForCausalLM(GPTBigCodeForCausalLM):
"""
Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Inherits from GPTBigCodeForCausalLM: https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
The only differences are:
- add new args token_idx
- add token_idx into model_inputs
- add new args token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask
- add token_idx, use_flash_attention, flash_attention_recompute, flash_attention_fast_softmax, flash_attention_causal_mask into model_inputs
- when KV cache is enabled, slice next_input_ids from input_ids based on the token_idx
- when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx
"""
Expand Down Expand Up @@ -422,6 +597,10 @@ def prepare_inputs_for_generation(
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
"token_idx": token_idx,
"use_flash_attention": kwargs.get("use_flash_attention", False),
"flash_attention_recompute": kwargs.get("flash_attention_recompute", False),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax", False),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask", False),
}
)
return model_inputs
Expand All @@ -443,6 +622,10 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand All @@ -467,6 +650,10 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_fast_softmax=flash_attention_fast_softmax,
flash_attention_causal_mask=flash_attention_causal_mask,
)
hidden_states = transformer_outputs[0]

Expand Down
5 changes: 4 additions & 1 deletion tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
("EleutherAI/gpt-neox-20b", 1, False, 50.67672679310354),
("meta-llama/Llama-2-7b-hf", 1, True, 141.25776956002076),
("tiiuae/falcon-40b", 1, True, 25.202450111088346),
("bigcode/starcoder", 1, False, 65.58632640700114),
("bigcode/starcoder", 256, False, 4329.754794647058),
("Salesforce/codegen2-1B", 1, False, 446.4029486883532),
("mosaicml/mpt-30b", 1, False, 36.06464336116623),
("mistralai/Mistral-7B-v0.1", 1, True, 130.2172236767782),
Expand Down Expand Up @@ -142,6 +142,9 @@ def _test_text_generation(
if "falcon" in model_name.lower() or "starcoder2" in model_name.lower():
command += ["--use_flash_attention", "--flash_attention_causal_mask"]

if "starcoder" in model_name.lower() and "starcoder2" not in model_name.lower():
command += ["--use_flash_attention"]

if "starcoder2" in model_name.lower():
command += ["--flash_attention_recompute"]

Expand Down

0 comments on commit 59d182d

Please sign in to comment.