Skip to content

Commit

Permalink
Remove clone() to improve performance (#1290)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonsily committed Aug 23, 2024
1 parent 144569c commit e093881
Showing 1 changed file with 1 addition and 3 deletions.
4 changes: 1 addition & 3 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,9 +2271,7 @@ def _sample(
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
next_token_scores = logits_processor(input_ids, next_token_logits)
else:
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
next_token_logits = outputs.logits[:, -1, :]
if token_idx is not None and self.config.is_encoder_decoder:
# case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size]
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
Expand Down

0 comments on commit e093881

Please sign in to comment.