From 698bc86b32df888990a279ad8afbffab59f44824 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Tue, 25 Jul 2023 13:44:35 +0200 Subject: [PATCH] Explicit tolerances for cdist tests --- tests/test_utils/test_cdist.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_utils/test_cdist.py b/tests/test_utils/test_cdist.py index 9486a0c..16a04d9 100644 --- a/tests/test_utils/test_cdist.py +++ b/tests/test_utils/test_cdist.py @@ -24,24 +24,28 @@ @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_all(dtype: torch.dtype) -> None: + tol = torch.finfo(dtype).eps * 10 + x = torch.randn(2, 3, 4, dtype=dtype) d1 = util.cdist(x) d2 = util.distance.cdist_direct_expansion(x, x, p=2) d3 = util.distance.euclidean_dist_quadratic_expansion(x, x) - assert pytest.approx(d1) == d2 - assert pytest.approx(d2) == d3 - assert pytest.approx(d3) == d1 + assert pytest.approx(d1, abs=tol) == d2 + assert pytest.approx(d2, abs=tol) == d3 + assert pytest.approx(d3, abs=tol) == d1 @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("p", [2, 3, 4, 5]) def test_ps(dtype: torch.dtype, p: int) -> None: + tol = torch.finfo(dtype).eps * 10 + x = torch.randn(2, 4, 5, dtype=dtype) y = torch.randn(2, 4, 5, dtype=dtype) d1 = util.cdist(x, y, p=p) d2 = torch.cdist(x, y, p=p) - assert pytest.approx(d1) == d2 + assert pytest.approx(d1, abs=tol) == d2