Skip to content

Commit

Permalink
fix: Use best model checkpoint if available by default on lightning t…
Browse files Browse the repository at this point in the history
…ask test (#39)

* fix: Use best model checkpoint if available by default on lightning task test

* build: Bump version 1.1.1 -> 1.1.2

* docs: Update readme

Approved by @rcmalli
  • Loading branch information
lorenzomammana authored Jul 28, 2023
1 parent 3fe8008 commit 1de4daa
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 4 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
# Changelog
All notable changes to this project will be documented in this file.

### [1.1.2]

#### Fixed

- Fix best checkpoint not used for testing when available in `LightningTask` class.

### [1.1.1]

#### Fixed
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "quadra"
version = "1.1.1"
version = "1.1.2"
description = "Deep Learning experiment orchestration library"
authors = [
{name = "Alessandro Polidori", email = "alessandro.polidori@orobix.com"},
Expand Down Expand Up @@ -117,7 +117,7 @@ repository = "https://github.com/orobix/quadra"

# Adapted from https://realpython.com/pypi-publish-python-package/#version-your-package
[tool.bumpver]
current_version = "1.1.1"
current_version = "1.1.2"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "build: Bump version {old_version} -> {new_version}"
commit = true
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.1.1"
__version__ = "1.1.2"


def get_version():
Expand Down
17 changes: 16 additions & 1 deletion quadra/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,22 @@ def train(self) -> None:
def test(self) -> Any:
"""Test the model."""
log.info("Starting testing!")
return self.trainer.test(model=self.module, datamodule=self.datamodule)

best_model = None
if (
self.trainer.checkpoint_callback is not None
and hasattr(self.trainer.checkpoint_callback, "best_model_path")
and self.trainer.checkpoint_callback.best_model_path is not None
):
best_model = self.trainer.checkpoint_callback.best_model_path

if best_model is None:
log.warning(
"No best checkpoint model found, using last weights for test, this might lead to worse results, "
"consider using a checkpoint callback."
)

return self.trainer.test(model=self.module, datamodule=self.datamodule, ckpt_path=best_model)

def finalize(self) -> None:
"""Finalize the experiment."""
Expand Down

0 comments on commit 1de4daa

Please sign in to comment.