diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py index 40e52fe0df55..ff2f8fc04c39 100644 --- a/py-polars/polars/testing/asserts/frame.py +++ b/py-polars/polars/testing/asserts/frame.py @@ -85,6 +85,8 @@ def assert_frame_equal( ... AssertionError: values for column 'a' are different """ + __tracebackhide__ = True + lazy = _assert_correct_input_type(left, right) objects = "LazyFrames" if lazy else "DataFrames" @@ -132,6 +134,8 @@ def assert_frame_equal( def _assert_correct_input_type( left: DataFrame | LazyFrame, right: DataFrame | LazyFrame ) -> bool: + __tracebackhide__ = True + if isinstance(left, DataFrame) and isinstance(right, DataFrame): return False elif isinstance(left, LazyFrame) and isinstance(right, LazyFrame): @@ -153,6 +157,8 @@ def _assert_frame_schema_equal( check_column_order: bool, objects: str, ) -> None: + __tracebackhide__ = True + left_schema, right_schema = left.schema, right.schema # Fast path for equal frames @@ -253,6 +259,8 @@ def assert_frame_not_equal( ... AssertionError: frames are equal """ + __tracebackhide__ = True + try: assert_frame_equal( left=left, diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index 69d7e0aec44f..5bf691037ea9 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -83,6 +83,8 @@ def assert_series_equal( [left]: [1, 2, 3] [right]: [1, 5, 3] """ + __tracebackhide__ = True + if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] raise_assertion_error( "inputs", @@ -119,6 +121,8 @@ def _assert_series_values_equal( atol: float, categorical_as_str: bool, ) -> None: + __tracebackhide__ = True + """Assert that the values in both Series are equal.""" # Handle categoricals if categorical_as_str: @@ -191,6 +195,8 @@ def _assert_series_nested_values_equal( atol: float, categorical_as_str: bool, ) -> None: + __tracebackhide__ = True + # compare nested lists element-wise if _comparing_lists(left.dtype, right.dtype): for s1, s2 in zip(left, right): @@ -221,6 +227,7 @@ def _assert_series_nested_values_equal( def _assert_series_null_values_match(left: Series, right: Series) -> None: + __tracebackhide__ = True null_value_mismatch = left.is_null() != right.is_null() if null_value_mismatch.any(): raise_assertion_error( @@ -229,6 +236,7 @@ def _assert_series_null_values_match(left: Series, right: Series) -> None: def _assert_series_nan_values_match(left: Series, right: Series) -> None: + __tracebackhide__ = True if not _comparing_floats(left.dtype, right.dtype): return nan_value_mismatch = left.is_nan() != right.is_nan() @@ -270,6 +278,8 @@ def _assert_series_values_within_tolerance( rtol: float, atol: float, ) -> None: + __tracebackhide__ = True + left_unequal, right_unequal = left.filter(unequal), right.filter(unequal) difference = (left_unequal - right_unequal).abs() @@ -339,6 +349,8 @@ def assert_series_not_equal( ... AssertionError: Series are equal """ + __tracebackhide__ = True + try: assert_series_equal( left=left, diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py index a5d00abc6eb9..bf8727c178a1 100644 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -10,6 +10,7 @@ from polars.testing import assert_frame_equal, assert_frame_not_equal nan = float("nan") +pytest_plugins = ["pytester"] @pytest.mark.parametrize( @@ -366,3 +367,66 @@ def test_assert_frame_not_equal() -> None: df = pl.DataFrame({"a": [1, 2]}) with pytest.raises(AssertionError, match="frames are equal"): assert_frame_not_equal(df, df) + + +def test_tracebackhide(testdir: pytest.Testdir) -> None: + testdir.makefile( + ".py", + test_path="""\ +import polars as pl +from polars.testing import assert_frame_equal, assert_frame_not_equal + +def test_frame_equal_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 3]}) + assert_frame_equal(df1, df2) + +def test_frame_not_equal_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = pl.DataFrame({"a": [1, 2]}) + assert_frame_not_equal(df1, df2) + +def test_frame_data_type_fail(): + df1 = pl.DataFrame({"a": [1, 2]}) + df2 = {"a": [1, 2]} + assert_frame_equal(df1, df2) + +def test_frame_schema_fail(): + df1 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int64}) + df2 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int32}) + assert_frame_equal(df1, df2) +""", + ) + result = testdir.runpytest() + result.assert_outcomes(passed=0, failed=4) + stdout = "\n".join(result.outlines) + + assert "polars/py-polars/polars/testing" not in stdout + + # The above should catch any polars testing functions that appear in the + # stack trace. But we keep the following checks (for specific function + # names) just to double-check. + + assert "def assert_frame_equal" not in stdout + assert "def assert_frame_not_equal" not in stdout + assert "def _assert_correct_input_type" not in stdout + assert "def _assert_frame_schema_equal" not in stdout + + assert "def assert_series_equal" not in stdout + assert "def assert_series_not_equal" not in stdout + assert "def _assert_series_values_equal" not in stdout + assert "def _assert_series_nested_values_equal" not in stdout + assert "def _assert_series_null_values_match" not in stdout + assert "def _assert_series_nan_values_match" not in stdout + assert "def _assert_series_values_within_tolerance" not in stdout + + # Make sure the tests are failing for the expected reason (e.g. not because + # an import is missing or something like that): + + assert ( + "AssertionError: DataFrames are different (value mismatch for column 'a')" + in stdout + ) + assert "AssertionError: frames are equal" in stdout + assert "AssertionError: inputs are different (unexpected input types)" in stdout + assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index 4b21a921edcc..e676be77b1ac 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -11,6 +11,7 @@ from polars.testing import assert_series_equal, assert_series_not_equal nan = float("nan") +pytest_plugins = ["pytester"] def test_compare_series_value_mismatch() -> None: @@ -636,3 +637,81 @@ def test_assert_series_equal_w_large_integers_12328() -> None: right = pl.Series([1577840521123543]) with pytest.raises(AssertionError): assert_series_equal(left, right) + + +def test_tracebackhide(testdir: pytest.Testdir) -> None: + testdir.makefile( + ".py", + test_path="""\ +import polars as pl +from polars.testing import assert_series_equal, assert_series_not_equal + +nan = float("nan") + +def test_series_equal_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, 3]) + assert_series_equal(s1, s2) + +def test_series_not_equal_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, 2]) + assert_series_not_equal(s1, s2) + +def test_series_nested_fail(): + s1 = pl.Series([[1, 2], [3, 4]]) + s2 = pl.Series([[1, 2], [3, 5]]) + assert_series_equal(s1, s2) + +def test_series_null_fail(): + s1 = pl.Series([1, 2]) + s2 = pl.Series([1, None]) + assert_series_equal(s1, s2) + +def test_series_nan_fail(): + s1 = pl.Series([1.0, 2.0]) + s2 = pl.Series([1.0, nan]) + assert_series_equal(s1, s2) + +def test_series_float_tolerance_fail(): + s1 = pl.Series([1.0, 2.0]) + s2 = pl.Series([1.0, 2.1]) + assert_series_equal(s1, s2) + +def test_series_schema_fail(): + s1 = pl.Series([1, 2], dtype=pl.Int64) + s2 = pl.Series([1, 2], dtype=pl.Int32) + assert_series_equal(s1, s2) + +def test_series_data_type_fail(): + s1 = pl.Series([1, 2]) + s2 = [1, 2] + assert_series_equal(s1, s2) +""", + ) + result = testdir.runpytest() + result.assert_outcomes(passed=0, failed=8) + stdout = "\n".join(result.outlines) + + assert "polars/py-polars/polars/testing" not in stdout + + # The above should catch any polars testing functions that appear in the + # stack trace. But we keep the following checks (for specific function + # names) just to double-check. + + assert "def assert_series_equal" not in stdout + assert "def assert_series_not_equal" not in stdout + assert "def _assert_series_values_equal" not in stdout + assert "def _assert_series_nested_values_equal" not in stdout + assert "def _assert_series_null_values_match" not in stdout + assert "def _assert_series_nan_values_match" not in stdout + assert "def _assert_series_values_within_tolerance" not in stdout + + # Make sure the tests are failing for the expected reason (e.g. not because + # an import is missing or something like that): + + assert "AssertionError: Series are different (exact value mismatch)" in stdout + assert "AssertionError: Series are equal" in stdout + assert "AssertionError: Series are different (nan value mismatch)" in stdout + assert "AssertionError: Series are different (dtype mismatch)" in stdout + assert "AssertionError: inputs are different (unexpected input types)" in stdout