Skip to content

Commit

Permalink
Use flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
jloveric committed Dec 10, 2023
1 parent 4c40191 commit 1491c48
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions language_interpolation/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,15 @@ def forward(
kth = kt.reshape(kt.shape[0], kt.shape[1], self.heads, -1)
vth = vt.reshape(vt.shape[0], vt.shape[1], self.heads, -1)

res = F.scaled_dot_product_attention(
query=qth, key=kth, value=vth, attn_mask=None
)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=True
):
res = F.scaled_dot_product_attention(
query=qth, key=kth, value=vth, attn_mask=None
)

# Used built in attention so I can get optimization
# qkh = torch.nn.functional.softmax(torch.einsum('blhd,brhd->blrh',qth,kth), dim=3)
# res = torch.einsum('blrh,brhd->blhd',qkh, vth)
Expand Down

0 comments on commit 1491c48

Please sign in to comment.