Skip to content

Commit

Permalink
pick removeslice on top of prepinp change
Browse files Browse the repository at this point in the history
  • Loading branch information
ssarkar2 committed Aug 29, 2024
1 parent 19b4d4e commit ff91438
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
11 changes: 10 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ def generate(
else:
synced_gpus = False


# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
Expand Down Expand Up @@ -839,6 +840,8 @@ def generate(
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
inputs_tensor_ORIG = inputs_tensor
attention_mask_ORIG = model_kwargs["attention_mask"]
batch_size = inputs_tensor.shape[0]

device = inputs_tensor.device
Expand Down Expand Up @@ -972,6 +975,7 @@ def generate(
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")


if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)

Expand Down Expand Up @@ -1307,6 +1311,8 @@ def generate(
profiling_steps=profiling_steps,
hb_gen_time=hb_gen_time,
profiling_record_shapes=profiling_record_shapes,
inputs_tensor_ORIG=inputs_tensor_ORIG,
attention_mask_ORIG=attention_mask_ORIG,
**model_kwargs,
)

Expand Down Expand Up @@ -2123,6 +2129,8 @@ def _sample(
profiling_steps: Optional[int] = 0,
hb_gen_time: Optional[HabanaGenerationtime] = None,
profiling_record_shapes: Optional[bool] = False,
inputs_tensor_ORIG=None,
attention_mask_ORIG=None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2248,7 +2256,8 @@ def _sample(

# prepare model inputs
#model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs = self.prep_inp(input_ids, **model_kwargs)
model_inputs = self.prep_inp(input_ids, inputs_tensor_ORIG, attention_mask_ORIG, **model_kwargs)
# model_inputs = self.prepare_inputs_for_generation(input_ids, inputs_tensor_ORIG, attention_mask_ORIG, **model_kwargs)

# prepare variable output controls (note: some models won't accept all output controls)
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
Expand Down
14 changes: 10 additions & 4 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,8 @@ def __init__(self):
def forward(
self,
input_ids,
inputs_tensor_ORIG,
attention_mask_ORIG,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
Expand All @@ -987,8 +989,8 @@ def forward(
elif (reuse_cache or bucket_internal) and token_idx is not None:
# KV cache is pre allocated with reuse cache or will be padded with bucket internal
# hence for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
input_ids = inputs_tensor_ORIG#input_ids[:, :token_idx]
attention_mask = attention_mask_ORIG#attention_mask[:, :token_idx]

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
Expand Down Expand Up @@ -1412,6 +1414,8 @@ def forward(
def prepare_inputs_for_generation(
self,
input_ids,
inputs_tensor_ORIG,
attention_mask_ORIG,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
Expand All @@ -1423,6 +1427,7 @@ def prepare_inputs_for_generation(
):
reuse_cache = kwargs.get("reuse_cache")
bucket_internal = kwargs.get("bucket_internal")
#breakpoint()
if past_key_values is not None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
Expand All @@ -1436,8 +1441,9 @@ def prepare_inputs_for_generation(
elif (reuse_cache or bucket_internal) and token_idx is not None:
# KV cache is pre allocated with reuse cache or will be padded with bucket internal
# hence for the 1st token we can slice the inputs till token idx for the fwd pass.
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]
#breakpoint()
input_ids = inputs_tensor_ORIG# input_ids[:, :token_idx]
attention_mask = attention_mask_ORIG#attention_mask[:, :token_idx]

if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
Expand Down

0 comments on commit ff91438

Please sign in to comment.