Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update NeMo/Megatron #302

Open
wants to merge 38 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0930452
bump commit hash
sichu2023 Oct 10, 2024
a7e7b36
update ESM2TEDotProductAttention
sichu2023 Oct 10, 2024
e44f931
bump te version
sichu2023 Oct 10, 2024
db1b6e5
update 3rd party commit hash
sichu2023 Oct 12, 2024
2907fc7
remove extra state from load_weights_sharded_inplace_nemo2_to_mcore
sichu2023 Oct 12, 2024
ff5b98a
mark test_mixin_strategy_contract_get_loss_reduction xfail
sichu2023 Oct 12, 2024
de2e9c2
switch Megatron-LM commit hash to main
sichu2023 Oct 14, 2024
ff12671
update 3rd party commit hash
sichu2023 Oct 30, 2024
622a914
reuse TEDotProductAttention __init__
sichu2023 Oct 31, 2024
40540f1
fix moe_token_dispatcher_type when variable_seq_lengths
sichu2023 Oct 31, 2024
5b0eca6
fix rotary_pos_emb get_rotary_seq_len call
sichu2023 Oct 31, 2024
c826b86
revert te version
sichu2023 Oct 31, 2024
279aa74
Revert "reuse TEDotProductAttention __init__"
sichu2023 Nov 1, 2024
454271e
support cp_comm_type in ESM2TEDotProductAttention __init__
sichu2023 Nov 1, 2024
184ccde
pump NeMo/Megatron/TE commit hash
sichu2023 Nov 1, 2024
b9615c5
fix test_tokenizer_serialization
sichu2023 Nov 4, 2024
69e2899
update iomixin test - nemo only captures non-default arguments to __i…
sichu2023 Nov 4, 2024
70f9707
increase limit_val_batches and val_check_interval to avoid duplicated…
sichu2023 Nov 4, 2024
c7d0b6b
add checkpoint callback to every mode in stop-and-go-test for nemo up…
sichu2023 Nov 4, 2024
44189e4
Revert "update iomixin test - nemo only captures non-default argument…
sichu2023 Nov 5, 2024
9a0c976
add notes on IOMixin behavior
sichu2023 Nov 5, 2024
32136de
Revert "increase limit_val_batches and val_check_interval to avoid du…
sichu2023 Nov 5, 2024
cb8856f
add step in checkpoint_dir to avoid name clashing
sichu2023 Nov 5, 2024
2b0c87a
fix test_main_runs for esm2 and geneformer
sichu2023 Nov 5, 2024
4645768
update test_iomixin_utils.py
sichu2023 Nov 5, 2024
b1ef6c6
revert ModelCheckpoint move in stopandgo
pstjohn Nov 5, 2024
522e849
disable ckpt_async_save
sichu2023 Nov 6, 2024
7eb1e8a
use trainer.should_stop to interrupt training, remove uneven checks i…
pstjohn Nov 7, 2024
35c9bb0
drop every_n_train_steps in ModelCheckpoint
sichu2023 Nov 7, 2024
c9a421f
update min_lr in esm2 scheduler
sichu2023 Nov 7, 2024
9159a25
bump nemo version
sichu2023 Nov 7, 2024
430c228
mark validation stop and go test xfail
sichu2023 Nov 7, 2024
ced60e6
drop every_n_train_steps
sichu2023 Nov 7, 2024
f1621e5
update geneformer output tolerance
sichu2023 Nov 7, 2024
a7e4f39
ruff
sichu2023 Nov 7, 2024
e53a5ac
bump megatron version
sichu2023 Nov 7, 2024
1fd15ba
disable ckpt_async_save
sichu2023 Nov 8, 2024
fb81158
revert to original loss thresholds from main
pstjohn Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 576 files
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated 633 files
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ RUN git clone https://github.com/NVIDIA/apex.git && \
--config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam --group_norm"

# Transformer Engine pre-1.7.0. 1.7 standardizes the meaning of bits in the attention mask to match
ARG TE_COMMIT=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG TE_COMMIT=c27ee60ec746210bcea4ec33958dbbff06706506
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attempting to split this as a separate PR in #399

RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \
cd TransformerEngine && \
git fetch origin ${TE_COMMIT} && \
Expand Down
1 change: 0 additions & 1 deletion docs/docs/user-guide/examples/bionemo-esm2/pretrain.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=True,
monitor="val_loss",
save_top_k=1,
every_n_train_steps=100,
always_save_context=True,
)

Expand Down
2 changes: 1 addition & 1 deletion scripts/gpt-pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def main() -> None:
devices, seq_length = 1, 2048

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, pipeline_dtype=torch.float32
tensor_model_parallel_size=1, pipeline_model_parallel_size=1, pipeline_dtype=torch.float32, ckpt_async_save=False,
)
trainer = nl.Trainer(
devices=devices,
Expand Down
5 changes: 3 additions & 2 deletions scripts/protein/esm2/esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def main(
ddp="megatron",
find_unused_parameters=True,
ckpt_include_optimizer=True,
ckpt_async_save=False,
)

# for wandb integration
Expand Down Expand Up @@ -243,10 +244,10 @@ def main(
# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
monitor=metric_to_monitor_for_checkpoints,
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{step}-{" + metric_to_monitor_for_checkpoints + ":.2f}",
)

# Setup the logger and train the model
Expand Down
4 changes: 1 addition & 3 deletions scripts/protein/esm2/test_esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from bionemo.testing import megatron_parallel_state_utils


@pytest.mark.skip("duplicate unittest")
@pytest.fixture
def dummy_protein_dataset(tmp_path):
"""Create a mock protein dataset."""
Expand Down Expand Up @@ -62,7 +61,6 @@ def dummy_protein_dataset(tmp_path):
return db_file


@pytest.mark.skip("duplicate unittest")
@pytest.fixture
def dummy_parquet_train_val_inputs(tmp_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that you're at it you may want to consider importing these fixtures from bionemo.testing

"""Create a mock protein train and val cluster parquet."""
Expand Down Expand Up @@ -104,7 +102,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
result_dir=result_dir,
wandb_project=None,
wandb_offline=True,
num_steps=55,
num_steps=10,
warmup_steps=5,
limit_val_batches=1,
val_check_interval=1,
Expand Down
43 changes: 30 additions & 13 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
from pkg_resources import packaging
from megatron.core.utils import get_te_version, is_te_min_version
from torch import Tensor


__all__: Sequence[str] = ("ESM2DotProductAttention", "ESM2TEDotProductAttention")

from megatron.core.extensions.transformer_engine import _te_version


class ESM2TEDotProductAttention(TEDotProductAttention):
"""ESM2-Specific transformer engine core attention.
Expand All @@ -52,6 +50,10 @@ def __init__(
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: float | None = None,
softmax_scale: float = 1.0,
k_channels: int | None = None,
v_channels: int | None = None,
cp_comm_type: str = "p2p",
):
"""Initialize ESM2TEDotProductAttention."""
self.config = config
Expand All @@ -67,30 +69,35 @@ def __init__(
)

extra_kwargs = {}
if _te_version >= packaging.version.Version("0.11.0"):
if is_te_min_version("0.11.0"):
extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
elif self.config.num_query_groups != self.config.num_attention_heads:
raise ValueError(
f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, "
f"use a newer version of Transformer Engine. "
f"(num_query_groups ({self.config.num_query_groups}) != "
f"num_attention_heads ({self.config.num_attention_heads}))"
)

if _te_version >= packaging.version.Version("0.10.0"):
if is_te_min_version("0.10.0"):
extra_kwargs["attention_type"] = attention_type
# older version don't need attention_type

if _te_version > packaging.version.Version("0.12.0"):
if is_te_min_version("0.12.0", check_equality=False):
self.te_forward_mask_type = True

# Only Transformer-Engine version >= 1.0.0 supports context parallelism
if _te_version >= packaging.version.Version("1.0.0"):
if is_te_min_version("1.0.0"):
if getattr(TEDotProductAttention, "cp_stream") is None:
TEDotProductAttention.cp_stream = torch.cuda.Stream()
extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)
extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
if is_te_min_version("1.10.0"):
if cp_comm_type is None:
extra_kwargs["cp_comm_type"] = "p2p"
else:
extra_kwargs["cp_comm_type"] = cp_comm_type
else:
assert (
self.config.context_parallel_size == 1
Expand All @@ -106,23 +113,33 @@ def __init__(

if config.window_size is not None:
# Check version
assert _te_version >= packaging.version.Version("1.2.0"), (
f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support"
"sliding window attention."
assert is_te_min_version("1.2.0"), (
f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" "sliding window attention."
)
extra_kwargs["window_size"] = config.window_size

if is_te_min_version("1.10.0"):
# TE 1.10.0 introduces the ability to set the different k and v channels
kv_channels = (
(k_channels, v_channels)
if k_channels is not None and v_channels is not None
else self.config.kv_channels
)
else:
kv_channels = self.config.kv_channels

extra_kwargs["softmax_scale"] = softmax_scale
sichu2023 marked this conversation as resolved.
Show resolved Hide resolved

super(TEDotProductAttention, self).__init__(
num_attention_heads=self.config.num_attention_heads,
kv_channels=self.config.kv_channels,
kv_channels=kv_channels,
attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
attn_mask_type=attn_mask_type.name,
sequence_parallel=self.config.sequence_parallel,
tp_size=self.config.tensor_model_parallel_size,
get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
tp_group=get_tensor_model_parallel_group(check_initialized=False),
layer_number=layer_number,
softmax_scale=1.0, # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing
**extra_kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ def train_model(
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=True,
save_on_train_epoch_end=True,
monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss",
every_n_train_steps=n_steps_train // 2,
monitor="val_loss",
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
)

Expand Down
11 changes: 10 additions & 1 deletion sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,16 @@ class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]):

def __post_init__(self):
# TODO, as a validator?
"""Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
"""Check configuration compatibility."""
# reset moe_token_dispatcher_type when variable_seq_lengths is True.
# must be performed before super().__post_init__()
if self.variable_seq_lengths and self.moe_token_dispatcher_type in ["allgather", "alltoall_seq"]:
logging.warning(
"MoE token dispatcher type 'allgather' and 'alltoall_seq' are not supported with variable sequence lengths. Setting moe_token_dispatcher_type to 'alltoall'."
)
self.moe_token_dispatcher_type = "alltoall"

# reset apply_query_key_layer_scaling based on biobert_spec_option
super().__post_init__()
if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
self.apply_query_key_layer_scaling = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_tokenize_with_empty_string(tokenizer):


def test_tokenizer_serialization(tokenizer, tmp_path):
tokenizer.io_dump(tmp_path / "tokenizer")
tokenizer.io_dump(tmp_path / "tokenizer", yaml_attrs=[]) # BioNeMoESMTokenizer takes no __init__ arguments
deserialized_tokenizer = io.load(tmp_path / "tokenizer", tokenizer.__class__)

our_tokens = deserialized_tokenizer.encode("K A <mask> I S Q", add_special_tokens=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataMo
adam_beta2=0.98,
),
lr_scheduler=WarmupAnnealDecayHoldScheduler(
warmup_steps=50, max_steps=cls.num_steps, max_lr=cls.lr, min_lr=cls.lr / 10.0, anneal_percentage=0.10
warmup_steps=50, max_steps=cls.num_steps, max_lr=cls.lr, min_lr=0.0, anneal_percentage=0.10
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,7 @@ def loss_reduction_class(self) -> Type[MegatronLossReduction]:
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=True,
save_on_train_epoch_end=True,
monitor="reduced_train_loss",
every_n_train_steps=25,
monitor="val_loss",
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ def run_finetune(checkpoint_dir: str, name: str, directory_name: str):
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=True,
save_on_train_epoch_end=True,
monitor="reduced_train_loss",
every_n_train_steps=25,
monitor="val_loss",
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def _train_model_get_ckpt(
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=True,
save_on_train_epoch_end=True,
monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss",
every_n_train_steps=5,
monitor="val_loss",
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
# async_save=False, # Tries to save asynchronously, previously led to race conditions.
filename="{epoch}-{step}-{val_loss:.2f}",
)
save_dir = root_dir / name
tb_logger = TensorBoardLogger(save_dir=save_dir, name=name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ def main(
# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
monitor=metric_to_monitor_for_checkpoints,
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{step}-{" + metric_to_monitor_for_checkpoints + ":.2f}",
)

# Setup the logger and train the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test_bionemo2_rootdir():
assert data_path.is_dir(), "Test data directory is supposed to be a directory."


@pytest.mark.skip("duplicate unittest")
def test_main_runs(tmpdir):
result_dir = Path(tmpdir.mkdir("results"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -821,8 +821,7 @@ def _train_model_get_ckpt(
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=True,
save_on_train_epoch_end=True,
monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss",
every_n_train_steps=n_steps_train // 2,
monitor="val_loss",
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
)
save_dir = root_dir / name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,11 @@ def forward(
rotary_pos_emb = None
if self.position_embedding_type == "rope":
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
inference_params, self.encoder, encoder_input, self.config
inference_params,
self.encoder,
encoder_input,
self.config,
packed_seq_params=None, # TODO @sichu: upstream to Megatron-LM
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)

Expand Down
1 change: 1 addition & 0 deletions sub-packages/bionemo-llm/src/bionemo/llm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def setup_trainer(
ddp="megatron",
find_unused_parameters=True,
ckpt_include_optimizer=True,
ckpt_async_save=False,
)
if callbacks is None:
callbacks = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def load_weights_sharded_inplace_nemo2_to_mcore(
sharded_state_dict = {
_munge_key_megatron_to_nemo2(k): _munge_sharded_tensor_key_megatron_to_nemo2(v)
for k, v in model.sharded_state_dict().items()
if not _key_in_filter(k, skip_keys_with_these_prefixes)
if not _key_in_filter(k, skip_keys_with_these_prefixes) and "_extra_state" not in k
}
dist_checkpointing.load(
sharded_state_dict=sharded_state_dict,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def forward(self, x):
return self.other(x)


# TODO rewrite unittest and potentially LightningPassthroughPredictionMixin
@pytest.mark.xfail(reason="MegatronStrategy no longer has '_get_loss_reduction' attribute")
def test_mixin_strategy_contract_get_loss_reduction():
with megatron_parallel_state_utils.clean_parallel_state_context():
strategy = nl.MegatronStrategy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ class OverrideModelDataClass2(BaseDataClass, iom.IOMixinWithGettersSetters):


class TestIOMixin:
"""TestCase on IOMixin.

Notes:
IOMixin only captures non-default __init__ arguments into self.__io__ to ensure no compatibility in loading older mcore config in newer versions.
"""

def test_dataclasses_two_versions(self):
_ = OverrideModelDataClass1(b=2)
v1 = OverrideModelDataClass2(b=4)
Expand Down Expand Up @@ -80,10 +86,10 @@ def test_dataclass_out_of_sync(self):
with pytest.raises(KeyError):
v1.get_hparam("q")

# Make sure we can get all hyper-parameters that are not defaultfactory objects
assert v1.get_hparams() == {"b": 7, "c": 3}
# Make sure we can get all hyper-parameters that are **non-default** non-defaultfactory objects
assert v1.get_hparams() == {"b": 7}

# Make sure by default we can change botht he hyper-parameter and the attribute.
# Make sure by default we can change both the hyper-parameter and the attribute.
v1_copy.set_hparam("b", 8)
assert v1_copy.b == 8
assert v1_copy.get_hparam("b") == 8
Expand All @@ -92,8 +98,8 @@ def test_dataclass_hparam_modify_parent_default(self):
v1 = OverrideModelDataClass1()
v1.set_hparam("a", 7)
assert v1.a == 7
# Make sure we can get all hyper-parameters
assert v1.get_hparams() == {"a": 7, "b": 3, "c": 3}
# Make sure we can get all **non-default** **non-defaultfactory** hyper-parameters
assert v1.get_hparams() == {"a": 7}

v1_copy = io.reinit(v1)
assert v1_copy.a == 7, "V1 should re-initialize with the updated hyper-parameter."
Expand Down
Loading
Loading