Skip to content

Commit

Permalink
Add missing condtion check in tensor creation in greedy search (#1289)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeonsily authored Aug 23, 2024
1 parent dfeb8a2 commit 1548fe3
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,8 @@ def _sample(
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
if not ignore_eos:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

bucket_size = model_kwargs.get("bucket_size", -1)
Expand Down

0 comments on commit 1548fe3

Please sign in to comment.