Skip to content

Commit

Permalink
Added CB support for Codegen, falcon, gpt2, gpt-j, mpt (quic#99)
Browse files Browse the repository at this point in the history
* Rebase with main

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* Format change

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* adding causal_lm testing for CB I

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* Updated causalLm test to use same set of inputs

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* Updated causallm test method for CB

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* Changed Num of Cores in causal_llm test

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* making cloud tests compatible with CB

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* fixed conflicts in tests

Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>

* Updated supported CB model in readme

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* Changed end of line sequence for some test modules

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* Changed end of line sequence for some test module

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

* Added AWQ test model back

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>

---------

Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com>
Signed-off-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>
Co-authored-by: Abukhoyer Shaik <quic_abukhoye@quicinc.com>
  • Loading branch information
quic-rishinr and abukhoy authored Sep 20, 2024
1 parent afb4645 commit 22a6b36
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 28 deletions.
8 changes: 8 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
CodeGenForCausalLM,
CodeGenModel,
)
Expand Down Expand Up @@ -58,6 +59,7 @@

from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
Expand Down Expand Up @@ -111,6 +113,11 @@
Qwen2ForCausalLM.__name__,
Phi3ForCausalLM.__name__,
PhiForCausalLM.__name__,
CodeGenForCausalLM.__name__,
GPT2LMHeadModel.__name__,
GPTJForCausalLM.__name__,
MptForCausalLM.__name__,
FalconForCausalLM.__name__,
]
)
# Create an instance of the named tuple
Expand Down Expand Up @@ -158,6 +165,7 @@
CodeGenAttention: QEffCodeGenAttention,
CodeGenModel: QEffCodeGenModel,
CodeGenForCausalLM: QEffCodeGenForCausalLM,
CodeGenBlock: QeffCodeGenBlock,
# Mistral model layers
MistralAttention: QEffMistralAttention,
MistralDecoderLayer: QEffMistralDecoderLayer,
Expand Down
65 changes: 63 additions & 2 deletions QEfficient/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
CodeGenForCausalLM,
CodeGenModel,
apply_rotary_pos_emb,
Expand Down Expand Up @@ -78,6 +79,7 @@ def forward(
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
batch_index: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Expand Down Expand Up @@ -126,7 +128,10 @@ def forward(
if layer_past is not None:
# Update the cache_kwargs with position_ids for Cloud AI 100
past_key_value = layer_past
cache_kwargs = {"position_ids": position_ids}
cache_kwargs = {
"position_ids": position_ids,
"batch_index": batch_index,
}
pkv = DynamicCache()
pkv.key_cache.append(past_key_value[0])
pkv.value_cache.append(past_key_value[1])
Expand Down Expand Up @@ -167,6 +172,7 @@ def forward(
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
Expand Down Expand Up @@ -279,6 +285,17 @@ def forward(
use_cache,
output_attentions,
)
elif batch_index is not None:
outputs = block(
hidden_states=hidden_states,
layer_past=layer_past,
batch_index=batch_index,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
)
else:
outputs = block(
hidden_states=hidden_states,
Expand Down Expand Up @@ -331,6 +348,7 @@ def forward(
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
Expand All @@ -351,6 +369,7 @@ def forward(
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
batch_index=batch_index,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -393,3 +412,45 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


class QeffCodeGenBlock(CodeGenBlock):
# Ignore copy

def forward(
self,
hidden_states: Optional[torch.FloatTensor],
layer_past: Optional[Cache] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
batch_index=batch_index,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]

feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_output + feed_forward_hidden_states + residual

if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]

return outputs # hidden_states, present, (attentions)
98 changes: 96 additions & 2 deletions QEfficient/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@
import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconDecoderLayer,
FalconForCausalLM,
FalconModel,
_prepare_4d_causal_attention_mask_for_sdpa,
apply_rotary_pos_emb,
build_alibi_tensor,
dropout_add,
logger,
)

Expand Down Expand Up @@ -85,6 +87,7 @@ def forward(
attention_mask: torch.Tensor,
alibi: Optional[torch.Tensor],
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
Expand Down Expand Up @@ -116,7 +119,10 @@ def forward(

if layer_past is not None:
past_key_value = layer_past
cache_kwargs = {"position_ids": position_ids}
cache_kwargs = {
"position_ids": position_ids,
"batch_index": batch_index,
}
pkv = DynamicCache()
pkv.key_cache.append(past_key_value[0])
pkv.value_cache.append(past_key_value[1])
Expand Down Expand Up @@ -227,6 +233,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -373,6 +380,18 @@ def forward(
use_cache,
output_attentions,
)
elif batch_index is not None:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
batch_index=batch_index,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
else:
outputs = block(
hidden_states,
Expand Down Expand Up @@ -422,6 +441,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
Expand All @@ -445,6 +465,7 @@ def forward(
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
batch_index=batch_index,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down Expand Up @@ -482,3 +503,76 @@ def forward(
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)


class QEffFalconDecoderLayer(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
residual = hidden_states

if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2:
attention_layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
attention_layernorm_out = self.input_layernorm(hidden_states)

# Self attention.
attn_outputs = self.self_attention(
attention_layernorm_out,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
batch_index=batch_index,
alibi=alibi,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)

attention_output = attn_outputs[0]

if not self.config.new_decoder_architecture:
if self.config.parallel_attn:
mlp_layernorm_out = attention_layernorm_out
else:
residual = dropout_add(
attention_output, residual, self.config.attention_dropout, training=self.training
)
mlp_layernorm_out = self.post_attention_layernorm(residual)

if (
self.config.new_decoder_architecture
and self.config.parallel_attn
and self.config.num_ln_in_parallel_attn == 1
):
mlp_layernorm_out = attention_layernorm_out

outputs = attn_outputs[1:]

# MLP.
mlp_output = self.mlp(mlp_layernorm_out)

if self.config.new_decoder_architecture or self.config.parallel_attn:
mlp_output += attention_output

output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)

if use_cache:
outputs = (output,) + outputs
else:
outputs = (output,) + outputs[1:]

return outputs # hidden_states, past_kv, attentions
21 changes: 20 additions & 1 deletion QEfficient/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
Expand All @@ -58,6 +59,7 @@ def forward(
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
batch_index=batch_index,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
Expand Down Expand Up @@ -117,6 +119,7 @@ def forward(
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
Expand All @@ -133,6 +136,7 @@ def forward(
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
batch_index=batch_index,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
Expand Down Expand Up @@ -192,6 +196,7 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -330,6 +335,19 @@ def forward(
use_cache,
output_attentions,
)
elif batch_index is not None:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
position_ids=position_ids,
batch_index=batch_index,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
else:
outputs = block(
hidden_states,
Expand Down Expand Up @@ -434,6 +452,7 @@ def forward(
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
Expand All @@ -460,7 +479,7 @@ def forward(
if layer_past is not None:
# Added for optimized GPT Attention for AI 100 KV Retention
# Update the cache_kwargs with position_ids for Cloud AI 100
cache_kwargs = {"position_ids": position_ids}
cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index}
pkv = DynamicCache()
pkv.key_cache.append(layer_past[0])
pkv.value_cache.append(layer_past[1])
Expand Down
Loading

0 comments on commit 22a6b36

Please sign in to comment.