Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: spead up loading NCMAPSS #62

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 41 additions & 6 deletions rul_datasets/reader/ncmapss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import tempfile
import warnings
import zipfile
from pathlib import Path
from typing import Tuple, List, Optional, Union, Dict

import h5py # type: ignore[import]
Expand All @@ -23,8 +24,7 @@

from rul_datasets import utils
from rul_datasets.reader.data_root import get_data_root
from rul_datasets.reader import AbstractReader, scaling

from rul_datasets.reader import AbstractReader, scaling, saving

NCMAPSS_DRIVE_ID = "1X9pHm2E3U0bZZbXIhJubVGSL3rtzqFkn"

Expand Down Expand Up @@ -206,23 +206,37 @@ def fds(self) -> List[int]:
"""Indices of the available sub-datasets."""
return list(self._WINDOW_SIZES)

def prepare_data(self) -> None:
def prepare_data(self, cache: bool = True) -> None:
"""
Prepare the N-C-MAPSS dataset. This function needs to be called before using the
dataset for the first time.
Prepare the N-C-MAPSS dataset. This function needs to be called before using
the dataset for the first time. The dataset is cached for faster loading in
the future. This behavior can be disabled to save disk space by setting
`cache` to `False`.

The dataset is assumed to be present in the data root directory. The training
data is then split into development and validation set. Afterward, a scaler
is fit on the development features if it was not already done previously.

Args:
cache: Whether to cache the data for faster loading in the future.
"""
if not os.path.exists(self._NCMAPSS_ROOT):
_download_ncmapss(self._NCMAPSS_ROOT)
if cache and not self._cache_exists():
self._cache_data()
if not os.path.exists(self._get_scaler_path()):
features, _, _ = self._load_data("dev")
scaler = scaling.fit_scaler(features, MinMaxScaler(self.scaling_range))
scaling.save_scaler(scaler, self._get_scaler_path())

def _get_scaler_path(self):
def _cache_data(self) -> None:
os.makedirs(self._get_cache_path(), exist_ok=True)
features, targets, auxiliary = self._load_raw_data()
features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary)
for i, (f, t, a) in enumerate(zip(features, targets, auxiliary)):
saving.save(str(self._get_cache_path() / f"{i}.npy"), f, t, a)

def _get_scaler_path(self) -> str:
file_name = (
f"scaler_{self.fd}_{self.run_split_dist['dev']}_{self.scaling_range}.pkl"
)
Expand Down Expand Up @@ -264,6 +278,21 @@ def load_complete_split(
def _load_data(
self, split: str
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
if self._cache_exists():
features, targets, auxiliary = self._load_cached_data(split)
else:
features, targets, auxiliary = self._load_original_data(split)

return features, targets, auxiliary

def _load_cached_data(self, split: str):
unit_idx = self.run_split_dist[split]
save_paths = [str(self._get_cache_path() / f"{i}.npy") for i in unit_idx]
features, targets, auxiliary = saving.load_multiple(save_paths)

return features, targets, auxiliary

def _load_original_data(self, split):
features, targets, auxiliary = self._load_raw_data()
features, targets, auxiliary = self._split_by_unit(features, targets, auxiliary)
features = self._select_units(features, split)
Expand Down Expand Up @@ -368,6 +397,12 @@ def _calc_default_window_size(self):

return max(*max_window_size)

def _cache_exists(self) -> bool:
return saving.exists(str(self._get_cache_path() / "0.npy"))

def _get_cache_path(self):
return Path(self._NCMAPSS_ROOT) / f"DS{self.fd:02d}"


def _download_ncmapss(data_root):
with tempfile.TemporaryDirectory() as tmp_path:
Expand Down
52 changes: 34 additions & 18 deletions rul_datasets/reader/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import tqdm # type: ignore


def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None:
def save(save_path: str, features: np.ndarray, *targets: np.ndarray) -> None:
"""
Save features and targets of a run to .npy files.

Expand All @@ -21,15 +21,20 @@ def save(save_path: str, features: np.ndarray, targets: np.ndarray) -> None:
Args:
save_path: The path including file name to save the arrays to.
features: The feature array to save.
targets: The targets array to save.
targets: The targets arrays to save.
"""
feature_path = _get_feature_path(save_path)
np.save(feature_path, features, allow_pickle=False)
target_path = _get_target_path(save_path)
np.save(target_path, targets, allow_pickle=False)
if len(targets) == 1: # keeps backward compat for when only one target was allowed
target_path = _get_target_path(save_path, None)
np.save(target_path, targets[0], allow_pickle=False)
else:
for i, target in enumerate(targets):
target_path = _get_target_path(save_path, i)
np.save(target_path, target, allow_pickle=False)


def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]:
def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, ...]:
"""
Load features and targets of a run from .npy files.

Expand All @@ -50,15 +55,24 @@ def load(save_path: str, memmap: bool = False) -> Tuple[np.ndarray, np.ndarray]:
memmap_mode: Optional[Literal["r"]] = "r" if memmap else None
feature_path = _get_feature_path(save_path)
features = np.load(feature_path, memmap_mode, allow_pickle=False)
target_path = _get_target_path(save_path)
targets = np.load(target_path, memmap_mode, allow_pickle=False)
if os.path.exists(_get_target_path(save_path, None)):
# keeps backward compat for when only one target was allowed
target_path = _get_target_path(save_path, None)
targets = [np.load(target_path, memmap_mode, allow_pickle=False)]
else:
i = 0
targets = []
while os.path.exists(_get_target_path(save_path, i)):
target_path = _get_target_path(save_path, i)
targets.append(np.load(target_path, memmap_mode, allow_pickle=False))
i += 1

return features, targets
return features, *targets


def load_multiple(
save_paths: List[str], memmap: bool = False
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
) -> Tuple[List[np.ndarray], ...]:
"""
Load multiple runs with the [load][rul_datasets.reader.saving.load] function.

Expand All @@ -72,11 +86,11 @@ def load_multiple(
"""
if save_paths:
runs = [load(save_path, memmap) for save_path in save_paths]
features, targets = [list(x) for x in zip(*runs)]
features, *targets = [list(x) for x in zip(*runs)]
else:
features, targets = [], []
features, targets = [], [[]]

return features, targets
return features, *targets


def exists(save_path: str) -> bool:
Expand All @@ -90,24 +104,26 @@ def exists(save_path: str) -> bool:
Returns:
Whether the files exist
"""
feature_path = _get_feature_path(save_path)
target_path = _get_target_path(save_path)
feature_exists = os.path.exists(_get_feature_path(save_path))
target_no_index_exists = os.path.exists(_get_target_path(save_path, None))
target_index_exists = os.path.exists(_get_target_path(save_path, 0))

return os.path.exists(feature_path) and os.path.exists(target_path)
return feature_exists and (target_no_index_exists or target_index_exists)


def _get_feature_path(save_path):
def _get_feature_path(save_path: str) -> str:
if save_path.endswith(".npy"):
save_path = save_path[:-4]
feature_path = f"{save_path}_features.npy"

return feature_path


def _get_target_path(save_path):
def _get_target_path(save_path: str, target_index: Optional[int]) -> str:
if save_path.endswith(".npy"):
save_path = save_path[:-4]
target_path = f"{save_path}_targets.npy"
suffix = "" if target_index is None else f"_{target_index}"
target_path = f"{save_path}_targets{suffix}.npy"

return target_path

Expand Down
35 changes: 34 additions & 1 deletion tests/reader/test_ncmapss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import os
import shutil
from pathlib import Path

import numpy as np
import pytest

import rul_datasets
from rul_datasets.reader.ncmapss import NCmapssReader


@pytest.fixture()
def prepared_ncmapss():
for fd in range(1, 8):
NCmapssReader(fd).prepare_data()
NCmapssReader(fd).prepare_data(cache=False)


@pytest.mark.parametrize("fd", list(range(1, 8)))
Expand Down Expand Up @@ -166,3 +171,31 @@ def test_scaling_range_is_tuple(scaling_range):

assert isinstance(reader.scaling_range, tuple)
assert reader.scaling_range == (0, 1)


@pytest.mark.needs_data
def test_cache(prepared_ncmapss, tmp_path):
ncmapss_files = Path(NCmapssReader._NCMAPSS_ROOT).rglob("*")
linked_files = []
for file in ncmapss_files:
if str(file).endswith(".h5"):
linked_file = tmp_path / file.name
os.symlink(file, linked_file)
linked_files.append(linked_file)

reader = NCmapssReader(1)
cached_reader = NCmapssReader(1)
cached_reader._NCMAPSS_ROOT = tmp_path
cached_reader.prepare_data(cache=True)

# remove linked files so that only cached files can be used
for file in linked_files:
os.remove(file)

org_features, org_targets = reader.load_split("dev")
cached_features, cached_targets = cached_reader.load_split("dev")

for org_feat, cached_feat in zip(org_features, cached_features):
np.testing.assert_almost_equal(org_feat, cached_feat)
for org_targ, cached_targ in zip(org_targets, cached_targets):
np.testing.assert_almost_equal(org_targ, cached_targ)
46 changes: 46 additions & 0 deletions tests/reader/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ def test_save(tmp_path, file_name):
npt.assert_equal(loaded_targets, targets)


@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"])
def test_save_multi_target(tmp_path, file_name):
features = np.empty((10, 2, 5))
targets_0 = np.empty((10,))
targets_1 = np.empty((10,))
save_path = os.path.join(tmp_path, file_name)
saving.save(save_path, features, targets_0, targets_1)

exp_save_path = save_path.replace(".npy", "")
loaded_features = np.load(exp_save_path + "_features.npy")
loaded_targets_0 = np.load(exp_save_path + "_targets_0.npy")
loaded_targets_1 = np.load(exp_save_path + "_targets_1.npy")
npt.assert_equal(loaded_features, features)
npt.assert_equal(loaded_targets_0, targets_0)
npt.assert_equal(loaded_targets_1, targets_1)


@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"])
def test_load(tmp_path, file_name):
features = np.empty((10, 2, 5))
Expand All @@ -37,6 +54,23 @@ def test_load(tmp_path, file_name):
npt.assert_equal(loaded_targets, targets)


@pytest.mark.parametrize("file_name", ["run", "run.npy", "run.foo"])
def test_load_multi_target(tmp_path, file_name):
features = np.empty((10, 2, 5))
targets_0 = np.empty((10,))
targets_1 = np.empty((10,))
exp_file_name = file_name.replace(".npy", "")
np.save(os.path.join(tmp_path, f"{exp_file_name}_features.npy"), features)
np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_0.npy"), targets_0)
np.save(os.path.join(tmp_path, f"{exp_file_name}_targets_1.npy"), targets_1)

save_path = os.path.join(tmp_path, file_name)
loaded_features, loaded_targets_0, loaded_targets_1 = saving.load(save_path)
npt.assert_equal(loaded_features, features)
npt.assert_equal(loaded_targets_0, targets_0)
npt.assert_equal(loaded_targets_1, targets_1)


@mock.patch("rul_datasets.reader.saving.load", return_value=(None, None))
@pytest.mark.parametrize("file_names", [["run1", "run2"], []])
def test_load_multiple(mock_load, file_names):
Expand All @@ -59,6 +93,18 @@ def test_exists(tmp_path, file_name):
assert saving.exists(save_path)


@pytest.mark.parametrize("file_name", ["run", "run.npy"])
def test_exists_multi_target(tmp_path, file_name):
save_path = os.path.join(tmp_path, file_name)
assert not saving.exists(save_path)

Path(os.path.join(tmp_path, "run_features.npy")).touch()
assert not saving.exists(save_path)

Path(os.path.join(tmp_path, "run_targets_0.npy")).touch()
assert saving.exists(save_path)


@pytest.mark.parametrize("columns", [[0], [0, 1]])
@pytest.mark.parametrize(
"file_name", ["raw_features.csv", "raw_features_corrupted.csv"]
Expand Down
Loading