Skip to content

Commit

Permalink
Modernize MosaicBERT
Browse files Browse the repository at this point in the history
  • Loading branch information
Skylion007 committed Jan 2, 2024
1 parent 7003793 commit 20d3725
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 42 deletions.
4 changes: 3 additions & 1 deletion examples/benchmarks/bert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ def main(cfg: DictConfig,
load_path=cfg.get('load_path', None),
load_weights_only=cfg.get('load_weights_only', False),
python_log_level=cfg.get('python_log_level', None),
)
autoresume=cfg.get('autoresume', None),
fsdp_config=cfg.get('fsdp_config', None),
compile_config=cfg.get('compile_config', None))

print('Logging config...')
log_config(cfg)
Expand Down
10 changes: 5 additions & 5 deletions examples/benchmarks/bert/requirements-cpu.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
einops==0.5.0
torch==1.13.1
mosaicml[nlp,wandb]>=0.14.0,<0.15
mosaicml-streaming==0.4.1
omegaconf==2.2.3
transformers==4.28.1
torch==2.1.1
composer[nlp,wandb]>=0.17.0,<0.18
mosaicml-streaming<=0.7
omegaconf==2.3.0
transformers==4.36.2
14 changes: 8 additions & 6 deletions examples/benchmarks/bert/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
einops==0.5.0
torch==1.13.1
mosaicml[nlp,wandb]>=0.14.0,<0.15
mosaicml-streaming==0.4.1
omegaconf==2.2.3
transformers==4.28.1
torch==2.1.1
composer[nlp,wandb]>=0.17.0,<0.18
mosaicml-streaming<=0.7
omegaconf==2.3.0
transformers== 4.36.2
# need a newer version of FA2
flash_attn>=2.4.2
# need a newer version of triton
triton==2.0.0.dev20221103
#triton==2.0.0.dev20221103
96 changes: 74 additions & 22 deletions examples/benchmarks/bert/src/bert_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,26 @@
SequenceClassifierOutput)
from transformers.models.bert.modeling_bert import BertPreTrainedModel

IMPL_USE_FLASH2 = False
try:
import flash_attn_triton as flash_attn_triton
flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
import importlib

from flash_attn import flash_attn_qkvpacked_func
installed_version = importlib.metadata.version('flash_attn')
if installed_version < '2.4.2':
raise ImportError('newer version of flash_attn required (>= 2.4.2)')
IMPL_USE_FLASH2 = True
except ImportError as e:
flash_attn_qkvpacked_func = None
warnings.warn(
f'Failed to import flash_attn. Will try to import triton implementation: {e}',
stacklevel=2)
try:
import flash_attn_triton as flash_attn_triton
flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func
except ImportError as e:
flash_attn_qkvpacked_func = None
warnings.warn(f'Failed to import flash_attn_triton as a fallback: {e}',
stacklevel=2)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -183,7 +198,8 @@ def __init__(self, config):

def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
max_seqlen_in_batch: int, indices: torch.Tensor,
attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
attn_mask: torch.Tensor, bias: torch.Tensor,
slopes: torch.Tensor) -> torch.Tensor:
"""Perform self-attention.
If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch
Expand All @@ -201,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
indices: (total_nnz,)
attn_mask: (batch, max_seqlen_in_batch)
bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
slopes: (heads) or (batch, heads)
Returns:
attention: (total_nnz, dim)
Expand All @@ -213,7 +230,8 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
'b s (t h d) -> b s t h d',
t=3,
h=self.num_attention_heads)
if self.p_dropout or flash_attn_qkvpacked_func is None:
if (not IMPL_USE_FLASH2 and
self.p_dropout) or flash_attn_qkvpacked_func is None:
# if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch
q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d
k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s
Expand All @@ -226,19 +244,41 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor,
attention = torch.matmul(attention_probs, v).permute(0, 2, 1,
3) # b s h d
else:
# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
if convert_dtype:
# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
bias_dtype = bias.dtype
bias = bias.to(torch.float16)
attention = flash_attn_qkvpacked_func(qkv, bias)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
if IMPL_USE_FLASH2:
assert 1 <= len(slopes.shape) <= 2, f'{slopes=}'
assert slopes.shape[
-1] == self.num_attention_heads, f'{slopes=}'

# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
if convert_dtype:
# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
bias_dtype = bias.dtype
bias = bias.to(torch.float16)

attention = flash_attn_qkvpacked_func(
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
else:
attention = flash_attn_qkvpacked_func(
qkv, dropout_p=self.p_dropout, alibi_slopes=slopes)
else:
attention = flash_attn_qkvpacked_func(qkv, bias)
# Triton implementation only supports 0 attention dropout
convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16]
if convert_dtype:
# Triton implementation only supports fp16 and bf16
orig_dtype = qkv.dtype
qkv = qkv.to(torch.float16)
bias_dtype = bias.dtype
bias = bias.to(torch.float16)
attention = flash_attn_qkvpacked_func(qkv, bias)
attention = attention.to(orig_dtype)
bias = bias.to(bias_dtype)
else:
attention = flash_attn_qkvpacked_func(qkv, bias)

# attn_mask is 1 for attend and 0 for don't
attention = bert_padding_module.unpad_input_only(
Expand Down Expand Up @@ -291,6 +331,7 @@ def forward(
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for scaled self-attention without padding.
Expand All @@ -303,9 +344,11 @@ def forward(
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen_in_batch)
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
slopes: None or (batch, heads) or (heads,)
"""
assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}'
self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
attn_mask, bias)
attn_mask, bias, slopes)
if subset_idx is not None:
return self.output(
bert_padding_module.index_first_axis(self_output, subset_idx),
Expand Down Expand Up @@ -379,6 +422,7 @@ def forward(
indices: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
slopes: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass for a BERT layer, including both attention and MLP.
Expand All @@ -391,9 +435,12 @@ def forward(
indices: None or (total_nnz,)
attn_mask: None or (batch, max_seqlen_in_batch)
bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch)
slopes: None or (batch, heads) or (heads,)
"""
assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}'
attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
subset_idx, indices, attn_mask, bias)
subset_idx, indices, attn_mask, bias,
slopes)
layer_output = self.mlp(attention_output)
return layer_output

Expand Down Expand Up @@ -463,6 +510,7 @@ def get_slopes_power_of_2(n_heads: int) -> List[float]:
relative_position = relative_position.unsqueeze(0).expand(
n_heads, -1, -1)
slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device)
self.slopes = slopes
alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position
# [1, n_heads, max_token_length, max_token_length]
alibi = alibi.unsqueeze(0)
Expand Down Expand Up @@ -504,6 +552,7 @@ def forward(
elif self.alibi.device != hidden_states.device:
# Device catch-up
self.alibi = self.alibi.to(hidden_states.device)
self.slopes = self.slopes.to(hidden_states.device)
alibi_bias = self.alibi[:, :, :seqlen, :seqlen]
attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen]
alibi_attn_mask = attn_bias + alibi_bias
Expand All @@ -517,7 +566,8 @@ def forward(
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
bias=alibi_attn_mask,
slopes=self.slopes)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
# Pad inputs and mask. It will insert back zero-padded tokens.
Expand All @@ -536,7 +586,8 @@ def forward(
None,
indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
bias=alibi_attn_mask,
slopes=self.slopes)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
subset_idx = torch.nonzero(subset_mask[attention_mask_bool],
Expand All @@ -547,7 +598,8 @@ def forward(
subset_idx=subset_idx,
indices=indices,
attn_mask=attention_mask,
bias=alibi_attn_mask)
bias=alibi_attn_mask,
slopes=self.slopes)

if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
Expand Down
8 changes: 0 additions & 8 deletions examples/benchmarks/bert/src/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ class StreamingTextDataset(StreamingDataset):
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
`False``.
keep_raw (bool): Whether to keep or delete the decompressed form (or only form)
of shards after all their samples have been yielded this epoch. If ``False``, keep iff
remote is local or no remote and no compression. Defaults to ``True``.
samples_per_epoch (int, optional): Provide this field iff you are weighting sub-datasets
proportionally. Defaults to ``None``.
predownload (int, optional): Target number of samples ahead to download the shards of while
Expand Down Expand Up @@ -99,7 +96,6 @@ def __init__(self,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
keep_raw: bool = True,
samples_per_epoch: Optional[int] = None,
predownload: int = 100_000,
partition_algo: str = 'orig',
Expand Down Expand Up @@ -140,7 +136,6 @@ def __init__(self,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_zip=keep_zip,
keep_raw=keep_raw,
samples_per_epoch=samples_per_epoch,
predownload=predownload,
partition_algo=partition_algo,
Expand Down Expand Up @@ -266,8 +261,6 @@ def build_text_dataloader(
cfg.dataset.get('validate_hash', None),
keep_zip=stream.get('keep_zip', None) or
cfg.dataset.get('keep_zip', False),
keep_raw=stream.get('keep_raw', None) or
cfg.dataset.get('keep_raw', True),
))

# build dataset potentially with streams
Expand All @@ -282,7 +275,6 @@ def build_text_dataloader(
download_timeout=cfg.dataset.get('download_timeout', 60),
validate_hash=cfg.dataset.get('validate_hash', None),
keep_zip=cfg.dataset.get('keep_zip', False),
keep_raw=cfg.dataset.get('keep_raw', True),
samples_per_epoch=cfg.dataset.get('samples_per_epoch', None),
predownload=cfg.dataset.get('predownload', 100_000),
partition_algo=cfg.dataset.get('partition_algo', 'orig'),
Expand Down

0 comments on commit 20d3725

Please sign in to comment.