Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sgwhat committed Nov 6, 2024
1 parent a95c128 commit c978885
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 5 deletions.
15 changes: 11 additions & 4 deletions python/llm/src/ipex_llm/transformers/npu_models/convert_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def convert_llama(
intra_pp=None,
transpose_value_cache=True,
):
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward
from ipex_llm.transformers.npu_models.llama_mp import gen_llama_fused_model_forward,\
gen_llama_32_fused_model_forward
from ipex_llm.transformers.npu_models.llama_mp import DecodeRunner, PrefillRunner
from transformers.models.llama.modeling_llama import LlamaModel

Expand All @@ -193,9 +194,15 @@ def convert_llama(
max_prompt_len=max_prompt_len,
transpose_value_cache=transpose_value_cache,
)
llama_model_forward = gen_llama_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
if model.config.num_hidden_layers == 28 or model.config.num_hidden_layers == 16:
# llama-3.2-3B & llama-3.2-1B
llama_model_forward = gen_llama_32_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
else:
llama_model_forward = gen_llama_fused_model_forward(
prefill_runner=prefill_runner, decode_runner=decode_runner
)
convert_forward(model, LlamaModel, llama_model_forward)
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from ipex_llm.transformers.npu_models.llama_mp import llama2_casullm_forward
Expand Down
117 changes: 116 additions & 1 deletion python/llm/src/ipex_llm/transformers/npu_models/llama_mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,121 @@ def llama_fused_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
msg = (
"You cannot specify both input_ids and inputs_embeds at the same time,"
" and must specify either one"
)
invalidInputError(False, msg)

if self.gradient_checkpointing and self.training and use_cache:
use_cache = False

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0

# ipex-llm changes start
from ipex_llm.transformers.npu_models.kv import DynamicFusedNormalCache

if use_cache and not isinstance(past_key_values, DynamicFusedNormalCache):
past_key_values = DynamicFusedNormalCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

if cache_position is None:
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
# ipex-llm changes end

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_seen_tokens
)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

seq_len = hidden_states.size(1)

if seq_len == 1:
layers_runner = decode_runner
else:
layers_runner = prefill_runner
layer_outputs = layers_runner.forward(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]

next_decoder_cache = layer_outputs[1]

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

# ipex-llm changes start
next_cache = next_decoder_cache if use_cache else None
# ipex-llm changes end
if not return_dict:
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

return llama_fused_model_forward


def gen_llama_32_fused_model_forward(prefill_runner, decode_runner):

def llama_32_fused_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if (input_ids is None) ^ (inputs_embeds is not None):
msg = (
"You cannot specify both input_ids and inputs_embeds at the same time,"
Expand Down Expand Up @@ -1131,7 +1246,7 @@ def llama_fused_model_forward(
attentions=all_self_attns,
)

return llama_fused_model_forward
return llama_32_fused_model_forward


def llama2_casullm_forward(
Expand Down

0 comments on commit c978885

Please sign in to comment.