diff --git a/tests/test_utils/test_cdist.py b/tests/test_utils/test_cdist.py index 16a04d9..dced45a 100644 --- a/tests/test_utils/test_cdist.py +++ b/tests/test_utils/test_cdist.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) def test_all(dtype: torch.dtype) -> None: - tol = torch.finfo(dtype).eps * 10 + tol = 1e-6 if dtype == torch.float else 1e-14 x = torch.randn(2, 3, 4, dtype=dtype) @@ -40,7 +40,7 @@ def test_all(dtype: torch.dtype) -> None: @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) @pytest.mark.parametrize("p", [2, 3, 4, 5]) def test_ps(dtype: torch.dtype, p: int) -> None: - tol = torch.finfo(dtype).eps * 10 + tol = 1e-6 if dtype == torch.float else 1e-14 x = torch.randn(2, 4, 5, dtype=dtype) y = torch.randn(2, 4, 5, dtype=dtype)