From ea553dca5c6e8d05861f0de03b050013d9e88a95 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Tue, 27 Aug 2024 18:24:45 +0000 Subject: [PATCH] VALID SEQ --- examples/text-generation/utils.py | 1 + .../habana/transformers/generation/utils.py | 1 + .../models/llama/modeling_llama.py | 37 ++++++++++++++----- 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c535acba0a..fa64fcfec6 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -577,6 +577,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + setattr(generation_config, 'valid_sequence_lengths', None) return generation_config diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index e3da8547c6..82ac689003 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1082,6 +1082,7 @@ def generate( True if generation_config.flash_attention_fast_softmax else False ) model_kwargs["num_virtual_tokens"] = num_virtual_tokens + model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] + num_virtual_tokens diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 7d41126390..60c19a9b90 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -351,8 +351,12 @@ def __init__(self, fusedSDPA): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA - def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode) + def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, x): + softmax_mode, recompute_mode, valid_sequence_lengths, padding_side = x + print(valid_sequence_lengths, flush=True) + assert False + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side) + class Matmul(torch.nn.Module): @@ -488,6 +492,7 @@ def pre_attn_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -621,27 +626,28 @@ def pre_attn_forward( import habana_frameworks.torch.hpu as ht softmax_mode = "fast" if flash_attention_fast_softmax else "None" - + ''' + def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side="left"): + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side) + ''' if q_len == 1: # next token use_recompute = True if os.getenv("QUANT_CONFIG", "") else False with ht.sdp_kernel(enable_recompute=use_recompute): + attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, "None" - ) + query_states, key_states, value_states, attention_mask, 0.0, False, None, (softmax_mode, False, valid_sequence_lengths, "None")) #### else: # first token if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length with ht.sdp_kernel(enable_recompute=flash_attention_recompute): attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) + query_states, key_states, value_states, None, 0.0, True, None, (softmax_mode, flash_attention_recompute, valid_sequence_lengths, "left")) else: with ht.sdp_kernel(enable_recompute=flash_attention_recompute): attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) + query_states, key_states, value_states, attention_mask, 0.0, False, None, (softmax_mode, flash_attention_recompute, valid_sequence_lengths, "left")) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -772,6 +778,7 @@ def pre_attn_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -792,6 +799,7 @@ def pre_attn_forward( flash_attention_recompute, flash_attention_causal_mask, flash_attention_fast_softmax, + valid_sequence_lengths, cache_idx, **kwargs, ) @@ -837,6 +845,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -870,6 +879,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, **kwargs, @@ -905,6 +915,7 @@ def pre_attn( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, num_virtual_tokens: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -925,6 +936,7 @@ def pre_attn( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -940,7 +952,7 @@ def post_attn_pre_mlp(self, hidden_states, residual): residual.add_(hidden_states) hidden_states = residual - hidden_states = self.post_attention_layernorm(hidden_states) + #hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp.pre_mlp_forward(hidden_states) return hidden_states, residual @@ -1018,6 +1030,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1176,6 +1189,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -1253,6 +1267,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1285,6 +1300,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, lazy_mode=lazy_mode, num_virtual_tokens=num_virtual_tokens, @@ -1394,6 +1410,7 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), + "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"),