From 1640df6f048217553d189962490aa196b8960ad3 Mon Sep 17 00:00:00 2001 From: Lorenzo Mammana Date: Tue, 9 Jul 2024 11:55:15 +0200 Subject: [PATCH] fix: Improve sklearn automatic batch size (#125) * refactor: Improve safety of automatic batch size computation * build: Upgrade version, update changelog --- CHANGELOG.md | 4 +++ pyproject.toml | 2 +- quadra/__init__.py | 2 +- quadra/tasks/classification.py | 16 ++++----- quadra/utils/evaluation.py | 59 +++++++++++++++++++++++----------- 5 files changed, 53 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 122eea12..5b753d1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ # Changelog All notable changes to this project will be documented in this file. +### [2.1.13] + +- Improve safe batch size computation for sklearn based classification tasks + ### [2.1.12] #### Fixed diff --git a/pyproject.toml b/pyproject.toml index 0fd7fe3b..8690288c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "quadra" -version = "2.1.12" +version = "2.1.13" description = "Deep Learning experiment orchestration library" authors = [ "Federico Belotti ", diff --git a/quadra/__init__.py b/quadra/__init__.py index 67acf36f..02993008 100644 --- a/quadra/__init__.py +++ b/quadra/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.1.12" +__version__ = "2.1.13" def get_version(): diff --git a/quadra/tasks/classification.py b/quadra/tasks/classification.py index e885c2c6..b646d391 100644 --- a/quadra/tasks/classification.py +++ b/quadra/tasks/classification.py @@ -41,7 +41,6 @@ from quadra.trainers.classification import SklearnClassificationTrainer from quadra.utils import utils from quadra.utils.classification import ( - automatic_batch_size_computation, get_results, save_classification_result, ) @@ -539,15 +538,6 @@ def prepare(self) -> None: self.datamodule.prepare_data() self.datamodule.setup(stage="fit") - if not self.automatic_batch_size.disable and self.device != "cpu": - self.datamodule.batch_size = automatic_batch_size_computation( - datamodule=self.datamodule, - backbone=self.backbone, - starting_batch_size=self.automatic_batch_size.starting_batch_size, - ) - - self.train_dataloader_list = list(self.datamodule.train_dataloader()) - self.test_dataloader_list = list(self.datamodule.val_dataloader()) self.trainer = self.config.trainer @property @@ -601,6 +591,7 @@ def trainer(self, trainer_config: DictConfig) -> None: self._trainer = trainer @typing.no_type_check + @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size") def train(self) -> None: """Train the model.""" log.info("Starting training...!") @@ -609,6 +600,9 @@ def train(self) -> None: class_to_keep = None + self.train_dataloader_list = list(self.datamodule.train_dataloader()) + self.test_dataloader_list = list(self.datamodule.val_dataloader()) + if hasattr(self.datamodule, "class_to_keep_training") and self.datamodule.class_to_keep_training is not None: class_to_keep = self.datamodule.class_to_keep_training @@ -729,6 +723,7 @@ def extract_model_summary( break + @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size") def train_full_data(self): """Train the model on train + validation.""" # Reinit classifier @@ -743,6 +738,7 @@ def test(self) -> None: # train module to handle cross validation @typing.no_type_check + @automatic_datamodule_batch_size(batch_size_attribute_name="batch_size") def test_full_data(self) -> None: """Test model trained on full dataset.""" self.config.datamodule.class_to_idx = self.datamodule.full_dataset.class_to_idx diff --git a/quadra/utils/evaluation.py b/quadra/utils/evaluation.py index 35a3e04a..fbb2deb7 100644 --- a/quadra/utils/evaluation.py +++ b/quadra/utils/evaluation.py @@ -418,6 +418,28 @@ def decorator(func: Callable): def wrapper(self, *args, **kwargs): """Wrapper function.""" is_func_finished = False + starting_batch_size = None + automatic_batch_size_completed = False + + if hasattr(self, "automatic_batch_size_completed"): + automatic_batch_size_completed = self.automatic_batch_size_completed + + if hasattr(self, "automatic_batch_size"): + if not hasattr(self.automatic_batch_size, "disable") or not hasattr( + self.automatic_batch_size, "starting_batch_size" + ): + raise ValueError( + "The automatic_batch_size attribute should have the disable and starting_batch_size attributes" + ) + starting_batch_size = ( + self.automatic_batch_size.starting_batch_size if not self.automatic_batch_size.disable else None + ) + + if starting_batch_size is not None and not automatic_batch_size_completed: + # If we already tried to reduce the batch size, we will start from the last batch size + log.info("Performing automatic batch size scaling from %d", starting_batch_size) + setattr(self.datamodule, batch_size_attribute_name, starting_batch_size) + while not is_func_finished: valid_exceptions = (RuntimeError,) @@ -426,25 +448,26 @@ def wrapper(self, *args, **kwargs): try: func(self, *args, **kwargs) + is_func_finished = True + self.automatic_batch_size_completed = True + if torch.cuda.is_available(): + torch.cuda.empty_cache() except valid_exceptions as e: - if "out of memory" in str(e) or "Failed to allocate memory" in str(e): - current_batch_size = getattr(self.datamodule, batch_size_attribute_name) - setattr(self.datamodule, batch_size_attribute_name, current_batch_size // 2) - log.warning( - "The function %s went out of memory, trying to reduce the batch size to %d", - func.__name__, - self.datamodule.batch_size, - ) - - if self.datamodule.batch_size == 0: - raise RuntimeError( - f"Unable to run {func.__name__} with batch size 1, the program will exit" - ) from e - continue - - raise e - - is_func_finished = True + current_batch_size = getattr(self.datamodule, batch_size_attribute_name) + setattr(self.datamodule, batch_size_attribute_name, current_batch_size // 2) + log.warning( + "The function %s went out of memory, trying to reduce the batch size to %d", + func.__name__, + self.datamodule.batch_size, + ) + + if self.datamodule.batch_size == 0: + raise RuntimeError( + f"Unable to run {func.__name__} with batch size 1, the program will exit" + ) from e + + if torch.cuda.is_available(): + torch.cuda.empty_cache() return wrapper