Skip to content

Commit

Permalink
Refactor Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
justanotherariel committed Jun 10, 2024
1 parent 2c77381 commit 859cf7d
Show file tree
Hide file tree
Showing 18 changed files with 280 additions and 269 deletions.
10 changes: 5 additions & 5 deletions epochalyst/pipeline/model/training/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,12 +484,12 @@ def create_datasets(
:return: The training and validation datasets.
"""
train_dataset = TensorDataset(
torch.tensor(x[train_indices]),
torch.tensor(y[train_indices]),
x[train_indices].clone().detach(),
y[train_indices].clone().detach(),
)
test_dataset = TensorDataset(
torch.tensor(x[test_indices]),
torch.tensor(y[test_indices]),
x[test_indices].clone().detach(),
y[test_indices].clone().detach(),
)

return train_dataset, test_dataset
Expand All @@ -503,7 +503,7 @@ def create_prediction_dataset(
:param x: The input data.
:return: The prediction dataset.
"""
return TensorDataset(torch.tensor(x))
return TensorDataset(x.clone().detach())

def create_dataloaders(
self,
Expand Down
90 changes: 33 additions & 57 deletions tests/_core/_caching/test__cacher.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import shutil
from epochalyst._core._caching._cacher import _Cacher
import numpy as np
from pathlib import Path

import dask.array as da
import dask.dataframe as dd
import numpy as np
import pandas as pd
import polars as pl
import dask.array as da
from tests.util import remove_cache_files
import pytest
from pathlib import Path

TEMP_DIR = Path("tests/temp")
from epochalyst._core._caching._cacher import _Cacher
from tests.constants import TEMP_DIR


class Implemented_Cacher(_Cacher):
Expand All @@ -21,15 +21,8 @@ class Test_Cacher:
cache_path = TEMP_DIR

@pytest.fixture(autouse=True)
def run_around_tests(self):
# Code that will run before each test
TEMP_DIR.mkdir(exist_ok=True)

yield

# Code that will run after each
if TEMP_DIR.exists():
shutil.rmtree(TEMP_DIR)
def run_always(self, setup_temp_dir):
pass

def test_cacher_init(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -64,13 +57,13 @@ def test__cache_exists_storage_type_npy_exists(self):
)
is True
)
remove_cache_files()

def test__cache_exists_storage_type_parquet(self):
c = Implemented_Cacher()
assert (
c.cache_exists(
"test", {"storage_type": ".parquet", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".parquet", "storage_path": f"{self.cache_path}"},
)
is False
)
Expand All @@ -81,11 +74,11 @@ def test__cache_exists_storage_type_parquet_exists(self):
f.write("test")
assert (
c.cache_exists(
"test", {"storage_type": ".parquet", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".parquet", "storage_path": f"{self.cache_path}"},
)
is True
)
remove_cache_files()

def test__cache_exists_storage_type_csv(self):
c = Implemented_Cacher()
Expand All @@ -106,13 +99,13 @@ def test__cache_exists_storage_type_csv_exists(self):
)
is True
)
remove_cache_files()

def test__cache_exists_storage_type_npy_stack(self):
c = Implemented_Cacher()
assert (
c.cache_exists(
"test", {"storage_type": ".npy_stack", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".npy_stack", "storage_path": f"{self.cache_path}"},
)
is False
)
Expand All @@ -123,11 +116,11 @@ def test__cache_exists_storage_type_npy_stack_exists(self):
f.write("test")
assert (
c.cache_exists(
"test", {"storage_type": ".npy_stack", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".npy_stack", "storage_path": f"{self.cache_path}"},
)
is True
)
remove_cache_files()

def test__cache_exists_storage_type_pkl(self):
c = Implemented_Cacher()
Expand All @@ -148,13 +141,13 @@ def test__cache_exists_storage_type_pkl_exists(self):
)
is True
)
remove_cache_files()

def test__cache_exists_storage_type_unsupported(self):
c = Implemented_Cacher()
assert (
c.cache_exists(
"test", {"storage_type": ".new_type", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".new_type", "storage_path": f"{self.cache_path}"},
)
is False
)
Expand All @@ -179,7 +172,9 @@ def test__store_cache_no_output_data_type(self):
c = Implemented_Cacher()
with pytest.raises(ValueError):
c._store_cache(
"test", "test", {"storage_type": ".npy", "storage_path": f"{self.cache_path}"}
"test",
"test",
{"storage_type": ".npy", "storage_path": f"{self.cache_path}"},
)

# storage type .npy
Expand All @@ -200,7 +195,6 @@ def test__store_cache_storage_type_npy_output_data_type_numpy_array(self):
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_npy_output_data_type_dask_array(self):
c = Implemented_Cacher()
Expand All @@ -221,7 +215,6 @@ def test__store_cache_storage_type_npy_output_data_type_dask_array(self):
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_npy_output_data_type_unsupported(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -252,11 +245,11 @@ def test__store_cache_storage_type_parquet_output_data_type_pandas_dataframe(sel
)
assert (
c.cache_exists(
"test", {"storage_type": ".parquet", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".parquet", "storage_path": f"{self.cache_path}"},
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_parquet_output_data_type_dask_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -274,11 +267,11 @@ def test__store_cache_storage_type_parquet_output_data_type_dask_dataframe(self)
)
assert (
c.cache_exists(
"test", {"storage_type": ".parquet", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".parquet", "storage_path": f"{self.cache_path}"},
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_parquet_output_data_type_numpy_array(self):
c = Implemented_Cacher()
Expand All @@ -295,11 +288,11 @@ def test__store_cache_storage_type_parquet_output_data_type_numpy_array(self):
)
assert (
c.cache_exists(
"test", {"storage_type": ".parquet", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".parquet", "storage_path": f"{self.cache_path}"},
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_parquet_output_data_type_dask_array(self):
c = Implemented_Cacher()
Expand All @@ -316,11 +309,11 @@ def test__store_cache_storage_type_parquet_output_data_type_dask_array(self):
)
assert (
c.cache_exists(
"test", {"storage_type": ".parquet", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".parquet", "storage_path": f"{self.cache_path}"},
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_parquet_output_data_type_polars_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -337,11 +330,11 @@ def test__store_cache_storage_type_parquet_output_data_type_polars_dataframe(sel
)
assert (
c.cache_exists(
"test", {"storage_type": ".parquet", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".parquet", "storage_path": f"{self.cache_path}"},
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_parquet_output_data_type_unsupported(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -376,7 +369,6 @@ def test__store_cache_storage_type_csv_output_data_type_pandas_dataframe(self):
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_csv_output_data_type_dask_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -398,7 +390,6 @@ def test__store_cache_storage_type_csv_output_data_type_dask_dataframe(self):
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_csv_output_data_type_polars_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -419,7 +410,6 @@ def test__store_cache_storage_type_csv_output_data_type_polars_dataframe(self):
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_csv_output_data_type_unsupported(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -450,7 +440,8 @@ def test__store_cache_storage_type_npy_stack_output_data_type_dask_array(self):
)
assert (
c.cache_exists(
"test", {"storage_type": ".npy_stack", "storage_path": f"{self.cache_path}"}
"test",
{"storage_type": ".npy_stack", "storage_path": f"{self.cache_path}"},
)
is True
)
Expand Down Expand Up @@ -501,7 +492,6 @@ def test__store_cache_storage_type_pkl_output_data_type_pandas_dataframe(self):
)
is True
)
remove_cache_files()

def test__store_cache_storage_type_pkl_output_data_type_dask_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -523,7 +513,6 @@ def test__store_cache_storage_type_pkl_output_data_type_dask_dataframe(self):
)
is True
)
remove_cache_files()

# _get_cache
def test__get_cache_no_cache_args(self):
Expand Down Expand Up @@ -571,7 +560,6 @@ def test__get_cache_storage_type_npy_output_data_type_numpy_array(self):
)
== "test"
)
remove_cache_files()

def test__get_cache_storage_type_npy_output_data_type_dask_array(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -600,7 +588,6 @@ def test__get_cache_storage_type_npy_output_data_type_dask_array(self):
.all()
== x.compute().all()
)
remove_cache_files()

def test__get_cache_storage_type_npy_output_data_type_unsupported(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -636,7 +623,6 @@ def test__get_cache_storage_type_parquet_output_data_type_pandas_dataframe(self)
"output_data_type": "pandas_dataframe",
},
).equals(data)
remove_cache_files()

def test__get_cache_storage_type_parquet_output_data_type_dask_dataframe(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -664,7 +650,6 @@ def test__get_cache_storage_type_parquet_output_data_type_dask_dataframe(self):
.compute()
.equals(data.compute())
)
remove_cache_files()

def test__get_cache_storage_type_parquet_output_data_type_numpy_array(self):
c = Implemented_Cacher()
Expand All @@ -688,7 +673,6 @@ def test__get_cache_storage_type_parquet_output_data_type_numpy_array(self):
},
)
assert get_cache.all() == data.all()
remove_cache_files()

def test__get_cache_storage_type_parquet_output_data_type_dask_array(self):
c = Implemented_Cacher()
Expand All @@ -712,7 +696,6 @@ def test__get_cache_storage_type_parquet_output_data_type_dask_array(self):
},
)
assert get_cache.compute().all() == data.compute().all()
remove_cache_files()

def test__get_cache_storage_type_parquet_output_data_type_polars_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -736,7 +719,6 @@ def test__get_cache_storage_type_parquet_output_data_type_polars_dataframe(self)
},
)
assert data.equals(get_cache)
remove_cache_files()

def test__get_cache_storage_type_parquet_output_data_type_unsupported(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -773,7 +755,6 @@ def test__get_cache_storage_type_csv_output_data_type_pandas_dataframe(self):
},
)
assert get_cache.equals(data)
remove_cache_files()

def test__get_cache_storage_type_csv_output_data_type_dask_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -798,7 +779,6 @@ def test__get_cache_storage_type_csv_output_data_type_dask_dataframe(self):
},
)
assert get_cache.compute().reset_index(drop=True).equals(data.compute())
remove_cache_files()

def test__get_cache_storage_type_csv_output_data_type_polars_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -822,7 +802,6 @@ def test__get_cache_storage_type_csv_output_data_type_polars_dataframe(self):
},
)
assert data.equals(get_cache)
remove_cache_files()

def test__get_cache_storage_type_csv_output_data_type_unsupported(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -859,7 +838,6 @@ def test__get_cache_storage_type_npy_stack_output_data_type_dask_array(self):
},
)
assert get_cache.compute().all() == data.compute().all()
remove_cache_files()

def test__get_cache_storage_type_npy_stack_output_data_type_unsupported(self):
c = Implemented_Cacher()
Expand Down Expand Up @@ -908,7 +886,6 @@ def test__get_cache_storage_type_pkl_output_data_type_pandas_dataframe(self):
},
)
assert get_cache.equals(data)
remove_cache_files()

def test__get_cache_storage_type_pkl_output_data_type_dask_dataframe(self):
c = Implemented_Cacher()
Expand All @@ -933,4 +910,3 @@ def test__get_cache_storage_type_pkl_output_data_type_dask_dataframe(self):
},
)
assert get_cache.compute().reset_index(drop=True).equals(data.compute())
remove_cache_files()
1 change: 1 addition & 0 deletions tests/_core/_logging/test__logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from epochalyst._core._logging._logger import _Logger


Expand Down
Loading

0 comments on commit 859cf7d

Please sign in to comment.