Skip to content

Commit

Permalink
Fix facebook/hf-seamless-m4t-medium crash (#1433)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi authored Oct 30, 2024
1 parent ea92e8b commit 309e0c4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/text-to-speech/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ python3 run_pipeline.py \
```
Models that have been validated:
- [microsoft/speecht5_tts](https://huggingface.co/microsoft/speecht5_tts)
- [facebook/hf-seamless-m4t-medium](https://huggingface.co/facebook/hf-seamless-m4t-medium)
- [facebook/mms-tts-eng](https://huggingface.co/facebook/mms-tts-eng)
17 changes: 13 additions & 4 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,14 +237,20 @@ def _prepare_decoder_input_ids_for_generation(
if token_idx is None:
decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
else:
max_length = max_new_tokens + 2 if max_new_tokens is not None else self.generation_config.max_length
decoder_input_ids_len = decoder_input_ids.shape[-1]
max_length = (
max_new_tokens + decoder_input_ids_len + 1
if max_new_tokens is not None
else self.generation_config.max_length
)
if max_length != decoder_start_token_id.shape[-1]:
decoder_start_token_id = torch.nn.functional.pad(
decoder_start_token_id,
(0, max_length - decoder_start_token_id.shape[-1]),
value=pad_token_id,
)
decoder_input_ids = decoder_start_token_id.index_copy(1, token_idx, decoder_input_ids)
decoder_start_token_id[:, 1 : 1 + decoder_input_ids_len, ...] = decoder_input_ids
decoder_input_ids = decoder_start_token_id
token_idx.add_(1)
if "decoder_attention_mask" in model_kwargs:
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
Expand Down Expand Up @@ -1103,11 +1109,14 @@ def generate(
)
else:
assert generation_config.bucket_size <= 0, "Untested path for bucket>0"
token_idx = 1
if model_kwargs.get("decoder_input_ids", None) is None:
token_idx = 1
else:
token_idx = model_kwargs["decoder_input_ids"].shape[-1]
model_kwargs["token_idx"] = torch.tensor(token_idx, device=inputs_tensor.device)
if model_kwargs.get("decoder_attention_mask", None) is None and generation_config.use_cache:
max_length = (
generation_config.max_new_tokens + 1
generation_config.max_new_tokens + token_idx
if generation_config.max_new_tokens is not None
else generation_config.max_length
)
Expand Down

0 comments on commit 309e0c4

Please sign in to comment.