diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 99ef65c4e4..88886dbd8a 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -136,6 +136,9 @@ gaudi_conv1d_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask, + Matmul, + KVCache, + apply_customized_rope ) from .mpt import ( GaudiMptForCausalLM, diff --git a/optimum/habana/transformers/models/clip/modeling_clip.py b/optimum/habana/transformers/models/clip/modeling_clip.py index b22c61972d..d6be81b809 100644 --- a/optimum/habana/transformers/models/clip/modeling_clip.py +++ b/optimum/habana/transformers/models/clip/modeling_clip.py @@ -47,14 +47,6 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, sof return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - class Softmax(nn.Module): def __init__(self): super().__init__() diff --git a/optimum/habana/transformers/models/falcon/modeling_falcon.py b/optimum/habana/transformers/models/falcon/modeling_falcon.py index a7a0c0e920..0ea46ad89e 100644 --- a/optimum/habana/transformers/models/falcon/modeling_falcon.py +++ b/optimum/habana/transformers/models/falcon/modeling_falcon.py @@ -19,13 +19,6 @@ except ImportError: SDPContext = False -try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE -except ImportError: - print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None - - import habana_frameworks.torch.core as htcore from torch import nn from torch.nn import CrossEntropyLoss @@ -69,19 +62,6 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: residual.add_(out) return residual - -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and FusedRoPE: - # TODO: remove `.clone()` when it is fixed in SynapseAI - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) - else: - return apply_rotary_pos_emb(q, k, cos, sin, position_ids) - - def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor: hidden_states = F.linear(input, self.weight, bias=self.bias) return hidden_states @@ -105,14 +85,6 @@ def forward(self, x, dim=None, invAttnHead=None): return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead) -class Matmul(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, *args, **kwargs): - return torch.matmul(*args, **kwargs) - - # ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention class ScaledDotProductAttention(nn.Module): def __init__(self, config: FalconConfig): @@ -183,56 +155,6 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa return attn_output -def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - cur = cur.to(dtype=prev.dtype) - - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - - if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - prev_cast = prev.to(orig_cur.dtype) - return prev_cast - else: - return torch.cat((prev, cur), dim=dim) - - -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - @staticmethod - def update(prev, cur, dim, idx, inp_seq_len): - return update(prev, cur, dim, idx, inp_seq_len) - - class GaudiFalconAttention(FalconAttention): """ Inherits from FalconAttention: https://github.com/huggingface/transformers/blob/838b87abe231fd70be5132088d0dee72a7bb8d62/src/transformers/models/falcon/modeling_falcon.py#L267 @@ -374,7 +296,7 @@ def pre_attn_forward( if alibi is None: cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len) - query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids) + query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids, self.training) if use_cache: if self.training: diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index aa6423d2b1..681ec6dcb2 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -15,14 +15,6 @@ from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask - -try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE -except ImportError: - print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None - - def gaudi_gpt_neox_attention_forward( self, hidden_states: torch.FloatTensor, @@ -434,32 +426,3 @@ def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dty emb = torch.cat((freqs, freqs), dim=-1) self.cos_cached = emb.cos() self.sin_cached = emb.sin() - - -def apply_customized_rope(q, k, cos, sin, position_ids, training=True): - if q.device.type == "hpu" and FusedRoPE: - if training: - rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) - rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) - else: - if q.dtype == torch.bfloat16: - rope_q = FusedRoPE.apply( - q, - cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - position_ids, - ) - else: - rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) - if k.dtype == torch.bfloat16: - rope_k = FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - position_ids, - ) - else: - rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) - return rope_q, rope_k - else: - return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index d25e71b16b..0b13132bc1 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -35,7 +35,6 @@ try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE - has_fused_rope = True except ImportError: has_fused_rope = False @@ -348,56 +347,6 @@ def __init__(self, fusedSDPA): def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) - -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - @staticmethod - def update(prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - class GaudiLlamaAttention(LlamaAttention): def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -1406,25 +1355,3 @@ def _reorder_cache(past_key_values, beam_idx): tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past - - -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and has_fused_rope: - # TODO: remove `.clone()` when it is fixed in SynapseAI - if k.dtype == torch.bfloat16: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - position_ids, - ) - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) - else: - # keep the same implementation as Transformers v4.37.2 - return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) diff --git a/optimum/habana/transformers/models/mistral/modeling_mistral.py b/optimum/habana/transformers/models/mistral/modeling_mistral.py index 7d95e548ce..c46fec7cc9 100644 --- a/optimum/habana/transformers/models/mistral/modeling_mistral.py +++ b/optimum/habana/transformers/models/mistral/modeling_mistral.py @@ -53,7 +53,6 @@ try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE - has_fused_rope = True except ImportError: has_fused_rope = False @@ -73,48 +72,6 @@ logger = logging.get_logger(__name__) - -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA): super().__init__() @@ -124,12 +81,32 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale): return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale) -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() +# Copy from GaudiMixtralAttentionLongSequence +class GaudiMistralAttentionLongSequence: + @staticmethod + def forward(q, k, v, mask, causal, q_block_size): + """ + Support long sequence at prompt phase + """ + q_len = q.size(-2) + q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size) + q_padding = q_tiles * q_block_size - q_len + q = F.pad(q, (0, 0, 0, q_padding), "constant", 0) + if mask is not None: + mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0) + attn_output = torch.zeros_like(q) + + for i in range(q_tiles): + s, e = i * q_block_size, (i + 1) * q_block_size + row_q = q[:, :, s:e, :] + row_mask = mask[:, :, s:e, :] + row_o = attn_output[:, :, s:e, :] + row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None)) + + if q_padding != 0: + attn_output = attn_output[:, :, :-q_padding, :] - def forward(self, x, y): - return torch.matmul(x, y) + return attn_output def gaudi_mistral_repeat_kv( @@ -315,7 +292,7 @@ def forward( else: kv_seq_len += kv_shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training) if use_cache: # reuse k, v, self_attention @@ -852,23 +829,3 @@ def prepare_inputs_for_generation( ) return model_inputs - -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and has_fused_rope: - # TODO: remove `.clone()` when SynapseAI v1.15 is released - if k.dtype == torch.bfloat16: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - position_ids, - ) - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) - else: - return apply_rotary_pos_emb(q, k, cos, sin) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index b6c750fa00..1215edb456 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -53,13 +53,6 @@ ) from .configuration_mixtral import MixtralConfig - -try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE -except ImportError: - print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None - try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm except ImportError: @@ -82,15 +75,6 @@ logger = logging.get_logger(__name__) -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and FusedRoPE: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids - ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) - else: - return apply_rotary_pos_emb(q, k, cos, sin, position_ids) - - def gaudi_mixtral_rmsnorm_forward(self, hidden_states): """ Copied from MixtralRMSNorm.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py @@ -148,47 +132,6 @@ def gaudi_mixtral_repeat_kv( return query_states, key_states, value_states, attention_mask -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - class GaudiMixtralAttentionLongSequence: @staticmethod def forward(q, k, v, mask, causal, q_block_size): @@ -317,7 +260,7 @@ def forward( else: kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training) if use_cache: if reuse_cache: diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 808e0f012a..8ae01fd73e 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -21,6 +21,89 @@ from transformers.modeling_utils import ModuleUtilsMixin, PretrainedConfig from transformers.utils.import_utils import is_torch_sdpa_available +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE +except ImportError: + print("Not using HPU fused kernel for apply_rotary_pos_emb") + FusedRoPE = None + + +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, *args, **kwargs): + return torch.matmul(*args, **kwargs) + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + @staticmethod + def update(prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + +def apply_customized_rope(q, k, cos, sin, position_ids, training=True): + if q.device.type == "hpu" and FusedRoPE: + if training: + rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + else: + if q.dtype == torch.bfloat16: + rope_q = FusedRoPE.apply( + q, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, + ) + else: + rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + if k.dtype == torch.bfloat16: + rope_k = FusedRoPE.apply( + k, + cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), + position_ids, + ) + else: + rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + return rope_q, rope_k + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids) def gaudi_invert_attention_mask(self, encoder_attention_mask: torch.Tensor) -> torch.Tensor: """ diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index 07f4d0cd71..43917d9fd7 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -79,55 +79,6 @@ def gaudi_phi_repeat_kv( return query_states, key_states, value_states, attention_mask -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - class GaudiPhiAttention(PhiAttention): def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) diff --git a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py index 0c8970dd88..a587f63875 100644 --- a/optimum/habana/transformers/models/qwen2/modeling_qwen2.py +++ b/optimum/habana/transformers/models/qwen2/modeling_qwen2.py @@ -40,13 +40,6 @@ _gaudi_prepare_4d_causal_attention_mask, ) - -try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE -except ImportError: - print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None - try: from habana_frameworks.torch.hpex.normalization import FusedRMSNorm as FusedRMSNorm except ImportError: @@ -132,56 +125,6 @@ def __init__(self, fusedSDPA): def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode): return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode) - -class Matmul(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - -class KVCache(torch.nn.Module): - def __init__(self): - super(KVCache, self).__init__() - self.cache = None - self.inp_seq_len = -1 - - def allocate(self, inp_seq_len, dtype, device, shape): - if self.cache is None or self.cache.shape != shape: - self.inp_seq_len = inp_seq_len - self.cache = torch.zeros(shape, dtype=dtype, device=device) - else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" - self.cache.fill_(0) - - def update(self, prev, cur, dim, idx, inp_seq_len): - orig_cur = cur - if prev.shape == cur.shape: - prev.copy_(cur) - return orig_cur - if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: - # Initialize - prev[:, :, :inp_seq_len, :].copy_(cur) - return orig_cur - assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" - if idx is not None: - prev.index_copy_(dim, idx - 1, cur) - return prev - else: - return torch.cat((prev, cur), dim=dim) - - def get_shape(self): - if self.cache is None: - return None - return self.cache.shape - - def forward(self, cur, dim, idx): - return self.update(self.cache, cur, dim, idx, self.inp_seq_len) - - class GaudiQwen2Attention(Qwen2Attention): def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): super().__init__(config, layer_idx) @@ -303,7 +246,7 @@ def pre_attn_forward( kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_customized_rope(query_states, key_states, cos, sin, position_ids, self.training) if use_cache: # reuse k, v, self_attention @@ -911,24 +854,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - -def apply_customized_rope(q, k, cos, sin, position_ids): - if q.device.type == "hpu" and FusedRoPE: - # TODO: remove `.clone()` when it is fixed in SynapseAI - if k.dtype == torch.bfloat16: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16), - position_ids, - ) - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ), FusedRoPE.apply( - k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids - ) - else: - return apply_rotary_pos_emb(q, k, cos, sin, position_ids) diff --git a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py index 36d5379e4f..e184d4fea2 100644 --- a/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py +++ b/optimum/habana/transformers/models/starcoder2/modeling_starcoder2.py @@ -39,13 +39,6 @@ _gaudi_prepare_4d_causal_attention_mask, ) - -try: - from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE -except ImportError: - print("Not using HPU fused kernel for apply_rotary_pos_emb") - FusedRoPE = None - try: from habana_frameworks.torch.hpex.kernels import FusedSDPA except ImportError: @@ -877,25 +870,3 @@ def prepare_inputs_for_generation( } ) return model_inputs - - -def apply_customized_rope(q, k, cos, sin, position_ids, is_training): - if q.device.type == "hpu" and FusedRoPE: - if not is_training and (q.dtype == torch.bfloat16 or k.dtype == torch.bfloat16): - return FusedRoPE.apply( - q, - cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - position_ids, - ), FusedRoPE.apply( - k, - cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16), - position_ids, - ) - else: - return FusedRoPE.apply( - q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids - ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) - else: - return apply_rotary_pos_emb(q, k, cos, sin, position_ids)