Skip to content

Commit

Permalink
optimize qwen2 memory usage again (#11520)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jul 5, 2024
1 parent 8f376e5 commit 7cb09a8
Showing 1 changed file with 25 additions and 43 deletions.
68 changes: 25 additions & 43 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,43 +73,6 @@ def qwen2_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input)
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
return qwen2_model_forward_internal(
self=self,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)


def qwen2_model_forward_internal(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else \
self.config.output_attentions
Expand Down Expand Up @@ -144,11 +107,21 @@ def qwen2_model_forward_internal(

past_key_values_length = 0

# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
)

if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
# ipex-llm changes end

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
Expand Down Expand Up @@ -176,7 +149,15 @@ def qwen2_model_forward_internal(
"the input. "
)

if self._attn_implementation == "flash_attention_2":
# ipex-llm changes start: don't generate `attention_mask` in specific cases
if seq_length == 1 or batch_size == 1 and use_sdp_causal(
seq_length, seq_length + past_key_values_length,
self.config.hidden_size // self.config.num_attention_heads,
inputs_embeds, self.training
):
attention_mask = None
# ipex-llm changes end
elif self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and
0 in attention_mask) else None
Expand Down Expand Up @@ -251,10 +232,11 @@ def qwen2_model_forward_internal(
if output_hidden_states:
all_hidden_states += (hidden_states,)

# ipex-llm changes start: remove `to_legacy_cache`
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \
next_decoder_cache
next_cache = next_decoder_cache
# ipex-llm changes end

if not return_dict:
return tuple(v for v in [hidden_states, next_cache,
Expand Down

0 comments on commit 7cb09a8

Please sign in to comment.