diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 149fc58b2d..a43ebf6375 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1878,9 +1878,10 @@ def _greedy_search( and not model_kwargs.get("reuse_cache", False) and bucket_internal ): - # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # Pad the returned past key values tensors from prefill phase forward run to maximum length # before starting the decode phase. - self._pad_past_key_values(model_kwargs) + if outputs.past_key_values[0][0].shape[2] == model_inputs["input_ids"].shape[1]: + self._pad_past_key_values(model_kwargs) model_kwargs["pad_done"] = True if ( @@ -2297,9 +2298,10 @@ def _sample( and not model_kwargs.get("reuse_cache", False) and bucket_internal ): - # Pad the returned pask key values tensors from prefill phase forward run to maximum length + # Pad the returned past key values tensors from prefill phase forward run to maximum length # before starting the decode phase. - self._pad_past_key_values(model_kwargs) + if outputs.past_key_values[0][0].shape[2] == model_inputs["input_ids"].shape[1]: + self._pad_past_key_values(model_kwargs) model_kwargs["pad_done"] = True if (