Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump 3rdparty/Megatron-LM from 0bda578 to b3375a0 #215

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 150 files
98 changes: 15 additions & 83 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,25 @@
# limitations under the License.


import os
from typing import Callable, Optional, Sequence, Union

import torch
from megatron.core import parallel_state, tensor_parallel
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import get_cuda_rng_tracker
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from pkg_resources import packaging
from torch import Tensor


__all__: Sequence[str] = ("ESM2DotProductAttention", "ESM2TEDotProductAttention")

from megatron.core.extensions.transformer_engine import _te_version


class ESM2TEDotProductAttention(TEDotProductAttention):
"""ESM2-Specific transformer engine core attention.

Override the softmax_scale to 1.0 to match the ESM2 implementation while keeping the rest from the original TEDotProductAttention.
Override the softmax_scale default to 1.0 to match the ESM2 implementation.
"""

def __init__(
Expand All @@ -51,79 +41,21 @@ def __init__(
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float | None = None,
attention_dropout: Optional[float] = None,
softmax_scale: float = 1.0,
k_channels: Optional[int] = None,
v_channels: Optional[int] = None,
):
"""Initialize ESM2TEDotProductAttention."""
self.config = config
self.te_forward_mask_type = False
self.qkv_format: str = "sbhd"

if self.config.apply_query_key_layer_scaling != bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))):
raise ValueError(
f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
f"setting query key layer scaling via argument, so these two must match."
)

extra_kwargs = {}
if _te_version >= packaging.version.Version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)

if _te_version >= packaging.version.Version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type

if _te_version > packaging.version.Version("0.12.0"):
self.te_forward_mask_type = True

# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if _te_version >= packaging.version.Version("1.0.0"):
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
else:
assert (
self.config.context_parallel_size == 1
), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"

if self.config.deterministic_mode:
if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
raise RuntimeError(
"deterministic_mode is on and we are using DotProductAttention from "
"Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
)

if config.window_size is not None:
# Check version
assert _te_version >= packaging.version.Version("1.2.0"), (
f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support"
"sliding window attention."
)
extra_kwargs["window_size"] = config.window_size

super(TEDotProductAttention, self).__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
softmax_scale=1.0, # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing
**extra_kwargs,
"""Initialize ESM2TEDotProductAttention with softmax_scale default to 1.0."""
super().__init__(
config,
layer_number,
attn_mask_type,
attention_type,
attention_dropout,
softmax_scale,
k_channels,
v_channels,
)


Expand Down
Loading