Skip to content

Commit

Permalink
Reviewer suggestions & import refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 committed Jan 2, 2024
1 parent b809a7b commit dbe8d64
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/benchmarks/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
sys.path.append(os.path.dirname(os.path.realpath(__file__)))

import importlib

import bert_padding as bert_padding_module
import torch
import torch.nn as nn
Expand All @@ -56,7 +58,6 @@

IMPL_USE_FLASH2 = False
try:
import importlib

from flash_attn import flash_attn_qkvpacked_func
installed_version = importlib.metadata.version('flash_attn')
Expand Down Expand Up @@ -249,10 +250,9 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
assert slopes.shape[
-1] == self.num_attention_heads, f'{slopes=}'

# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
if convert_dtype:
# Triton implementation only supports fp16 and bf16
# FA2 implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
bias_dtype = bias.dtype
Expand Down

0 comments on commit dbe8d64

Please sign in to comment.