Skip to content

Commit

Permalink
fix test_main_runs for esm2 and geneformer
Browse files Browse the repository at this point in the history
  • Loading branch information
sichu2023 committed Nov 5, 2024
1 parent abbf4bd commit 42138d2
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 9 deletions.
4 changes: 2 additions & 2 deletions scripts/protein/esm2/esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ def main(
# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
monitor=metric_to_monitor_for_checkpoints,
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{step}-{val_loss:.2f}",
filename="{epoch}-{step}-{" + metric_to_monitor_for_checkpoints + ":.2f}",
)

# Setup the logger and train the model
Expand Down
4 changes: 1 addition & 3 deletions scripts/protein/esm2/test_esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from bionemo.testing import megatron_parallel_state_utils


@pytest.mark.skip("duplicate unittest")
@pytest.fixture
def dummy_protein_dataset(tmp_path):
"""Create a mock protein dataset."""
Expand Down Expand Up @@ -62,7 +61,6 @@ def dummy_protein_dataset(tmp_path):
return db_file


@pytest.mark.skip("duplicate unittest")
@pytest.fixture
def dummy_parquet_train_val_inputs(tmp_path):
"""Create a mock protein train and val cluster parquet."""
Expand Down Expand Up @@ -104,7 +102,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra
result_dir=result_dir,
wandb_project=None,
wandb_offline=True,
num_steps=55,
num_steps=10,
warmup_steps=5,
limit_val_batches=1,
val_check_interval=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _train_model_get_ckpt(
every_n_train_steps=5,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
# async_save=False, # Tries to save asynchronously, previously led to race conditions.
filename="{epoch}-{step}-{val_loss:.2f}"
filename="{epoch}-{step}-{val_loss:.2f}",
)
save_dir = root_dir / name
tb_logger = TensorBoardLogger(save_dir=save_dir, name=name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,11 @@ def main(
# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
monitor=metric_to_monitor_for_checkpoints,
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
# filename="{epoch}-{step}-{val_loss:.2f}",
filename="{epoch}-{step}-{" + metric_to_monitor_for_checkpoints + ":.2f}",
)

# Setup the logger and train the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def test_bionemo2_rootdir():
assert data_path.is_dir(), "Test data directory is supposed to be a directory."


@pytest.mark.skip("duplicate unittest")
def test_main_runs(tmpdir):
result_dir = Path(tmpdir.mkdir("results"))

Expand Down

0 comments on commit 42138d2

Please sign in to comment.