diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index 67a6915..696cf04 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -15,6 +15,7 @@ import tempfile import warnings import zipfile +from pathlib import Path from typing import Tuple, List, Optional, Union, Dict import h5py # type: ignore[import] @@ -23,8 +24,7 @@ from rul_datasets import utils from rul_datasets.reader.data_root import get_data_root -from rul_datasets.reader import AbstractReader, scaling - +from rul_datasets.reader import AbstractReader, scaling, saving NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn" @@ -206,23 +206,37 @@ def fds(self) -> List[int]: """Indices of the available sub-datasets.""" return list(self._WINDOW_SIZES) - def prepare_data(self) -> None: + def prepare_data(self, cache: bool = True) -> None: """ - Prepare the N-C-MAPSS dataset. This function needs to be called before using the - dataset for the first time. + Prepare the N-C-MAPSS dataset. This function needs to be called before using + the dataset for the first time. The dataset is cached for faster loading in + the future. This behavior can be disabled to save disk space by setting + `cache` to `False`. The dataset is assumed to be present in the data root directory. The training data is then split into development and validation set. Afterward, a scaler is fit on the development features if it was not already done previously. + + Args: + cache: Whether to cache the data for faster loading in the future. """ if not os.path.exists(self._NCMAPSS_ROOT): _download_ncmapss(self._NCMAPSS_ROOT) + if cache and not self._cache_exists(): + self._cache_data() if not os.path.exists(self._get_scaler_path()): features, _, _ = self._load_data("dev") scaler = scaling.fit_scaler(features, MinMaxScaler(self.scaling_range)) scaling.save_scaler(scaler, self._get_scaler_path()) - def _get_scaler_path(self): + def _cache_data(self) -> None: + os.makedirs(self._get_cache_path(), exist_ok=True) + features, targets, auxiliary = self._load_raw_data() + features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary) + for i, (f, t, a) in enumerate(zip(features, targets, auxiliary)): + saving.save(str(self._get_cache_path() / f"{i}.npy"), f, t, a) + + def _get_scaler_path(self) -> str: file_name = ( f"scaler_{self.fd}_{self.run_split_dist['dev']}_{self.scaling_range}.pkl" ) @@ -264,6 +278,21 @@ def load_complete_split( def _load_data( self, split: str ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + if self._cache_exists(): + features, targets, auxiliary = self._load_cached_data(split) + else: + features, targets, auxiliary = self._load_original_data(split) + + return features, targets, auxiliary + + def _load_cached_data(self, split: str): + unit_idx = self.run_split_dist[split] + save_paths = [str(self._get_cache_path() / f"{i}.npy") for i in unit_idx] + features, targets, auxiliary = saving.load_multiple(save_paths) + + return features, targets, auxiliary + + def _load_original_data(self, split): features, targets, auxiliary = self._load_raw_data() features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary) features = self._select_units(features, split) @@ -368,6 +397,12 @@ def _calc_default_window_size(self): return max(*max_window_size) + def _cache_exists(self) -> bool: + return saving.exists(str(self._get_cache_path() / "0.npy")) + + def _get_cache_path(self): + return Path(self._NCMAPSS_ROOT) / f"DS{self.fd:02d}" + def _download_ncmapss(data_root): with tempfile.TemporaryDirectory() as tmp_path: diff --git a/rul_datasets/reader/saving.py b/rul_datasets/reader/saving.py index 8b15b7b..0bde77e 100644 --- a/rul_datasets/reader/saving.py +++ b/rul_datasets/reader/saving.py @@ -8,7 +8,7 @@ from tqdm import tqdm # type: ignore -def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None: +def save(save_path: str, features: np.ndarray, *targets: np.ndarray) -> None: """ Save features and targets of a run to .npy files. @@ -21,15 +21,20 @@ def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None: Args: save_path: The path including file name to save the arrays to. features: The feature array to save. - targets: The targets array to save. + targets: The targets arrays to save. """ feature_path = _get_feature_path(save_path) np.save(feature_path, features, allow_pickle=False) - target_path = _get_target_path(save_path) - np.save(target_path, targets, allow_pickle=False) + if len(targets) == 1: # keeps backward compat for when only one target was allowed + target_path = _get_target_path(save_path, None) + np.save(target_path, targets[0], allow_pickle=False) + else: + for i, target in enumerate(targets): + target_path = _get_target_path(save_path, i) + np.save(target_path, target, allow_pickle=False) -def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]: +def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, ...]: """ Load features and targets of a run from .npy files. @@ -50,15 +55,24 @@ def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]: memmap_mode: Optional[Literal["r"]] = "r" if memmap else None feature_path = _get_feature_path(save_path) features = np.load(feature_path, memmap_mode, allow_pickle=False) - target_path = _get_target_path(save_path) - targets = np.load(target_path, memmap_mode, allow_pickle=False) + if os.path.exists(_get_target_path(save_path, None)): + # keeps backward compat for when only one target was allowed + target_path = _get_target_path(save_path, None) + targets = [np.load(target_path, memmap_mode, allow_pickle=False)] + else: + i = 0 + targets = [] + while os.path.exists(_get_target_path(save_path, i)): + target_path = _get_target_path(save_path, i) + targets.append(np.load(target_path, memmap_mode, allow_pickle=False)) + i += 1 - return features, targets + return features, *targets def load_multiple( save_paths: List[str], memmap: bool = False -) -> Tuple[List[np.ndarray], List[np.ndarray]]: +) -> Tuple[List[np.ndarray], ...]: """ Load multiple runs with the [load][rul_datasets.reader.saving.load] function. @@ -72,11 +86,11 @@ def load_multiple( """ if save_paths: runs = [load(save_path, memmap) for save_path in save_paths] - features, targets = [list(x) for x in zip(*runs)] + features, *targets = [list(x) for x in zip(*runs)] else: - features, targets = [], [] + features, targets = [], [[]] - return features, targets + return features, *targets def exists(save_path: str) -> bool: @@ -90,13 +104,14 @@ def exists(save_path: str) -> bool: Returns: Whether the files exist """ - feature_path = _get_feature_path(save_path) - target_path = _get_target_path(save_path) + feature_exists = os.path.exists(_get_feature_path(save_path)) + target_no_index_exists = os.path.exists(_get_target_path(save_path, None)) + target_index_exists = os.path.exists(_get_target_path(save_path, 0)) - return os.path.exists(feature_path) and os.path.exists(target_path) + return feature_exists and (target_no_index_exists or target_index_exists) -def _get_feature_path(save_path): +def _get_feature_path(save_path: str) -> str: if save_path.endswith(".npy"): save_path = save_path[:-4] feature_path = f"{save_path}_features.npy" @@ -104,10 +119,11 @@ def _get_feature_path(save_path): return feature_path -def _get_target_path(save_path): +def _get_target_path(save_path: str, target_index: Optional[int]) -> str: if save_path.endswith(".npy"): save_path = save_path[:-4] - target_path = f"{save_path}_targets.npy" + suffix = "" if target_index is None else f"_{target_index}" + target_path = f"{save_path}_targets{suffix}.npy" return target_path diff --git a/tests/reader/test_ncmapss.py b/tests/reader/test_ncmapss.py index fc4477c..333ffd0 100644 --- a/tests/reader/test_ncmapss.py +++ b/tests/reader/test_ncmapss.py @@ -1,13 +1,18 @@ +import os +import shutil +from pathlib import Path + import numpy as np import pytest +import rul_datasets from rul_datasets.reader.ncmapss import NCmapssReader @pytest.fixture() def prepared_ncmapss(): for fd in range(1, 8): - NCmapssReader(fd).prepare_data() + NCmapssReader(fd).prepare_data(cache=False) @pytest.mark.parametrize("fd", list(range(1, 8))) @@ -166,3 +171,31 @@ def test_scaling_range_is_tuple(scaling_range): assert isinstance(reader.scaling_range, tuple) assert reader.scaling_range == (0, 1) + + +@pytest.mark.needs_data +def test_cache(prepared_ncmapss, tmp_path): + ncmapss_files = Path(NCmapssReader._NCMAPSS_ROOT).rglob("*") + linked_files = [] + for file in ncmapss_files: + if str(file).endswith(".h5"): + linked_file = tmp_path / file.name + os.symlink(file, linked_file) + linked_files.append(linked_file) + + reader = NCmapssReader(1) + cached_reader = NCmapssReader(1) + cached_reader._NCMAPSS_ROOT = tmp_path + cached_reader.prepare_data(cache=True) + + # remove linked files so that only cached files can be used + for file in linked_files: + os.remove(file) + + org_features, org_targets = reader.load_split("dev") + cached_features, cached_targets = cached_reader.load_split("dev") + + for org_feat, cached_feat in zip(org_features, cached_features): + np.testing.assert_almost_equal(org_feat, cached_feat) + for org_targ, cached_targ in zip(org_targets, cached_targets): + np.testing.assert_almost_equal(org_targ, cached_targ) diff --git a/tests/reader/test_saving.py b/tests/reader/test_saving.py index 38df780..5cf439d 100644 --- a/tests/reader/test_saving.py +++ b/tests/reader/test_saving.py @@ -23,6 +23,23 @@ def test_save(tmp_path, file_name): npt.assert_equal(loaded_targets, targets) +@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"]) +def test_save_multi_target(tmp_path, file_name): + features = np.empty((10, 2, 5)) + targets_0 = np.empty((10,)) + targets_1 = np.empty((10,)) + save_path = os.path.join(tmp_path, file_name) + saving.save(save_path, features, targets_0, targets_1) + + exp_save_path = save_path.replace(".npy", "") + loaded_features = np.load(exp_save_path + "_features.npy") + loaded_targets_0 = np.load(exp_save_path + "_targets_0.npy") + loaded_targets_1 = np.load(exp_save_path + "_targets_1.npy") + npt.assert_equal(loaded_features, features) + npt.assert_equal(loaded_targets_0, targets_0) + npt.assert_equal(loaded_targets_1, targets_1) + + @pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"]) def test_load(tmp_path, file_name): features = np.empty((10, 2, 5)) @@ -37,6 +54,23 @@ def test_load(tmp_path, file_name): npt.assert_equal(loaded_targets, targets) +@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"]) +def test_load_multi_target(tmp_path, file_name): + features = np.empty((10, 2, 5)) + targets_0 = np.empty((10,)) + targets_1 = np.empty((10,)) + exp_file_name = file_name.replace(".npy", "") + np.save(os.path.join(tmp_path, f"{exp_file_name}_features.npy"), features) + np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_0.npy"), targets_0) + np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_1.npy"), targets_1) + + save_path = os.path.join(tmp_path, file_name) + loaded_features, loaded_targets_0, loaded_targets_1 = saving.load(save_path) + npt.assert_equal(loaded_features, features) + npt.assert_equal(loaded_targets_0, targets_0) + npt.assert_equal(loaded_targets_1, targets_1) + + @mock.patch("rul_datasets.reader.saving.load", return_value=(None, None)) @pytest.mark.parametrize("file_names", [["run1", "run2"], []]) def test_load_multiple(mock_load, file_names): @@ -59,6 +93,18 @@ def test_exists(tmp_path, file_name): assert saving.exists(save_path) +@pytest.mark.parametrize("file_name", ["run", "run.npy"]) +def test_exists_multi_target(tmp_path, file_name): + save_path = os.path.join(tmp_path, file_name) + assert not saving.exists(save_path) + + Path(os.path.join(tmp_path, "run_features.npy")).touch() + assert not saving.exists(save_path) + + Path(os.path.join(tmp_path, "run_targets_0.npy")).touch() + assert saving.exists(save_path) + + @pytest.mark.parametrize("columns", [[0], [0, 1]]) @pytest.mark.parametrize( "file_name", ["raw_features.csv", "raw_features_corrupted.csv"]