From 4a0cd8d05e70dd5f28372fe2014c3afa6ff5c775 Mon Sep 17 00:00:00 2001 From: Alessandro Polidori <61737239+AlessandroPolidori@users.noreply.github.com> Date: Wed, 20 Sep 2023 17:54:44 +0200 Subject: [PATCH] feat: Align classification results with sklearn classification ones (#63) * feat: Align classification results with sklearn classification ones * build: Bump version 1.2.3 -> 1.2.4 * docs: Add changelog * feat: Print also task config during experiment starting phase (#62) * fix: Add->Added in changelog Approved By: @lorenzomammana --- CHANGELOG.md | 5 +++++ pyproject.toml | 4 ++-- quadra/__init__.py | 2 +- quadra/modules/classification/base.py | 2 +- quadra/tasks/classification.py | 9 ++++++--- quadra/utils/utils.py | 1 + 6 files changed, 16 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13136b07..80387391 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ # Changelog All notable changes to this project will be documented in this file. +### [1.2.4] +#### Added + +- Return also probabilities in Classification's module predict step and add them to `self.res`. + ### [1.2.3] diff --git a/pyproject.toml b/pyproject.toml index 7ba246d2..e6f70947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "quadra" -version = "1.2.3" +version = "1.2.4" description = "Deep Learning experiment orchestration library" authors = [ { name = "Alessandro Polidori", email = "alessandro.polidori@orobix.com" }, @@ -118,7 +118,7 @@ repository = "https://github.com/orobix/quadra" # Adapted from https://realpython.com/pypi-publish-python-package/#version-your-package [tool.bumpver] -current_version = "1.2.3" +current_version = "1.2.4" version_pattern = "MAJOR.MINOR.PATCH" commit_message = "build: Bump version {old_version} -> {new_version}" commit = true diff --git a/quadra/__init__.py b/quadra/__init__.py index 0837dde2..639b40c2 100644 --- a/quadra/__init__.py +++ b/quadra/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.3" +__version__ = "1.2.4" def get_version(): diff --git a/quadra/modules/classification/base.py b/quadra/modules/classification/base.py index affc7bbf..d1678155 100644 --- a/quadra/modules/classification/base.py +++ b/quadra/modules/classification/base.py @@ -219,7 +219,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A grayscale_cam = ndimage.zoom(grayscale_cam_low_res, zoom_factors, order=1) else: grayscale_cam = None - return predicted_classes, grayscale_cam + return predicted_classes, grayscale_cam, torch.max(probs, dim=1)[0].tolist() class MultilabelClassificationModule(BaseLightningModule): diff --git a/quadra/tasks/classification.py b/quadra/tasks/classification.py index 3ff6c356..b43707dd 100644 --- a/quadra/tasks/classification.py +++ b/quadra/tasks/classification.py @@ -311,10 +311,12 @@ def generate_report(self) -> None: log.warning("There is no prediction to generate the report. Skipping report generation.") return all_outputs = [x[0] for x in predictions_outputs] - if not all_outputs: + all_probs = [x[2] for x in predictions_outputs] + if not all_outputs or not all_probs: log.warning("There is no prediction to generate the report. Skipping report generation.") return all_outputs = [item for sublist in all_outputs for item in sublist] + all_probs = [item for sublist in all_probs for item in sublist] all_targets = [target.tolist() for im, target in self.datamodule.test_dataloader()] all_targets = [item for sublist in all_targets for item in sublist] @@ -335,16 +337,17 @@ def generate_report(self) -> None: output_folder_test = "test" test_dataloader = self.datamodule.test_dataloader() test_dataset = cast(ImageClassificationListDataset, test_dataloader.dataset) - res = pd.DataFrame( + self.res = pd.DataFrame( { "sample": list(test_dataset.x), "real_label": all_targets, "pred_label": all_outputs, + "probability": all_probs, } ) os.makedirs(output_folder_test, exist_ok=True) save_classification_result( - results=res, + results=self.res, output_folder=output_folder_test, confmat=self.report_confmat, accuracy=accuracy, diff --git a/quadra/utils/utils.py b/quadra/utils/utils.py index bc226933..ac5ed02c 100644 --- a/quadra/utils/utils.py +++ b/quadra/utils/utils.py @@ -84,6 +84,7 @@ def extras(config: DictConfig) -> None: def print_config( config: DictConfig, fields: Sequence[str] = ( + "task", "trainer", "model", "datamodule",