Skip to content

Commit

Permalink
add step in checkpoint_dir to avoid name clashing
Browse files Browse the repository at this point in the history
  • Loading branch information
sichu2023 committed Nov 5, 2024
1 parent 5e82404 commit abbf4bd
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions scripts/protein/esm2/esm2_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def main(
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
filename="{epoch}-{step}-{val_loss:.2f}",
)

# Setup the logger and train the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def _train_model_get_ckpt(
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_last=True,
save_on_train_epoch_end=True,
monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss",
monitor="val_loss",
every_n_train_steps=5,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
# async_save=False, # Tries to save asynchronously, previously led to race conditions.
filename="{epoch}-{step}-{val_loss:.2f}"
)
save_dir = root_dir / name
tb_logger = TensorBoardLogger(save_dir=save_dir, name=name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def main(
save_top_k=save_top_k,
every_n_train_steps=val_check_interval,
always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
# filename="{epoch}-{step}-{val_loss:.2f}",
)

# Setup the logger and train the model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,11 @@ def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]:
testing_callbacks.ValidLossCallback: testing_callbacks.ValidLossCallback(),
nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint(
save_last=True,
monitor="reduced_train_loss",
monitor="val_loss",
save_top_k=2,
every_n_train_steps=cls.val_check_interval,
always_save_context=True,
filename="{epoch}-{step}-{val_loss:.2f}",
),
}

Expand Down

0 comments on commit abbf4bd

Please sign in to comment.