diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 10f0359a2..4b156eda9 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -120,6 +120,7 @@ def __init__( self._original_model = self.model.clone() # keep original model for serialization self._pkv_precision = Type.f32 self.next_beam_idx = None + self._past_length = 0 self.update_pkv_precision() if self.is_dynamic: self.model = self._reshape(self.model, -1, -1) @@ -356,19 +357,14 @@ def prepare_inputs( position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> Dict: - if self.use_cache and past_key_values is not None: - input_ids = input_ids[:, -1:] - batch_size = input_ids.shape[0] if self.config.model_type == "bloom": batch_size *= self.config.num_attention_heads inputs = {} - past_len = 0 if not self.stateful: if past_key_values is not None: if self.config.model_type not in MULTI_QUERY_ATTN_MODELS: - past_len = past_key_values[0][1].shape[-2] if self._pkv_precision == Type.bf16: # numpy does not support bf16, pretending f16, should change to bf16 past_key_values = tuple( @@ -381,8 +377,6 @@ def prepare_inputs( past_key_values = tuple( past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer ) - else: - past_len = past_key_values[0].shape[-2] # Add the past_key_values to the decoder inputs inputs = dict(zip(self.key_value_input_names, past_key_values)) @@ -411,6 +405,8 @@ def prepare_inputs( # Set initial value for the next beam_idx input that will be used at the current iteration # and will be optionally updated by _reorder_cache at the next iterations if beam_search is used self.next_beam_idx = np.arange(batch_size, dtype=int) + self._past_length = 0 + past_len = self._get_past_length(past_key_values) inputs["input_ids"] = np.array(input_ids) # Add the attention_mask inputs when needed @@ -432,7 +428,7 @@ def prepare_inputs( position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 if past_key_values: - position_ids = np.expand_dims(position_ids[:, -1], axis=-1) + position_ids = position_ids[:, -input_ids.shape[1] :] inputs["position_ids"] = position_ids @@ -470,6 +466,7 @@ def forward( # the first condition at the function beginning above. # It should be something that is not None and it should be True when converted to Boolean. past_key_values = ((),) + self._past_length += input_ids.shape[1] if not self.stateful: if self.use_cache: @@ -485,19 +482,32 @@ def forward( return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values) - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation + # Adapted from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly attention_mask = kwargs.get("attention_mask", None) use_cache = kwargs.get("use_cache", None) + if past_key_values is not None: + past_len = self._get_past_length(past_key_values) + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_len) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_len < input_ids.shape[1]: + input_ids = input_ids[:, past_len:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: + if attention_mask is not None and position_ids is None and "position_ids" in self.input_names: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, @@ -507,6 +517,24 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg "attention_mask": attention_mask, } + def _get_past_length(self, past_key_values=None): + if past_key_values is None: + return 0 + if self.stateful: + return self._past_length + if self.config.model_type in MULTI_QUERY_ATTN_MODELS: + return past_key_values[0].shape[-2] + seq_length_dim = -2 + if self.config.model_type == "chatglm": + seq_length_dim = 0 + elif self.config.model_type == "qwen": + seq_length_dim = 1 + # input is tuple of pairs + if isinstance(past_key_values[0], (tuple, list)): + return past_key_values[0][1].shape[seq_length_dim] + # past key values comes after flattening + return past_key_values[1].shape[seq_length_dim] + # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache def _reorder_cache( self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor @@ -573,10 +601,6 @@ def _from_pretrained( model_type = config.model_type.replace("_", "-") if model_type == "bloom": init_cls = OVBloomForCausalLM - elif model_type == "mpt": - init_cls = OVMPTForCausalLM - elif model_type == "opt": - init_cls = OVOPTForCausalLM elif model_type == "gpt-bigcode": init_cls = OVGPTBigCodeForCausalLM else: @@ -630,22 +654,12 @@ def _from_pretrained( class OVBloomForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - # only last token for input_ids if past is not None if past_key_values and not self.stateful: # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: past_key_values = self._convert_to_bloom_cache(past_key_values) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } + return super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, **kwargs) # Adapted from transformers.models.bloom.modeling_bloom.BloomForCausalLM._reorder_cache def _reorder_cache( @@ -712,36 +726,6 @@ def _convert_to_standard_cache( ) -class OVOPTForCausalLM(OVModelForCausalLM): - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - -class OVMPTForCausalLM(OVModelForCausalLM): - # Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): - attention_mask = kwargs.get("attention_mask", None) - use_cache = kwargs.get("use_cache", None) - - return { - "input_ids": input_ids, - "past_key_values": past_key_values, - "use_cache": use_cache, - "position_ids": None, - "attention_mask": attention_mask, - } - - class OVGPTBigCodeForCausalLM(OVModelForCausalLM): # Adapted from transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTBigCodeForCausalLM._reorder_cache def _reorder_cache( diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 65094ae22..f54305113 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -632,6 +632,11 @@ def test_multiple_inputs(self, model_arch): outputs = model.generate(**tokens, generation_config=generation_config) self.assertIsInstance(outputs, torch.Tensor) self.assertEqual(outputs.shape[0], 3) + # test that generation result is reproducible + outputs2 = model.generate(**tokens, generation_config=generation_config) + self.assertIsInstance(outputs2, torch.Tensor) + self.assertEqual(outputs2.shape[0], 3) + self.assertTrue(torch.allclose(outputs2, outputs)) del model gc.collect()