diff --git a/extension/llm/modules/attention.py b/extension/llm/modules/attention.py index 60183801b4..695e5efa72 100644 --- a/extension/llm/modules/attention.py +++ b/extension/llm/modules/attention.py @@ -10,6 +10,7 @@ import torch import torchtune.modules.attention as TorchTuneAttention from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache +from executorch.extension.llm.custom_ops import custom_ops from torch import nn from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention from torchtune.modules.kv_cache import KVCache @@ -146,6 +147,7 @@ def __init__( # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() self._sdpa = SDPA( + max_seq_len=self.max_seq_len, num_kv_heads=self.num_kv_heads, num_heads=self.num_heads, head_dim=self.head_dim, @@ -310,7 +312,7 @@ def false_fn(y): self.kv_cache.v_cache.copy_(v) self.kv_cache.cache_pos.copy_(cache_pos) - output = self._sdpa(q, k, v, b, s_x, mask=mask) + output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos) return self.output_proj(output) @@ -322,6 +324,7 @@ class SDPA(nn.Module): def __init__( self, + max_seq_len: int, num_kv_heads: int, num_heads: int, head_dim: int, @@ -331,6 +334,7 @@ def __init__( kv_cache, ) -> None: super().__init__() + self.max_seq_len = max_seq_len self.num_kv_heads = num_kv_heads self.num_heads = num_heads self.head_dim = head_dim @@ -348,7 +352,23 @@ def forward( bsz: int, seq_len: int, mask: Optional[_MaskType] = None, + # Below args are only used for ET custom sdpa op. + input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: + start_pos = input_pos[0][-1].item() - seq_len + 1 + torch._check_is_size(start_pos) + torch._check(start_pos <= self.max_seq_len) + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + start_pos, + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal TODO: flip to false if kv cache is enabled??? + ) + return output.view(bsz, seq_len, -1) + # View + expand + reshape bring num_kv_heads to num_heads for k and v # to match q.