Skip to content

Commit

Permalink
Add all
Browse files Browse the repository at this point in the history
  • Loading branch information
hjdeheer committed Apr 16, 2024
2 parents b03404b + 1fc1a95 commit 4740878
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 65 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ python -m pytest --cov=epochalyst --cov-branch --cov-report=html:coverage_re
### Caching

For caching some imports are only required, these have to be manually installed when needed

- dask >= 2023.12.0 & dask-expr
- pandas >= 1.3.0
- polars
Expand Down
80 changes: 51 additions & 29 deletions epochalyst/_core/_caching/_cacher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import os
import pickle
import sys
from typing import Any, TypedDict, Literal

try:
Expand All @@ -26,8 +27,13 @@

from epochalyst._core._logging._logger import _Logger

if sys.version_info < (3, 11):
from typing_extensions import NotRequired
else:
from typing import NotRequired

class _CacheArgs(TypedDict):

class CacheArgs(TypedDict):
"""The cache arguments.
Currently listed cache_args are supported. If more are required, create a new GitHub issue.
Expand All @@ -44,12 +50,16 @@ class _CacheArgs(TypedDict):
- ".parquet": The storage type is a Parquet file.
- ".csv": The storage type is a CSV file.
- ".npy_stack": The storage type is a NumPy stack.
- ".pkl": The storage type is a pickle file
- ".pkl": The storage type is a pickle file.
- storage_path: The path to the storage.
- read_args: The arguments for reading the data.
- store_args: The arguments for storing the data.
:param output_data_type: The type of the output data.
:param storage_type: The type of the storage.
:param storage_path: The path to the storage.
:param read_args: The optional additional arguments for reading the data.
:param store_args: The optional additional arguments for storing the data.
"""

output_data_type: Literal[
Expand All @@ -61,6 +71,8 @@ class _CacheArgs(TypedDict):
]
storage_type: Literal[".npy", ".parquet", ".csv", ".npy_stack", ".pkl"]
storage_path: str # TODO(Jeffrey) Allow str | bytes | os.PathLike[str] | os.PathLike[bytes] instead of just str
read_args: NotRequired[dict[str, Any]]
store_args: NotRequired[dict[str, Any]]


class _Cacher(_Logger):
Expand All @@ -78,7 +90,7 @@ def _get_cache(name: str, cache_args: _CacheArgs | None = None) -> Any: # Load t
def _store_cache(name: str, data: Any, cache_args: _CacheArgs | None = None) -> None: # Store data
"""

def _cache_exists(self, name: str, cache_args: _CacheArgs | None = None) -> bool:
def _cache_exists(self, name: str, cache_args: CacheArgs | None = None) -> bool:
"""Check if the cache exists.
:param cache_args: The cache arguments.
Expand Down Expand Up @@ -126,7 +138,7 @@ def _cache_exists(self, name: str, cache_args: _CacheArgs | None = None) -> bool

return path_exists

def _get_cache(self, name: str, cache_args: _CacheArgs | None = None) -> Any:
def _get_cache(self, name: str, cache_args: CacheArgs | None = None) -> Any:
"""Load the cache.
:param name: The name of the cache.
Expand All @@ -151,6 +163,7 @@ def _get_cache(self, name: str, cache_args: _CacheArgs | None = None) -> Any:
storage_type = cache_args["storage_type"]
storage_path = cache_args["storage_path"]
output_data_type = cache_args["output_data_type"]
read_args = cache_args.get("read_args", {})

# If storage path does not end a slash, add it
if storage_path[-1] != "/":
Expand All @@ -161,9 +174,9 @@ def _get_cache(self, name: str, cache_args: _CacheArgs | None = None) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .npy file from {storage_path + name}")
if output_data_type == "numpy_array":
return np.load(storage_path + name + ".npy")
return np.load(storage_path + name + ".npy", **read_args)
elif output_data_type == "dask_array":
return da.from_array(np.load(storage_path + name + ".npy"))
return da.from_array(np.load(storage_path + name + ".npy"), **read_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for loading .npy file."
Expand All @@ -175,15 +188,19 @@ def _get_cache(self, name: str, cache_args: _CacheArgs | None = None) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .parquet file from {storage_path + name}")
if output_data_type == "pandas_dataframe":
return pd.read_parquet(storage_path + name + ".parquet")
return pd.read_parquet(storage_path + name + ".parquet", **read_args)
elif output_data_type == "dask_dataframe":
return dd.read_parquet(storage_path + name + ".parquet")
return dd.read_parquet(storage_path + name + ".parquet", **read_args)
elif output_data_type == "numpy_array":
return pd.read_parquet(storage_path + name + ".parquet").to_numpy()
return pd.read_parquet(
storage_path + name + ".parquet", **read_args
).to_numpy()
elif output_data_type == "dask_array":
return dd.read_parquet(storage_path + name + ".parquet").to_dask_array()
return dd.read_parquet(
storage_path + name + ".parquet", **read_args
).to_dask_array()
elif output_data_type == "polars_dataframe":
return pl.read_parquet(storage_path + name + ".parquet")
return pl.read_parquet(storage_path + name + ".parquet", **read_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for loading .parquet file."
Expand All @@ -195,11 +212,11 @@ def _get_cache(self, name: str, cache_args: _CacheArgs | None = None) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .csv file from {storage_path + name}")
if output_data_type == "pandas_dataframe":
return pd.read_csv(storage_path + name + ".csv")
return pd.read_csv(storage_path + name + ".csv", **read_args)
elif output_data_type == "dask_dataframe":
return dd.read_csv(storage_path + name + "/*.part")
return dd.read_csv(storage_path + name + "/*.part", **read_args)
elif output_data_type == "polars_dataframe":
return pl.read_csv(storage_path + name + ".csv")
return pl.read_csv(storage_path + name + ".csv", **read_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for loading .csv file."
Expand All @@ -211,7 +228,7 @@ def _get_cache(self, name: str, cache_args: _CacheArgs | None = None) -> Any:
# Check if output_data_type is supported and load cache to output_data_type
self.log_to_debug(f"Loading .npy_stack file from {storage_path + name}")
if output_data_type == "dask_array":
return da.from_npy_stack(storage_path + name)
return da.from_npy_stack(storage_path + name, **read_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for loading .npy_stack file."
Expand All @@ -224,15 +241,15 @@ def _get_cache(self, name: str, cache_args: _CacheArgs | None = None) -> Any:
self.log_to_debug(
f"Loading pickle file from {storage_path + name + '.pkl'}"
)
return pickle.load(open(storage_path + name + ".pkl", "rb"))
return pickle.load(open(storage_path + name + ".pkl", "rb"), **read_args)
else:
self.log_to_debug(f"Invalid storage type: {storage_type}")
raise ValueError(
"storage_type must be .npy, .parquet, .csv, or .npy_stack, other types not supported yet"
)

def _store_cache(
self, name: str, data: Any, cache_args: _CacheArgs | None = None
self, name: str, data: Any, cache_args: CacheArgs | None = None
) -> None:
"""Store one set of data.
Expand All @@ -258,6 +275,7 @@ def _store_cache(
storage_type = cache_args["storage_type"]
storage_path = cache_args["storage_path"]
output_data_type = cache_args["output_data_type"]
store_args = cache_args.get("store_args", {})

# If storage path does not end a slash, add it
if storage_path[-1] != "/":
Expand All @@ -268,9 +286,9 @@ def _store_cache(
# Check if output_data_type is supported and store cache to output_data_type
self.log_to_debug(f"Storing .npy file to {storage_path + name}")
if output_data_type == "numpy_array":
np.save(storage_path + name + ".npy", data)
np.save(storage_path + name + ".npy", data, **store_args)
elif output_data_type == "dask_array":
np.save(storage_path + name + ".npy", data.compute())
np.save(storage_path + name + ".npy", data.compute(), **store_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for storing .npy file."
Expand All @@ -282,19 +300,21 @@ def _store_cache(
# Check if output_data_type is supported and store cache to output_data_type
self.log_to_debug(f"Storing .parquet file to {storage_path + name}")
if output_data_type == "pandas_dataframe":
data.to_parquet(storage_path + name + ".parquet")
data.to_parquet(storage_path + name + ".parquet", **store_args)
elif output_data_type == "dask_dataframe":
data.to_parquet(storage_path + name + ".parquet")
data.to_parquet(storage_path + name + ".parquet", **store_args)
elif output_data_type == "numpy_array":
pd.DataFrame(data).to_parquet(storage_path + name + ".parquet")
pd.DataFrame(data).to_parquet(
storage_path + name + ".parquet", **store_args
)
elif output_data_type == "dask_array":
new_dd = dd.from_dask_array(data)
new_dd = new_dd.rename(
columns={col: str(col) for col in new_dd.columns}
)
new_dd.to_parquet(storage_path + name + ".parquet")
new_dd.to_parquet(storage_path + name + ".parquet", **store_args)
elif output_data_type == "polars_dataframe":
data.write_parquet(storage_path + name + ".parquet")
data.write_parquet(storage_path + name + ".parquet", **store_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for storing .parquet file."
Expand All @@ -306,11 +326,13 @@ def _store_cache(
# Check if output_data_type is supported and store cache to output_data_type
self.log_to_debug(f"Storing .csv file to {storage_path + name}")
if output_data_type == "pandas_dataframe":
data.to_csv(storage_path + name + ".csv", index=False)
data.to_csv(
storage_path + name + ".csv", **({"index": False} | store_args)
)
elif output_data_type == "dask_dataframe":
data.to_csv(storage_path + name, index=False)
data.to_csv(storage_path + name, **({"index": False} | store_args))
elif output_data_type == "polars_dataframe":
data.write_csv(storage_path + name + ".csv")
data.write_csv(storage_path + name + ".csv", **store_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for storing .csv file."
Expand All @@ -322,7 +344,7 @@ def _store_cache(
# Check if output_data_type is supported and store cache to output_data_type
self.log_to_debug(f"Storing .npy_stack file to {storage_path + name}")
if output_data_type == "dask_array":
da.to_npy_stack(storage_path + name, data)
da.to_npy_stack(storage_path + name, data, **store_args)
else:
self.log_to_debug(
f"Invalid output data type: {output_data_type}, for storing .npy_stack file."
Expand All @@ -336,7 +358,7 @@ def _store_cache(
pickle.dump(
data,
open(storage_path + name + ".pkl", "wb"),
protocol=pickle.HIGHEST_PROTOCOL,
**({"protocol": pickle.HIGHEST_PROTOCOL} | store_args),
)
else:
self.log_to_debug(f"Invalid storage type: {storage_type}")
Expand Down
6 changes: 3 additions & 3 deletions epochalyst/pipeline/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from agogos.training import ParallelTrainingSystem
from typing import Any
from epochalyst._core._caching._cacher import _CacheArgs
from epochalyst._core._caching._cacher import CacheArgs


class EnsemblePipeline(ParallelTrainingSystem):
Expand All @@ -9,7 +9,7 @@ class EnsemblePipeline(ParallelTrainingSystem):
:param steps: Trainers to ensemble
"""

def get_x_cache_exists(self, cache_args: _CacheArgs) -> bool:
def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
"""Get status of x
:param cache_args: Cache arguments
Expand All @@ -24,7 +24,7 @@ def get_x_cache_exists(self, cache_args: _CacheArgs) -> bool:

return True

def get_y_cache_exists(self, cache_args: _CacheArgs) -> bool:
def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
"""Get status of y cache
:param cache_args: Cache arguments
Expand Down
6 changes: 3 additions & 3 deletions epochalyst/pipeline/model/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any
from agogos.training import Pipeline
from epochalyst._core._caching._cacher import _CacheArgs
from epochalyst._core._caching._cacher import CacheArgs


class ModelPipeline(Pipeline):
Expand Down Expand Up @@ -36,7 +36,7 @@ def predict(self, x: Any, **pred_args: Any) -> Any:
"""
return super().predict(x, **pred_args)

def get_x_cache_exists(self, cache_args: _CacheArgs) -> bool:
def get_x_cache_exists(self, cache_args: CacheArgs) -> bool:
"""Get status of x
:param cache_args: Cache arguments
Expand All @@ -46,7 +46,7 @@ def get_x_cache_exists(self, cache_args: _CacheArgs) -> bool:
return False
return self.x_sys._cache_exists(self.x_sys.get_hash(), cache_args)

def get_y_cache_exists(self, cache_args: _CacheArgs) -> bool:
def get_y_cache_exists(self, cache_args: CacheArgs) -> bool:
"""Get status of y cache
:param cache_args: Cache arguments
Expand Down
Loading

0 comments on commit 4740878

Please sign in to comment.