From 5b1eb7ddb36f508dfda7d07c6408c3714a6bb3f6 Mon Sep 17 00:00:00 2001 From: Sayantan Sarkar Date: Wed, 28 Aug 2024 17:05:45 -0700 Subject: [PATCH] valid seq --- examples/text-generation/run_generation.py | 6 ++ examples/text-generation/utils.py | 1 + .../habana/transformers/generation/utils.py | 1 + .../models/llama/modeling_llama.py | 81 +++++++++++++++---- 4 files changed, 72 insertions(+), 17 deletions(-) diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 0a16543c2a..75f050b273 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -417,6 +417,12 @@ def generate(size=None, reduce_recompile=False): max_length=args.max_input_tokens, truncation=True, ) + def compute_valid_sequence_lengths_tensor(input_tokens): + attn_mask = input_tokens["attention_mask"] + return torch.sum(attn_mask, dim=1) + + valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device) + generation_config.valid_sequence_lengths = valid_sequence_lengths else: input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True) encode_duration = time.perf_counter() - encode_t0 diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c535acba0a..fa64fcfec6 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -577,6 +577,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax generation_config.trust_remote_code = args.trust_remote_code + setattr(generation_config, 'valid_sequence_lengths', None) return generation_config diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index e3da8547c6..82ac689003 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1082,6 +1082,7 @@ def generate( True if generation_config.flash_attention_fast_softmax else False ) model_kwargs["num_virtual_tokens"] = num_virtual_tokens + model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths if not self.config.is_encoder_decoder: calculated_max_length = input_ids.shape[-1] + num_virtual_tokens diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 7d41126390..f7e66a38c9 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -347,12 +347,17 @@ def gaudi_llama_repeat_kv( # FusedScaledDotProductAttention class ModuleFusedSDPA(torch.nn.Module): - def __init__(self, fusedSDPA): + def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() self._hpu_kernel_fsdpa = fusedSDPA + self.scale = scale + self.attention_dropout = attention_dropout + self.enable_recompute = enable_recompute + self.flash_attention_fp8 = flash_attention_fp8 - def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode): - return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode) + def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, x): + softmax_mode, recompute_mode, valid_sequence_lengths, padding_side = x + return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side) class Matmul(torch.nn.Module): @@ -412,7 +417,11 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() - self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA, scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=True,) if FusedSDPA else None if hasattr(config, "fused_qkv") and config.fused_qkv: self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // self.num_heads @@ -488,6 +497,7 @@ def pre_attn_forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -624,24 +634,51 @@ def pre_attn_forward( if q_len == 1: # next token - use_recompute = True if os.getenv("QUANT_CONFIG", "") else False - with ht.sdp_kernel(enable_recompute=use_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, "None" - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + ("None", + False, + None, + "None",) + ) else: + softmax_mode = "fast" if flash_attention_fast_softmax else "None" # first token if flash_attention_causal_mask: # causal masking on first token requires inputs to be of the same length - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, None, 0.0, True, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + None, + (softmax_mode, + flash_attention_recompute, + valid_sequence_lengths, + "left",) + ) else: - with ht.sdp_kernel(enable_recompute=flash_attention_recompute): - attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode - ) + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + (softmax_mode, + flash_attention_recompute, + None, + "None",) + ) else: query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv( @@ -837,6 +874,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, **kwargs, @@ -870,6 +908,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, **kwargs, @@ -905,6 +944,7 @@ def pre_attn( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, cache_idx: int = None, num_virtual_tokens: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -925,6 +965,7 @@ def pre_attn( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -1018,6 +1059,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1157,6 +1199,7 @@ def forward( flash_attention_recompute, flash_attention_causal_mask, flash_attention_fast_softmax, + valid_sequence_lengths, None, ) else: @@ -1176,6 +1219,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, num_virtual_tokens=num_virtual_tokens, ) @@ -1253,6 +1297,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, cache_idx: int = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, @@ -1285,6 +1330,7 @@ def forward( flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, cache_idx=cache_idx, lazy_mode=lazy_mode, num_virtual_tokens=num_virtual_tokens, @@ -1394,6 +1440,7 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), + "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"),