Skip to content

Commit

Permalink
Cast to bfloat16 if possible
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 committed Jan 4, 2024
1 parent 6a8d85b commit 4cbdec1
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions examples/benchmarks/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import os
import sys
import warnings
from functools import lru_cache
from typing import List, Optional, Tuple, Union

# Add folder root to path to allow us to use relative imports regardless of what directory the script is run from
Expand Down Expand Up @@ -79,6 +80,13 @@
logger = logging.getLogger(__name__)


@lru_cache
def _get_half_dtype() -> torch.dtype:
if torch.cuda.is_bf16_supported():
return torch.bfloat16
return torch.float16


class BertEmbeddings(nn.Module):
"""Construct the embeddings for words, ignoring position.
Expand Down Expand Up @@ -250,13 +258,15 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
assert slopes.shape[
-1] == self.num_attention_heads, f'{slopes=}'

convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
# FA2 implementation only supports fp16 and bf16
# If FA2 is supported, bfloat16 must be supported
# as of FA2 2.4.2. (Turing GPUs not supported)
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
qkv = qkv.to(torch.bfloat16)
bias_dtype = bias.dtype
bias = bias.to(torch.float16)
bias = bias.to(torch.bfloat16)

attention = flash_attn_qkvpacked_func(
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
Expand All @@ -267,13 +277,15 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
else:
# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
half = _get_half_dtype()

# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
qkv = qkv.to(half)
bias_dtype = bias.dtype
bias = bias.to(torch.float16)
bias = bias.to(half)
attention = flash_attn_qkvpacked_func(qkv, bias)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
Expand Down

0 comments on commit 4cbdec1

Please sign in to comment.