Skip to content

Commit

Permalink
make style. some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
imangohari1 committed Sep 11, 2024
1 parent e570b02 commit ac35df0
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 24 deletions.
8 changes: 5 additions & 3 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@
GaudiFalconForCausalLM,
GaudiFalconMLP,
GaudiFalconModel,
GaudiGemmaMLP,
GaudiGemmaAttention,
GaudiGemmaDecoderLayer,
GaudiGemmaModel,
GaudiGemmaForCausalLM,
GaudiGemmaMLP,
GaudiGemmaModel,
GaudiGPT2Attention,
GaudiGPT2Block,
GaudiGPT2DoubleHeadsModel,
Expand Down Expand Up @@ -361,7 +361,9 @@ def adapt_transformers_to_gaudi():
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM = GaudiGPTBigCodeForCausalLM
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeBlock.forward = gaudi_gpt_bigcode_block_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeModel.forward = gaudi_gpt_bigcode_model_forward
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBIGCODE_ATTENTION_CLASSES.update({"eager": GaudiGPTBigCodeAttention})
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBIGCODE_ATTENTION_CLASSES.update(
{"eager": GaudiGPTBigCodeAttention}
)

# Optimization for gpt-neox generation on Gaudi
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXForCausalLM = GaudiGPTNeoXForCausalLM
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@
gaudi_falcon_linear_forward,
)
from .gemma import (
GaudiGemmaAttention,
GaudiGemmaDecoderLayer,
GaudiGemmaForCausalLM,
GaudiGemmaAttention,
GaudiGemmaModel,
GaudiGemmaMLP,
GaudiGemmaModel,
)
from .gpt2 import (
GaudiGPT2Attention,
Expand Down
4 changes: 2 additions & 2 deletions optimum/habana/transformers/models/gemma/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .modeling_gemma import (
GaudiGemmaAttention,
GaudiGemmaDecoderLayer,
GaudiGemmaForCausalLM,
GaudiGemmaAttention,
GaudiGemmaMLP,
GaudiGemmaModel,
GaudiGemmaMLP
)
26 changes: 16 additions & 10 deletions optimum/habana/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,30 @@
# limitations under the License.
"""PyTorch Gemma model."""

import math
from typing import List, Optional, Tuple, Union

import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import math
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaConfig,
GemmaMLP,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaMLP,
GemmaModel,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging

from ...modeling_attn_mask_utils import (
_gaudi_prepare_4d_causal_attention_mask,
)


try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
Expand All @@ -51,8 +51,10 @@

import habana_frameworks.torch.core as htcore


logger = logging.get_logger(__name__)


def gaudi_gemma_repeat_kv(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand All @@ -78,6 +80,7 @@ def gaudi_gemma_repeat_kv(

return query_states, key_states, value_states, attention_mask


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -126,6 +129,7 @@ def get_shape(self):
def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


class GaudiGemmaAttention(GemmaAttention):
def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
Expand Down Expand Up @@ -166,7 +170,7 @@ def reorder_kv_cache(self, beam_idx: torch.LongTensor):
self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim)
self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim)
return (self.k_cache.cache.shape, self.v_cache.cache.shape)

def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mask, dropout_rate, q_block_size):
"""
Gaudi version of Flash Attention V1 to support long sequence at prompt phase
Expand All @@ -192,7 +196,7 @@ def gaudi_flash_attn_v1(self, query_layer, key_layer, value_layer, attention_mas
attn_output = attn_output[:, :, :-q_padding, :]

return attn_output

def pre_attn_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -221,7 +225,7 @@ def pre_attn_forward(
- add new arg flash_attention_recompute
"""
if "padding_mask" in kwargs:
warnings.warn(
logger.warning_once(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

Expand Down Expand Up @@ -354,6 +358,7 @@ def post_attn_forward(self, attn_output):
self.o_proj.post_all_reduce(attn_output)
return attn_output


class GaudiGemmaMLP(GemmaMLP):
def pre_mlp_forward(self, x):
inputs = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
Expand All @@ -368,7 +373,8 @@ def post_mlp_forward(self, x):
if hasattr(self.down_proj, "post_all_reduce"):
return self.down_proj.post_all_reduce(x)
return x



class GaudiGemmaDecoderLayer(GemmaDecoderLayer):
def __init__(self, config: GemmaConfig, layer_idx: int):
super().__init__(config, layer_idx)
Expand Down Expand Up @@ -478,7 +484,7 @@ def forward(
outputs += (present_key_value,)

return outputs

def post_attn_pre_mlp(self, hidden_states, residual):
hidden_states = self.self_attn.post_attn_forward(hidden_states)

Expand Down Expand Up @@ -586,7 +592,7 @@ def forward(
past_seen_tokens = past_key_values.get_usable_length(seq_length)
else:
past_seen_tokens = past_key_values[0][0].shape[2]

cache_position = None

if position_ids is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ def __init__(self, fusedSDPA):
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, enable_recompute):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, enable_recompute)
return self._hpu_kernel_fsdpa.apply(
query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode, enable_recompute
)


class GaudiGPTBigCodeAttention(GPTBigCodeAttention):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
Expand Down Expand Up @@ -68,15 +71,14 @@ def gaudi_flash_attn_v1(
row_q = query_layer[:, :, s:e, :]
row_mask = attention_mask[:, :, s:e, :]
attn_output_partial = self.fused_scaled_dot_product_attention(
row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode, enable_recompute
row_q, key_layer, value_layer, row_mask, dropout_rate, is_causal, scale, softmax_mode, enable_recompute
)
row_o_list.append(attn_output_partial)
attn_output = torch.cat(row_o_list, dim=-2)
if q_padding != 0:
attn_output = attn_output[:, :, :-q_padding, :]
return attn_output


def apply_FusedSDPA(
self,
query,
Expand Down Expand Up @@ -170,7 +172,6 @@ def apply_FusedSDPA(

return sdpa_result, None


def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -197,7 +198,9 @@ def forward(
- optimize KV cache
"""
if use_flash_attention:
assert self.fused_scaled_dot_product_attention is not None, "Can't load HPU fused scaled dot-product attention kernel. Please retry without flash attention"
assert (
self.fused_scaled_dot_product_attention is not None
), "Can't load HPU fused scaled dot-product attention kernel. Please retry without flash attention"

if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
Expand Down Expand Up @@ -229,7 +232,9 @@ def forward(
if token_idx is not None:
# Using out of place version of index_add_() to ensure the intermediate tensors are not lost when HPU graphs are enabled.
key = past_key.index_add(1, token_idx - 1, key - torch.index_select(past_key, 1, token_idx - 1))
value = past_value.index_add(1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1))
value = past_value.index_add(
1, token_idx - 1, value - torch.index_select(past_value, 1, token_idx - 1)
)
else:
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
Expand Down
5 changes: 4 additions & 1 deletion tests/test_text_generation_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
("google/gemma-7b", 1, False, 109.70751574382221),
("state-spaces/mamba-130m-hf", 1536, False, 5385.511100161605),
("Deci/DeciLM-7B", 1, False, 120),
("Qwen/Qwen1.5-7B", 4, False, 488.82855464593257),
("Qwen/Qwen2-7B", 512, False, 9669.45787),
("Qwen/Qwen1.5-MoE-A2.7B", 1, True, 40),
],
"fp8": [
Expand Down Expand Up @@ -169,6 +169,9 @@ def _test_text_generation(
if "starcoder" in model_name.lower() and "starcoder2" not in model_name.lower():
command += ["--use_flash_attention"]

if "gemma" in model_name.lower():
command += ["--use_flash_attention"]

if "starcoder2" in model_name.lower():
command += ["--flash_attention_recompute"]

Expand Down

0 comments on commit ac35df0

Please sign in to comment.