Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix geneformer training instability bug #421

Merged
merged 14 commits into from
Nov 12, 2024
Merged
8 changes: 6 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ RUN <<EOF
EOF

# Transformer engine attention defaults
ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0
# FIXME the following result in unstable training curves even if they are faster
# see https://github.com/NVIDIA/bionemo-framework/pull/421
#ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0

FROM dev AS development

Expand Down Expand Up @@ -207,4 +209,6 @@ RUN chmod 777 -R /workspace/bionemo2/

# Transformer engine attention defaults
# We have to declare this again because the devcontainer splits from the release image's base.
ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0
# FIXME the following results in unstable training curves even if faster.
# See https://github.com/NVIDIA/bionemo-framework/pull/421
#ENV NVTE_FUSED_ATTN=1 NVTE_FLASH_ATTN=0
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ export MY_DATA_SOURCE="pbss"

```bash
# The fastest transformer engine environment variables in testing were the following two
export NVTE_FUSED_ATTN=1
export NVTE_FLASH_ATTN=0

TEST_DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source $MY_DATA_SOURCE); \
ESM2_650M_CKPT=$(download_bionemo_data esm2/650m:2.0 --source $MY_DATA_SOURCE); \
python \
Expand Down Expand Up @@ -248,9 +245,6 @@ and DataModule types.
> ⚠️ **Warning:** This setup does NO configuration of Weights and Biases. Edit your config JSON and populate it with your WandB details.

```
export NVTE_FUSED_ATTN=1
export NVTE_FLASH_ATTN=0

bionemo-esm2-train \
--data-config-t bionemo.esm2.run.config_models.ESM2DataConfig \
--model-config-t bionemo.esm2.run.config_models.ExposedESM2PretrainConfig \
Expand Down
3 changes: 0 additions & 3 deletions docs/docs/user-guide/examples/bionemo-esm2/pretrain.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,6 @@ llm.train(
Or simply call `esm2_pretrain.py` directly.
```bash
# Enable fused attention in transformer engine for speed-up
export NVTE_FUSED_ATTN=1
export NVTE_FLASH_ATTN=0

DATA_DIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source ngc)

python scripts/protein/esm2/esm2_pretrain.py \
Expand Down
7 changes: 0 additions & 7 deletions scripts/protein/esm2/test_esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
result_dir = Path(tmpdir.mkdir("results"))

with megatron_parallel_state_utils.distributed_model_parallel_state():
monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
main(
train_cluster_path=train_cluster_path,
train_database_path=dummy_protein_dataset,
Expand Down Expand Up @@ -159,8 +157,6 @@ def test_val_dataloader_in_main_runs_with_limit_val_batches(
result_dir = Path(tmpdir.mkdir("results"))

with megatron_parallel_state_utils.distributed_model_parallel_state():
monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
main(
train_cluster_path=train_cluster_path,
train_database_path=dummy_protein_dataset,
Expand Down Expand Up @@ -239,9 +235,6 @@ def test_pretrain_cli(tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inp
# a local copy of the environment
env = dict(**os.environ)
env["MASTER_PORT"] = str(open_port)
env["NVTE_FUSED_ATTN"] = "1"
env["NVTE_FLASH_ATTN"] = "0"

cmd = shlex.split(cmd_str)
result = subprocess.run(
cmd,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class TestGeneformerStopAndGo(stop_and_go.StopAndGoHarness):
limit_val_batches: int = 2
lr: float = 1e-4
precision: Literal["16-mixed", "bf16-mixed", "32"] = MODEL_PRECISION
train_val_output_atol: float = 2e-2

@override
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,7 @@ def configure_model(self, tokenizer: AutoTokenizer) -> MegatronBioBertModelType:
self.num_layers // p_size
) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."

# The local specs all require the standard full attention mask. For transformer engine only the NVTE_FLASH_ATTN=0
# option requires this full attention mask.
# The local specs all require the standard full attention mask.
use_full_attention_mask: bool = "transformer_engine" not in self.biobert_spec_option
do_next_sentence = False
if self.model_cls is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class StopAndGoHarness(ABC):
limit_val_batches: int
lr: float = 1e-4
precision: Literal["16-mixed", "bf16-mixed", "32"]
train_val_output_atol: float = 1e-3
other_output_atol: float = 1e-4

# class variables that will be setup in setUpClass
tempdir: tempfile.TemporaryDirectory
Expand Down Expand Up @@ -336,9 +338,9 @@ def test_stop_and_go_consistency(self, callback_type):
assert interrupted_callback.data, f"No data found for {callback_type}"

if callback_type == testing_callbacks.TrainOutputCallback:
atol = 1e-3
atol = self.train_val_output_atol
else:
atol = 1e-4
atol = self.other_output_atol

recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol)

Expand Down Expand Up @@ -388,8 +390,8 @@ def test_stop_and_go_consistency_with_uneven_validation_sizes(self, callback_typ
interrupted_data = interrupted_callback.data[-len(continuous_callback.data) :]

if callback_type == testing_callbacks.ValidOutputCallback:
atol = 1e-3
atol = self.train_val_output_atol
else:
atol = 1e-4
atol = self.other_output_atol

recursive_assert_approx_equal(interrupted_data, continuous_callback.data, atol=atol)
Loading