Skip to content

Commit

Permalink
fix: target shape of dummy (#10)
Browse files Browse the repository at this point in the history
target shape of dummy should be (len,) not (len, 1)
  • Loading branch information
tilman151 authored Dec 8, 2022
1 parent 5c2848b commit d6d51a9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions rul_datasets/reader/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,14 @@ def _generate_targets(self, rng):
t = np.clip(np.arange(length, 1, -1), a_min=0, a_max=self.max_rul)
t = t.astype(np.float)

return t[:, None]
return t

def _generate_features(self, rng, targets):
steady = -0.05 * targets + self._OFFSET[self.fd] + rng.normal() * 0.01
noise = rng.normal(size=targets.shape) * self._NOISE_FACTOR[self.fd]
f = np.exp(steady) + noise

return f
return f[:, None]

def _truncate_test_split(self, rng, features, targets):
"""Extract a single window from a random position of the time series."""
Expand Down
1 change: 1 addition & 0 deletions tests/reader/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def _assert_run_correct(run, run_target, win):
assert win == run.shape[1]
assert 1 == run.shape[2]
assert len(run) == len(run_target)
assert run_target.shape == (len(run_target),)
assert np.float == run.dtype
assert np.float == run_target.dtype

Expand Down

0 comments on commit d6d51a9

Please sign in to comment.