From c3f62061ebfd1282c8b3ecff8360032411c0ddd4 Mon Sep 17 00:00:00 2001 From: Schobbejak Date: Wed, 15 May 2024 15:23:01 +0200 Subject: [PATCH] Add logging for storing cache in pipeline and blocks --- .../pipeline/model/training/training.py | 5 ++- .../pipeline/model/training/training_block.py | 34 +++++++++++-------- .../model/transformation/transformation.py | 1 + .../transformation/transformation_block.py | 1 + 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/epochalyst/pipeline/model/training/training.py b/epochalyst/pipeline/model/training/training.py index 8125a1d..8c10504 100644 --- a/epochalyst/pipeline/model/training/training.py +++ b/epochalyst/pipeline/model/training/training.py @@ -63,6 +63,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg x, y = super().train(x, y, **train_args) if cache_args: + self.log_to_terminal(f"Storing cache for x and y to {cache_args["storage_path"]}") self._store_cache(name=self.get_hash() + "x", data=x, cache_args=cache_args) self._store_cache(name=self.get_hash() + "y", data=y, cache_args=cache_args) @@ -115,7 +116,9 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) x = super().predict(x, **pred_args) - self._store_cache(self.get_hash() + "p", x, cache_args) if cache_args else None + if cache_args: + self.log_to_terminal(f"Storing cache for x to {cache_args["storage_path"]}") + self._store_cache(self.get_hash() + "p", x, cache_args) # Set steps to original in case class is called again self.steps = self.all_steps diff --git a/epochalyst/pipeline/model/training/training_block.py b/epochalyst/pipeline/model/training/training_block.py index cb27831..914bb79 100644 --- a/epochalyst/pipeline/model/training/training_block.py +++ b/epochalyst/pipeline/model/training/training_block.py @@ -76,16 +76,18 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg x, y = self.custom_train(x, y, **train_args) - self._store_cache( - name=self.get_hash() + "x", - data=x, - cache_args=cache_args, - ) if cache_args else None - self._store_cache( - name=self.get_hash() + "y", - data=y, - cache_args=cache_args, - ) if cache_args else None + if cache_args: + self.log_to_terminal(f"Storing cache for x and y to {cache_args["storage_path"]}") + self._store_cache( + name=self.get_hash() + "x", + data=x, + cache_args=cache_args, + ) + self._store_cache( + name=self.get_hash() + "y", + data=y, + cache_args=cache_args, + ) return x, y @@ -116,11 +118,13 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any) x = self.custom_predict(x, **pred_args) - self._store_cache( - name=self.get_hash() + "p", - data=x, - cache_args=cache_args, - ) if cache_args else None + if cache_args: + self.log_to_terminal(f"Store cache for predictions to {cache_args["storage_path"]}") + self._store_cache( + name=self.get_hash() + "p", + data=x, + cache_args=cache_args, + ) return x diff --git a/epochalyst/pipeline/model/transformation/transformation.py b/epochalyst/pipeline/model/transformation/transformation.py index 479cc94..623d152 100644 --- a/epochalyst/pipeline/model/transformation/transformation.py +++ b/epochalyst/pipeline/model/transformation/transformation.py @@ -105,6 +105,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_ data = super().transform(data, **transform_args) if cache_args: + self.log_to_terminal(f"Storing cache for pipeline to {cache_args["storage_path"]}") self._store_cache(self.get_hash(), data, cache_args) # Set steps to original in case class is called again diff --git a/epochalyst/pipeline/model/transformation/transformation_block.py b/epochalyst/pipeline/model/transformation/transformation_block.py index 8ec5abe..1f1bf36 100644 --- a/epochalyst/pipeline/model/transformation/transformation_block.py +++ b/epochalyst/pipeline/model/transformation/transformation_block.py @@ -73,6 +73,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_ data = self.custom_transform(data, **transform_args) if cache_args: + self.log_to_terminal(f"Storing cache to {cache_args["storage_path"]}") self._store_cache(name=self.get_hash(), data=data, cache_args=cache_args) return data