Skip to content

Commit

Permalink
Explicit tolerances for cdist tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede committed Jul 25, 2023
1 parent 38e58f0 commit 698bc86
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/test_utils/test_cdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,28 @@

@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
def test_all(dtype: torch.dtype) -> None:
tol = torch.finfo(dtype).eps * 10

x = torch.randn(2, 3, 4, dtype=dtype)

d1 = util.cdist(x)
d2 = util.distance.cdist_direct_expansion(x, x, p=2)
d3 = util.distance.euclidean_dist_quadratic_expansion(x, x)

assert pytest.approx(d1) == d2
assert pytest.approx(d2) == d3
assert pytest.approx(d3) == d1
assert pytest.approx(d1, abs=tol) == d2
assert pytest.approx(d2, abs=tol) == d3
assert pytest.approx(d3, abs=tol) == d1


@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
@pytest.mark.parametrize("p", [2, 3, 4, 5])
def test_ps(dtype: torch.dtype, p: int) -> None:
tol = torch.finfo(dtype).eps * 10

x = torch.randn(2, 4, 5, dtype=dtype)
y = torch.randn(2, 4, 5, dtype=dtype)

d1 = util.cdist(x, y, p=p)
d2 = torch.cdist(x, y, p=p)

assert pytest.approx(d1) == d2
assert pytest.approx(d1, abs=tol) == d2

0 comments on commit 698bc86

Please sign in to comment.