Skip to content

Commit

Permalink
Nit on tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Faury committed Oct 27, 2024
1 parent 53c2c80 commit 9a587df
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def test_correct_range(self, num_categories: int) -> None:
assert actions.min() >= 0
assert actions.max() < num_categories

def test_bounded_gradients(self, distribution: type) -> None:
def test_bounded_gradients(self) -> None:
logits = torch.tensor(
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
requires_grad=True,
Expand Down Expand Up @@ -750,7 +750,7 @@ def test_generate_ordinal_logits_numerical(self) -> None:

class TestOneHotOrdinal:
@pytest.mark.parametrize("dtype", [torch.float, torch.double])
@pytest.mark.parametrize("device", ("cpu", "meta"))
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("logit_shape", [(10,), (10, 10), (5, 10, 20)])
def test_correct_sampling_shape(
self, logit_shape: tuple[int, ...], dtype: torch.dtype, device: str
Expand Down Expand Up @@ -780,7 +780,7 @@ def test_correct_range(self, num_categories: int) -> None:
assert torch.all(actions.sum(-1))
assert actions.shape[-1] == num_categories

def test_bounded_gradients(self, distribution: type) -> None:
def test_bounded_gradients(self) -> None:
logits = torch.tensor(
[[1.0, 0.0, torch.finfo().max], [1.0, 0.0, torch.finfo().min]],
requires_grad=True,
Expand Down

0 comments on commit 9a587df

Please sign in to comment.