From 7cb09a8eac6747fc15054968559cd728431a9f9c Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 5 Jul 2024 17:32:34 +0800 Subject: [PATCH] optimize qwen2 memory usage again (#11520) --- .../src/ipex_llm/transformers/models/qwen2.py | 68 +++++++------------ 1 file changed, 25 insertions(+), 43 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 779ccd4b5e7..de679a2266e 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -73,43 +73,6 @@ def qwen2_model_forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, -): - use_cache = use_cache if use_cache is not None else self.config.use_cache - input = input_ids if input_ids is not None else inputs_embeds - use_quantize_kv = ( - self.config.hidden_size != 3584 # disable quantize kv in specific model - and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input) - ) - if use_cache: - if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): - past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): - past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) - return qwen2_model_forward_internal( - self=self, - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - -def qwen2_model_forward_internal( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else \ self.config.output_attentions @@ -144,11 +107,21 @@ def qwen2_model_forward_internal( past_key_values_length = 0 + # ipex-llm changes start + # IPEX-LLM OPT: kv cache and quantize kv cache + inputs = input_ids if input_ids is not None else inputs_embeds + use_quantize_kv = ( + self.config.hidden_size != 3584 # disable quantize kv in specific model + and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs) + ) + if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) + if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_usable_length(seq_length) + # ipex-llm changes end if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -176,7 +149,15 @@ def qwen2_model_forward_internal( "the input. " ) - if self._attn_implementation == "flash_attention_2": + # ipex-llm changes start: don't generate `attention_mask` in specific cases + if seq_length == 1 or batch_size == 1 and use_sdp_causal( + seq_length, seq_length + past_key_values_length, + self.config.hidden_size // self.config.num_attention_heads, + inputs_embeds, self.training + ): + attention_mask = None + # ipex-llm changes end + elif self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None @@ -251,10 +232,11 @@ def qwen2_model_forward_internal( if output_hidden_states: all_hidden_states += (hidden_states,) + # ipex-llm changes start: remove `to_legacy_cache` next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else \ - next_decoder_cache + next_cache = next_decoder_cache + # ipex-llm changes end if not return_dict: return tuple(v for v in [hidden_states, next_cache,