Skip to content

Commit

Permalink
Custom SDPA in attention
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz committed Jan 2, 2025
1 parent 8145cda commit 079bab3
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion extension/llm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torchtune.modules.attention as TorchTuneAttention
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
from executorch.extension.llm.custom_ops import custom_ops
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache
Expand Down Expand Up @@ -146,6 +147,7 @@ def __init__(
# Use flex attention if supported and we are sample packing
self._attention_call = _sdpa_or_flex_attention()
self._sdpa = SDPA(
max_seq_len=self.max_seq_len,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
head_dim=self.head_dim,
Expand Down Expand Up @@ -310,7 +312,7 @@ def false_fn(y):
self.kv_cache.v_cache.copy_(v)
self.kv_cache.cache_pos.copy_(cache_pos)

output = self._sdpa(q, k, v, b, s_x, mask=mask)
output = self._sdpa(q, k, v, b, s_x, mask=mask, input_pos=input_pos)
return self.output_proj(output)


Expand All @@ -322,6 +324,7 @@ class SDPA(nn.Module):

def __init__(
self,
max_seq_len: int,
num_kv_heads: int,
num_heads: int,
head_dim: int,
Expand All @@ -331,6 +334,7 @@ def __init__(
kv_cache,
) -> None:
super().__init__()
self.max_seq_len = max_seq_len
self.num_kv_heads = num_kv_heads
self.num_heads = num_heads
self.head_dim = head_dim
Expand All @@ -348,7 +352,23 @@ def forward(
bsz: int,
seq_len: int,
mask: Optional[_MaskType] = None,
# Below args are only used for ET custom sdpa op.
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
start_pos = input_pos[0][-1].item() - seq_len + 1
torch._check_is_size(start_pos)
torch._check(start_pos <= self.max_seq_len)
output = torch.ops.llama.custom_sdpa(
q,
k,
v,
start_pos,
None, # Attention mask
0, # dropout probability. Ignored by the code
True, # is_causal TODO: flip to false if kv cache is enabled???
)
return output.view(bsz, seq_len, -1)

# View + expand + reshape bring num_kv_heads to num_heads for k and v
# to match q.

Expand Down

0 comments on commit 079bab3

Please sign in to comment.