Skip to content

Commit

Permalink
valid seq
Browse files Browse the repository at this point in the history
  • Loading branch information
ssarkar2 committed Aug 29, 2024
1 parent 52e22cb commit 5b1eb7d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 17 deletions.
6 changes: 6 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,12 @@ def generate(size=None, reduce_recompile=False):
max_length=args.max_input_tokens,
truncation=True,
)
def compute_valid_sequence_lengths_tensor(input_tokens):
attn_mask = input_tokens["attention_mask"]
return torch.sum(attn_mask, dim=1)

valid_sequence_lengths = compute_valid_sequence_lengths_tensor(input_tokens).to(args.device)
generation_config.valid_sequence_lengths = valid_sequence_lengths
else:
input_tokens = tokenizer.batch_encode_plus(input_sentences, return_tensors="pt", padding=True)
encode_duration = time.perf_counter() - encode_t0
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
setattr(generation_config, 'valid_sequence_lengths', None)

return generation_config

Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def generate(
True if generation_config.flash_attention_fast_softmax else False
)
model_kwargs["num_virtual_tokens"] = num_virtual_tokens
model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1] + num_virtual_tokens
Expand Down
81 changes: 64 additions & 17 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,17 @@ def gaudi_llama_repeat_kv(

# FusedScaledDotProductAttention
class ModuleFusedSDPA(torch.nn.Module):
def __init__(self, fusedSDPA):
def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA
self.scale = scale
self.attention_dropout = attention_dropout
self.enable_recompute = enable_recompute
self.flash_attention_fp8 = flash_attention_fp8

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)
def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, x):
softmax_mode, recompute_mode, valid_sequence_lengths, padding_side = x
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side)


class Matmul(torch.nn.Module):
Expand Down Expand Up @@ -412,7 +417,11 @@ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
self.matmul_av = Matmul()
self.k_cache = KVCache()
self.v_cache = KVCache()
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if FusedSDPA else None
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA, scale=self.norm_factor,
attention_dropout=self.attention_dropout,
enable_recompute=False,
flash_attention_fp8=True,) if FusedSDPA else None
if hasattr(config, "fused_qkv") and config.fused_qkv:
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // self.num_heads
Expand Down Expand Up @@ -488,6 +497,7 @@ def pre_attn_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -624,24 +634,51 @@ def pre_attn_forward(

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, "None"
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
("None",
False,
None,
"None",)
)
else:
softmax_mode = "fast" if flash_attention_fast_softmax else "None"
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
None,
0.0,
True,
None,
(softmax_mode,
flash_attention_recompute,
valid_sequence_lengths,
"left",)
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
attn_output = self.fused_scaled_dot_product_attention(
query_states,
key_states,
value_states,
attention_mask,
0.0,
False,
None,
(softmax_mode,
flash_attention_recompute,
None,
"None",)
)

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down Expand Up @@ -837,6 +874,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -870,6 +908,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
Expand Down Expand Up @@ -905,6 +944,7 @@ def pre_attn(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: Optional[torch.Tensor] = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -925,6 +965,7 @@ def pre_attn(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -1018,6 +1059,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1157,6 +1199,7 @@ def forward(
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
valid_sequence_lengths,
None,
)
else:
Expand All @@ -1176,6 +1219,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -1253,6 +1297,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1285,6 +1330,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
num_virtual_tokens=num_virtual_tokens,
Expand Down Expand Up @@ -1394,6 +1440,7 @@ def prepare_inputs_for_generation(
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"valid_sequence_lengths": kwargs.get("valid_sequence_lengths"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
Expand Down

0 comments on commit 5b1eb7d

Please sign in to comment.