From d6d51a921c13f6fdfca7521a98afa44e98a64ffa Mon Sep 17 00:00:00 2001 From: Tilman Krokotsch Date: Thu, 8 Dec 2022 15:44:31 +0100 Subject: [PATCH] fix: target shape of dummy (#10) target shape of dummy should be (len,) not (len, 1) --- rul_datasets/reader/dummy.py | 4 ++-- tests/reader/test_dummy.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/rul_datasets/reader/dummy.py b/rul_datasets/reader/dummy.py index 5bd4c0a..7a04d2f 100644 --- a/rul_datasets/reader/dummy.py +++ b/rul_datasets/reader/dummy.py @@ -131,14 +131,14 @@ def _generate_targets(self, rng): t = np.clip(np.arange(length, 1, -1), a_min=0, a_max=self.max_rul) t = t.astype(np.float) - return t[:, None] + return t def _generate_features(self, rng, targets): steady = -0.05 * targets + self._OFFSET[self.fd] + rng.normal() * 0.01 noise = rng.normal(size=targets.shape) * self._NOISE_FACTOR[self.fd] f = np.exp(steady) + noise - return f + return f[:, None] def _truncate_test_split(self, rng, features, targets): """Extract a single window from a random position of the time series.""" diff --git a/tests/reader/test_dummy.py b/tests/reader/test_dummy.py index 18278d5..1997719 100644 --- a/tests/reader/test_dummy.py +++ b/tests/reader/test_dummy.py @@ -22,6 +22,7 @@ def _assert_run_correct(run, run_target, win): assert win == run.shape[1] assert 1 == run.shape[2] assert len(run) == len(run_target) + assert run_target.shape == (len(run_target),) assert np.float == run.dtype assert np.float == run_target.dtype