Skip to content

Commit

Permalink
VALID SEQ
Browse files Browse the repository at this point in the history
  • Loading branch information
ssarkar2 committed Aug 27, 2024
1 parent e093881 commit ea553dc
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
generation_config.trust_remote_code = args.trust_remote_code
setattr(generation_config, 'valid_sequence_lengths', None)

return generation_config

Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def generate(
True if generation_config.flash_attention_fast_softmax else False
)
model_kwargs["num_virtual_tokens"] = num_virtual_tokens
model_kwargs["valid_sequence_lengths"] = generation_config.valid_sequence_lengths

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1] + num_virtual_tokens
Expand Down
37 changes: 27 additions & 10 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,12 @@ def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode)
def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, x):
softmax_mode, recompute_mode, valid_sequence_lengths, padding_side = x
print(valid_sequence_lengths, flush=True)
assert False
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side)



class Matmul(torch.nn.Module):
Expand Down Expand Up @@ -488,6 +492,7 @@ def pre_attn_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -621,27 +626,28 @@ def pre_attn_forward(
import habana_frameworks.torch.hpu as ht

softmax_mode = "fast" if flash_attention_fast_softmax else "None"

'''
def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side="left"):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side)
'''
if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):

attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, "None"
)
query_states, key_states, value_states, attention_mask, 0.0, False, None, (softmax_mode, False, valid_sequence_lengths, "None")) ####
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same length
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
query_states, key_states, value_states, None, 0.0, True, None, (softmax_mode, flash_attention_recompute, valid_sequence_lengths, "left"))
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
query_states, key_states, value_states, attention_mask, 0.0, False, None, (softmax_mode, flash_attention_recompute, valid_sequence_lengths, "left"))

else:
query_states, key_states, value_states, attention_mask = gaudi_llama_repeat_kv(
Expand Down Expand Up @@ -772,6 +778,7 @@ def pre_attn_forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
Expand All @@ -792,6 +799,7 @@ def pre_attn_forward(
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
valid_sequence_lengths,
cache_idx,
**kwargs,
)
Expand Down Expand Up @@ -837,6 +845,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
**kwargs,
Expand Down Expand Up @@ -870,6 +879,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
**kwargs,
Expand Down Expand Up @@ -905,6 +915,7 @@ def pre_attn(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
num_virtual_tokens: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -925,6 +936,7 @@ def pre_attn(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand All @@ -940,7 +952,7 @@ def post_attn_pre_mlp(self, hidden_states, residual):
residual.add_(hidden_states)
hidden_states = residual

hidden_states = self.post_attention_layernorm(hidden_states)
#hidden_states = self.post_attention_layernorm(hidden_states)

hidden_states = self.mlp.pre_mlp_forward(hidden_states)
return hidden_states, residual
Expand Down Expand Up @@ -1018,6 +1030,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1176,6 +1189,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
num_virtual_tokens=num_virtual_tokens,
)
Expand Down Expand Up @@ -1253,6 +1267,7 @@ def forward(
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
valid_sequence_lengths: torch.Tensor = None,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
num_virtual_tokens: int = None,
Expand Down Expand Up @@ -1285,6 +1300,7 @@ def forward(
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
valid_sequence_lengths=valid_sequence_lengths,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
num_virtual_tokens=num_virtual_tokens,
Expand Down Expand Up @@ -1394,6 +1410,7 @@ def prepare_inputs_for_generation(
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"valid_sequence_lengths": kwargs.get("valid_sequence_lengths"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
"num_virtual_tokens": kwargs.get("num_virtual_tokens"),
Expand Down

0 comments on commit ea553dc

Please sign in to comment.