From f27733842d2c3342ba7073cd72234d4d0171d85e Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Fri, 3 Feb 2023 13:40:49 +0100 Subject: [PATCH] fix: adjust targets when extractor can change number of samples (#22) --- rul_datasets/core.py | 16 ++++++++++++---- tests/test_core.py | 20 +++++++++++++------- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/rul_datasets/core.py b/rul_datasets/core.py index c928768..f117292 100644 --- a/rul_datasets/core.py +++ b/rul_datasets/core.py @@ -61,7 +61,7 @@ def __init__( self, reader: AbstractReader, batch_size: int, - feature_extractor: Optional[Callable[[np.ndarray], np.ndarray]] = None, + feature_extractor: Optional[Callable] = None, window_size: Optional[int] = None, ): """ @@ -236,12 +236,20 @@ def _setup_split(self, split: str) -> Tuple[torch.Tensor, torch.Tensor]: def _apply_feature_extractor_per_run( self, features: List[np.ndarray], targets: List[np.ndarray] ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + extracted = (self._extract_and_window(f, t) for f, t in zip(features, targets)) + features, targets = zip(*extracted) + + return list(features), list(targets) + + def _extract_and_window( + self, features: np.ndarray, targets: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray]: if self.feature_extractor is not None: - features = [self.feature_extractor(f) for f in features] + features, targets = self.feature_extractor(features, targets) if self.window_size is not None: cutoff = self.window_size - 1 - features = [utils.extract_windows(f, self.window_size) for f in features] - targets = [t[cutoff:] for t in targets] + features = utils.extract_windows(features, self.window_size) + targets = targets[cutoff:] return features, targets diff --git a/tests/test_core.py b/tests/test_core.py index f67019a..27991d1 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -201,7 +201,7 @@ def test_feature_extractor(self, mock_loader): [np.zeros((8, 30, 14)) + np.arange(8)[:, None, None]], [np.arange(8)], ) - fe = lambda x: np.mean(x, axis=1) + fe = lambda x, y: (np.mean(x, axis=1), y) dataset = core.RulDataModule(mock_loader, 16, fe, window_size=2) dataset.setup() @@ -217,16 +217,22 @@ def test_feature_extractor_no_rewindowing(self, mock_loader): [np.zeros((8, 30, 14)) + np.arange(8)[:, None, None]], [np.arange(8)], ) - fe = lambda x: np.tile(x, (1, 2, 1)) # repeats window two times + fe = lambda x, y: ( + np.repeat(x, 2, axis=0), + np.repeat(y, 2), + ) # repeats window two times dataset = core.RulDataModule(mock_loader, 16, fe, window_size=None) dataset.setup() dev_data = dataset.to_dataset("dev") - assert len(dev_data) == 8 - for i, (feat, targ) in enumerate(dev_data): - assert feat.shape == torch.Size([14, 60]) - assert torch.dist(feat[:, :30], feat[:, 30:]) == 0.0 # fe applied correctly - assert targ == i + assert len(dev_data) == 16 + for i in range(0, len(dev_data), 2): + f0, t0 = dev_data[i] + f1, t1 = dev_data[i + 1] + assert f0.shape == torch.Size([14, 30]) + assert torch.dist(f0, f1) == 0 # each window is repeated twice + assert t0 == i // 2 # both windows share a label + assert t1 == i // 2 class DummyRul(reader.AbstractReader):