Skip to content

Commit

Permalink
enable flash and mem attention modules by default
Browse files Browse the repository at this point in the history
  • Loading branch information
thayeral committed Jan 13, 2025
1 parent c19fc43 commit e06a44f
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions src/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def forward(self, x, return_attention=False):
k = self.k_norm(k)

try:
if return_attention:
raise NotImplementedError

with torch.backends.cuda.sdp_kernel():
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True, enable_math=False):
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.att_drop.p if self.training else 0.,
Expand Down

0 comments on commit e06a44f

Please sign in to comment.