Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Faury committed Jan 10, 2025
1 parent f93e330 commit d9a4dab
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12630,9 +12630,7 @@ def test_transform_env(self, num_rewards, weights):

expected = sum(
w * r
for w, r in zip(
weights, rollout.get(("next", "reward")).split(1, dim=-1), strict=True
)
for w, r in zip(weights, rollout.get(("next", "reward")).split(1, dim=-1))
)
torch.testing.assert_close(scalar_reward, expected)

Expand All @@ -12654,7 +12652,7 @@ def test_transform_model(self, num_rewards, weights):
td = TensorDict({"reward": torch.randn(num_rewards)}, [])
model(td)

expected = sum(w * r for w, r in zip(weights, td["reward"], strict=True))
expected = sum(w * r for w, r in zip(weights, td["reward"]))
torch.testing.assert_close(td["scalar_reward"], expected)

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
Expand Down

0 comments on commit d9a4dab

Please sign in to comment.