diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ff94cce66b..e3da8547c6 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -2271,9 +2271,7 @@ def _sample( next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2) next_token_scores = logits_processor(input_ids, next_token_logits) else: - # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration - # (the clone itself is always small) - next_token_logits = outputs.logits[:, -1, :].clone() + next_token_logits = outputs.logits[:, -1, :] if token_idx is not None and self.config.is_encoder_decoder: # case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size] next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)