From 9f6d94122e53f1d5f66d4ba871e27e5ffc679f05 Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 11 Apr 2024 15:23:41 +0200 Subject: [PATCH] fix: apply max RUL for NCMAPSS correctly (#60) --- rul_datasets/reader/ncmapss.py | 2 ++ tests/reader/test_ncmapss.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/rul_datasets/reader/ncmapss.py b/rul_datasets/reader/ncmapss.py index 44bfe08..43db771 100644 --- a/rul_datasets/reader/ncmapss.py +++ b/rul_datasets/reader/ncmapss.py @@ -256,6 +256,8 @@ def load_complete_split( self._window_by_cycle(*unit) for unit in zip(features, targets, auxiliary) ] features, targets = zip(*windowed) + if self.max_rul is not None: + targets = [np.clip(t, 0, self.max_rul) for t in targets] return list(features), list(targets) diff --git a/tests/reader/test_ncmapss.py b/tests/reader/test_ncmapss.py index c476574..57d9307 100644 --- a/tests/reader/test_ncmapss.py +++ b/tests/reader/test_ncmapss.py @@ -87,6 +87,16 @@ def test_scaling(fd, prepared_ncmapss): assert np.all(np.min(feat, axis=(0, 1)).round(6) >= 0) +@pytest.mark.needs_data +@pytest.mark.parametrize("max_rul", [65, None]) +def test_max_rul(max_rul, prepared_ncmapss): + reader = NCmapssReader(1, max_rul=max_rul) + _, targets = reader.load_split("dev") + + for targ in targets: + assert np.all(targ <= (max_rul or np.inf)) + + @pytest.mark.needs_data def test__split_by_unit(prepared_ncmapss): reader = NCmapssReader(1)