Skip to content

Commit

Permalink
Add logging for storing cache in pipeline and blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
schobbejak committed May 15, 2024
1 parent 53ac91b commit c3f6206
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
5 changes: 4 additions & 1 deletion epochalyst/pipeline/model/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg
x, y = super().train(x, y, **train_args)

if cache_args:
self.log_to_terminal(f"Storing cache for x and y to {cache_args["storage_path"]}")
self._store_cache(name=self.get_hash() + "x", data=x, cache_args=cache_args)
self._store_cache(name=self.get_hash() + "y", data=y, cache_args=cache_args)

Expand Down Expand Up @@ -115,7 +116,9 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)

x = super().predict(x, **pred_args)

self._store_cache(self.get_hash() + "p", x, cache_args) if cache_args else None
if cache_args:
self.log_to_terminal(f"Storing cache for x to {cache_args["storage_path"]}")
self._store_cache(self.get_hash() + "p", x, cache_args)

# Set steps to original in case class is called again
self.steps = self.all_steps
Expand Down
34 changes: 19 additions & 15 deletions epochalyst/pipeline/model/training/training_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,18 @@ def train(self, x: Any, y: Any, cache_args: CacheArgs | None = None, **train_arg

x, y = self.custom_train(x, y, **train_args)

self._store_cache(
name=self.get_hash() + "x",
data=x,
cache_args=cache_args,
) if cache_args else None
self._store_cache(
name=self.get_hash() + "y",
data=y,
cache_args=cache_args,
) if cache_args else None
if cache_args:
self.log_to_terminal(f"Storing cache for x and y to {cache_args["storage_path"]}")
self._store_cache(
name=self.get_hash() + "x",
data=x,
cache_args=cache_args,
)
self._store_cache(
name=self.get_hash() + "y",
data=y,
cache_args=cache_args,
)

return x, y

Expand Down Expand Up @@ -116,11 +118,13 @@ def predict(self, x: Any, cache_args: CacheArgs | None = None, **pred_args: Any)

x = self.custom_predict(x, **pred_args)

self._store_cache(
name=self.get_hash() + "p",
data=x,
cache_args=cache_args,
) if cache_args else None
if cache_args:
self.log_to_terminal(f"Store cache for predictions to {cache_args["storage_path"]}")
self._store_cache(
name=self.get_hash() + "p",
data=x,
cache_args=cache_args,
)

return x

Expand Down
1 change: 1 addition & 0 deletions epochalyst/pipeline/model/transformation/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_
data = super().transform(data, **transform_args)

if cache_args:
self.log_to_terminal(f"Storing cache for pipeline to {cache_args["storage_path"]}")
self._store_cache(self.get_hash(), data, cache_args)

# Set steps to original in case class is called again
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def transform(self, data: Any, cache_args: CacheArgs | None = None, **transform_

data = self.custom_transform(data, **transform_args)
if cache_args:
self.log_to_terminal(f"Storing cache to {cache_args["storage_path"]}")
self._store_cache(name=self.get_hash(), data=data, cache_args=cache_args)
return data

Expand Down

0 comments on commit c3f6206

Please sign in to comment.