Skip to content

Commit

Permalink
merged 1148. resolved conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
imangohari1 committed Aug 13, 2024
2 parents 737028e + 798f99d commit 9910b98
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 481 deletions.
3 changes: 3 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@
gaudi_conv1d_forward,
gaudi_get_extended_attention_mask,
gaudi_invert_attention_mask,
Matmul,
KVCache,
apply_customized_rope
)
from .mpt import (
GaudiMptForCausalLM,
Expand Down
8 changes: 0 additions & 8 deletions optimum/habana/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,6 @@ def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, sof
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class Softmax(nn.Module):
def __init__(self):
super().__init__()
Expand Down
80 changes: 1 addition & 79 deletions optimum/habana/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@
except ImportError:
SDPContext = False

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None


import habana_frameworks.torch.core as htcore
from torch import nn
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -69,19 +62,6 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
residual.add_(out)
return residual


def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and FusedRoPE:
# TODO: remove `.clone()` when it is fixed in SynapseAI
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)


def gaudi_falcon_linear_forward(self, input: torch.Tensor) -> torch.Tensor:
hidden_states = F.linear(input, self.weight, bias=self.bias)
return hidden_states
Expand All @@ -105,14 +85,6 @@ def forward(self, x, dim=None, invAttnHead=None):
return torch.ops.hpu.softmax_fp8(x, dim, None, None, invAttnHead)


class Matmul(nn.Module):
def __init__(self):
super().__init__()

def forward(self, *args, **kwargs):
return torch.matmul(*args, **kwargs)


# ScaledDotProductAttention is based on torch.nn.functional.scaled_dot_product_attention
class ScaledDotProductAttention(nn.Module):
def __init__(self, config: FalconConfig):
Expand Down Expand Up @@ -183,56 +155,6 @@ def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=Fa
return attn_output


def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
cur = cur.to(dtype=prev.dtype)

if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur

if cur.shape[-2] > 1 and cur.shape[-2] <= prev.shape[-2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}"
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
prev_cast = prev.to(orig_cur.dtype)
return prev_cast
else:
return torch.cat((prev, cur), dim=dim)


class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
return update(prev, cur, dim, idx, inp_seq_len)


class GaudiFalconAttention(FalconAttention):
"""
Inherits from FalconAttention: https://github.com/huggingface/transformers/blob/838b87abe231fd70be5132088d0dee72a7bb8d62/src/transformers/models/falcon/modeling_falcon.py#L267
Expand Down Expand Up @@ -374,7 +296,7 @@ def pre_attn_forward(

if alibi is None:
cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids)
query_layer, key_layer = apply_customized_rope(query_layer, key_layer, cos, sin, position_ids, self.training)

if use_cache:
if self.training:
Expand Down
37 changes: 0 additions & 37 deletions optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@

from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask


try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE
except ImportError:
print("Not using HPU fused kernel for apply_rotary_pos_emb")
FusedRoPE = None


def gaudi_gpt_neox_attention_forward(
self,
hidden_states: torch.FloatTensor,
Expand Down Expand Up @@ -434,32 +426,3 @@ def gaudi_gpt_neox_rotary_embedding_set_cos_sin_cache(self, seq_len, device, dty
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()


def apply_customized_rope(q, k, cos, sin, position_ids, training=True):
if q.device.type == "hpu" and FusedRoPE:
if training:
rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
else:
if q.dtype == torch.bfloat16:
rope_q = FusedRoPE.apply(
q,
cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
position_ids,
)
else:
rope_q = FusedRoPE.apply(q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
if k.dtype == torch.bfloat16:
rope_k = FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).to(torch.bfloat16),
position_ids,
)
else:
rope_k = FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids)
return rope_q, rope_k
else:
return apply_rotary_pos_emb(q, k, cos, sin, position_ids)
73 changes: 0 additions & 73 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

try:
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE

has_fused_rope = True
except ImportError:
has_fused_rope = False
Expand Down Expand Up @@ -348,56 +347,6 @@ def __init__(self, fusedSDPA):
def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


class Matmul(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class KVCache(torch.nn.Module):
def __init__(self):
super(KVCache, self).__init__()
self.cache = None
self.inp_seq_len = -1

def allocate(self, inp_seq_len, dtype, device, shape):
if self.cache is None or self.cache.shape != shape:
self.inp_seq_len = inp_seq_len
self.cache = torch.zeros(shape, dtype=dtype, device=device)
else:
assert (
self.inp_seq_len == inp_seq_len
), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}"
self.cache.fill_(0)

@staticmethod
def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
if prev.shape == cur.shape:
prev.copy_(cur)
return orig_cur
if idx is not None and cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]:
# Initialize
prev[:, :, :inp_seq_len, :].copy_(cur)
return orig_cur
if idx is not None:
prev.index_copy_(dim, idx - 1, cur)
return prev
else:
return torch.cat((prev, cur), dim=dim)

def get_shape(self):
if self.cache is None:
return None
return self.cache.shape

def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


class GaudiLlamaAttention(LlamaAttention):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
Expand Down Expand Up @@ -1406,25 +1355,3 @@ def _reorder_cache(past_key_values, beam_idx):
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past


def apply_customized_rope(q, k, cos, sin, position_ids):
if q.device.type == "hpu" and has_fused_rope:
# TODO: remove `.clone()` when it is fixed in SynapseAI
if k.dtype == torch.bfloat16:
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k,
cos.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
sin.unsqueeze(0).unsqueeze(0).clone().to(torch.bfloat16),
position_ids,
)
return FusedRoPE.apply(
q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
), FusedRoPE.apply(
k, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids
)
else:
# keep the same implementation as Transformers v4.37.2
return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids])
Loading

0 comments on commit 9910b98

Please sign in to comment.