Skip to content

Commit

Permalink
[Model Updates]: Update the rope calculations (quic#87)
Browse files Browse the repository at this point in the history
* [Llama]: Update the rope calculations

- Update Pytorch Base Transforms to include the InitMapping
- Seperate the InitMapping module and apply the corresponding transform

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* Remove init mapping and initilize at module mapping

1. Initialize the modified QEff. rotary embedding at the Module mapping transform
2. Setattr, based on the instance check

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* update transformers to v4.44.2

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* update the InitMapping transform to __qeff_init__() calls

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* pass cache_position as None in the KVTransform tests

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* update packages, torch and peft

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* updates to model for transformers v4.44.2

1. Also add support for CB on [Phi, Phi3, Qwen2]

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* remove bool from the attention_mask

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* Added condition while decoder layer during non CB execution

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* Added model in get_lists_of_cb_qeff_models

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* update to setup details to include torch cpu version only

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* add peft to the setup

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* update lint format in the toml

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* update lint format in the toml, v1

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* update lint format in the toml, v2

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* fix linter

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* fix license

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* add license headers to the setup

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* update toml

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* fix setup

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

* fix setup issues, modify toml and remove setup

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>

---------

Signed-off-by: vbaddi <quic_vbaddi@quicinc.com>
Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Co-authored-by: Rishin Raj <quic_rishinr@quicinc.com>
  • Loading branch information
vbaddi and quic-rishinr authored Sep 11, 2024
1 parent 34e29f8 commit 67922d7
Show file tree
Hide file tree
Showing 14 changed files with 903 additions and 439 deletions.
4 changes: 3 additions & 1 deletion QEfficient/base/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

from typing import Dict, Tuple, Type

from torch import nn
Expand Down Expand Up @@ -43,6 +42,9 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
for module in model.modules():
if repl_module := cls._module_mapping.get(type(module)):
module.__class__ = repl_module
# Handling the __init__ calls in the models
if hasattr(module, "__qeff_init__"):
module.__qeff_init__()
transformed = True
return model, transformed

Expand Down
3 changes: 3 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@
MistralForCausalLM.__name__,
MixtralForCausalLM.__name__,
Starcoder2ForCausalLM.__name__,
Qwen2ForCausalLM.__name__,
Phi3ForCausalLM.__name__,
PhiForCausalLM.__name__,
]
)
# Create an instance of the named tuple
Expand Down
103 changes: 100 additions & 3 deletions QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,111 @@
)
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
apply_rotary_pos_emb,
LlamaRotaryEmbedding,
logger,
repeat_kv,
rotate_half,
)

from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask


class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
"""
Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- Add static sin/cos computations.
"""

def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super(LlamaRotaryEmbedding, self).__init__() # Initialize nn.Module
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)

# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

freqs = torch.outer(t, self.inv_freq)

emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)


def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)

# Apply rotation
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
# Cast back to original dtype
return q_embed.to(q.dtype), k_embed.to(k.dtype)


class QEffLlamaAttention(LlamaAttention):
"""
Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args cache idx for the kv retention
"""

def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
# Define the general __qeff_init__() for any changes in the init calls
# Set the init in the module mapping pytorch transforms
self.__qeff_init__()

def __qeff_init__(self):
self.rotary_emb = QEffLlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -78,8 +165,18 @@ def forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
Expand Down
Loading

0 comments on commit 67922d7

Please sign in to comment.