Skip to content

Commit

Permalink
feat(python): Hide polars.testing.* in pytest stack traces (#14399)
Browse files Browse the repository at this point in the history
  • Loading branch information
kalekundert authored Feb 11, 2024
1 parent 84d5920 commit 81bd89c
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 0 deletions.
8 changes: 8 additions & 0 deletions py-polars/polars/testing/asserts/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def assert_frame_equal(
...
AssertionError: values for column 'a' are different
"""
__tracebackhide__ = True

lazy = _assert_correct_input_type(left, right)
objects = "LazyFrames" if lazy else "DataFrames"

Expand Down Expand Up @@ -132,6 +134,8 @@ def assert_frame_equal(
def _assert_correct_input_type(
left: DataFrame | LazyFrame, right: DataFrame | LazyFrame
) -> bool:
__tracebackhide__ = True

if isinstance(left, DataFrame) and isinstance(right, DataFrame):
return False
elif isinstance(left, LazyFrame) and isinstance(right, LazyFrame):
Expand All @@ -153,6 +157,8 @@ def _assert_frame_schema_equal(
check_column_order: bool,
objects: str,
) -> None:
__tracebackhide__ = True

left_schema, right_schema = left.schema, right.schema

# Fast path for equal frames
Expand Down Expand Up @@ -253,6 +259,8 @@ def assert_frame_not_equal(
...
AssertionError: frames are equal
"""
__tracebackhide__ = True

try:
assert_frame_equal(
left=left,
Expand Down
12 changes: 12 additions & 0 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def assert_series_equal(
[left]: [1, 2, 3]
[right]: [1, 5, 3]
"""
__tracebackhide__ = True

if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr]
raise_assertion_error(
"inputs",
Expand Down Expand Up @@ -119,6 +121,8 @@ def _assert_series_values_equal(
atol: float,
categorical_as_str: bool,
) -> None:
__tracebackhide__ = True

"""Assert that the values in both Series are equal."""
# Handle categoricals
if categorical_as_str:
Expand Down Expand Up @@ -191,6 +195,8 @@ def _assert_series_nested_values_equal(
atol: float,
categorical_as_str: bool,
) -> None:
__tracebackhide__ = True

# compare nested lists element-wise
if _comparing_lists(left.dtype, right.dtype):
for s1, s2 in zip(left, right):
Expand Down Expand Up @@ -221,6 +227,7 @@ def _assert_series_nested_values_equal(


def _assert_series_null_values_match(left: Series, right: Series) -> None:
__tracebackhide__ = True
null_value_mismatch = left.is_null() != right.is_null()
if null_value_mismatch.any():
raise_assertion_error(
Expand All @@ -229,6 +236,7 @@ def _assert_series_null_values_match(left: Series, right: Series) -> None:


def _assert_series_nan_values_match(left: Series, right: Series) -> None:
__tracebackhide__ = True
if not _comparing_floats(left.dtype, right.dtype):
return
nan_value_mismatch = left.is_nan() != right.is_nan()
Expand Down Expand Up @@ -270,6 +278,8 @@ def _assert_series_values_within_tolerance(
rtol: float,
atol: float,
) -> None:
__tracebackhide__ = True

left_unequal, right_unequal = left.filter(unequal), right.filter(unequal)

difference = (left_unequal - right_unequal).abs()
Expand Down Expand Up @@ -339,6 +349,8 @@ def assert_series_not_equal(
...
AssertionError: Series are equal
"""
__tracebackhide__ = True

try:
assert_series_equal(
left=left,
Expand Down
64 changes: 64 additions & 0 deletions py-polars/tests/unit/testing/test_assert_frame_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from polars.testing import assert_frame_equal, assert_frame_not_equal

nan = float("nan")
pytest_plugins = ["pytester"]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -366,3 +367,66 @@ def test_assert_frame_not_equal() -> None:
df = pl.DataFrame({"a": [1, 2]})
with pytest.raises(AssertionError, match="frames are equal"):
assert_frame_not_equal(df, df)


def test_tracebackhide(testdir: pytest.Testdir) -> None:
testdir.makefile(
".py",
test_path="""\
import polars as pl
from polars.testing import assert_frame_equal, assert_frame_not_equal
def test_frame_equal_fail():
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 3]})
assert_frame_equal(df1, df2)
def test_frame_not_equal_fail():
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2]})
assert_frame_not_equal(df1, df2)
def test_frame_data_type_fail():
df1 = pl.DataFrame({"a": [1, 2]})
df2 = {"a": [1, 2]}
assert_frame_equal(df1, df2)
def test_frame_schema_fail():
df1 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int64})
df2 = pl.DataFrame({"a": [1, 2]}, {"a": pl.Int32})
assert_frame_equal(df1, df2)
""",
)
result = testdir.runpytest()
result.assert_outcomes(passed=0, failed=4)
stdout = "\n".join(result.outlines)

assert "polars/py-polars/polars/testing" not in stdout

# The above should catch any polars testing functions that appear in the
# stack trace. But we keep the following checks (for specific function
# names) just to double-check.

assert "def assert_frame_equal" not in stdout
assert "def assert_frame_not_equal" not in stdout
assert "def _assert_correct_input_type" not in stdout
assert "def _assert_frame_schema_equal" not in stdout

assert "def assert_series_equal" not in stdout
assert "def assert_series_not_equal" not in stdout
assert "def _assert_series_values_equal" not in stdout
assert "def _assert_series_nested_values_equal" not in stdout
assert "def _assert_series_null_values_match" not in stdout
assert "def _assert_series_nan_values_match" not in stdout
assert "def _assert_series_values_within_tolerance" not in stdout

# Make sure the tests are failing for the expected reason (e.g. not because
# an import is missing or something like that):

assert (
"AssertionError: DataFrames are different (value mismatch for column 'a')"
in stdout
)
assert "AssertionError: frames are equal" in stdout
assert "AssertionError: inputs are different (unexpected input types)" in stdout
assert "AssertionError: DataFrames are different (dtypes do not match)" in stdout
79 changes: 79 additions & 0 deletions py-polars/tests/unit/testing/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from polars.testing import assert_series_equal, assert_series_not_equal

nan = float("nan")
pytest_plugins = ["pytester"]


def test_compare_series_value_mismatch() -> None:
Expand Down Expand Up @@ -636,3 +637,81 @@ def test_assert_series_equal_w_large_integers_12328() -> None:
right = pl.Series([1577840521123543])
with pytest.raises(AssertionError):
assert_series_equal(left, right)


def test_tracebackhide(testdir: pytest.Testdir) -> None:
testdir.makefile(
".py",
test_path="""\
import polars as pl
from polars.testing import assert_series_equal, assert_series_not_equal
nan = float("nan")
def test_series_equal_fail():
s1 = pl.Series([1, 2])
s2 = pl.Series([1, 3])
assert_series_equal(s1, s2)
def test_series_not_equal_fail():
s1 = pl.Series([1, 2])
s2 = pl.Series([1, 2])
assert_series_not_equal(s1, s2)
def test_series_nested_fail():
s1 = pl.Series([[1, 2], [3, 4]])
s2 = pl.Series([[1, 2], [3, 5]])
assert_series_equal(s1, s2)
def test_series_null_fail():
s1 = pl.Series([1, 2])
s2 = pl.Series([1, None])
assert_series_equal(s1, s2)
def test_series_nan_fail():
s1 = pl.Series([1.0, 2.0])
s2 = pl.Series([1.0, nan])
assert_series_equal(s1, s2)
def test_series_float_tolerance_fail():
s1 = pl.Series([1.0, 2.0])
s2 = pl.Series([1.0, 2.1])
assert_series_equal(s1, s2)
def test_series_schema_fail():
s1 = pl.Series([1, 2], dtype=pl.Int64)
s2 = pl.Series([1, 2], dtype=pl.Int32)
assert_series_equal(s1, s2)
def test_series_data_type_fail():
s1 = pl.Series([1, 2])
s2 = [1, 2]
assert_series_equal(s1, s2)
""",
)
result = testdir.runpytest()
result.assert_outcomes(passed=0, failed=8)
stdout = "\n".join(result.outlines)

assert "polars/py-polars/polars/testing" not in stdout

# The above should catch any polars testing functions that appear in the
# stack trace. But we keep the following checks (for specific function
# names) just to double-check.

assert "def assert_series_equal" not in stdout
assert "def assert_series_not_equal" not in stdout
assert "def _assert_series_values_equal" not in stdout
assert "def _assert_series_nested_values_equal" not in stdout
assert "def _assert_series_null_values_match" not in stdout
assert "def _assert_series_nan_values_match" not in stdout
assert "def _assert_series_values_within_tolerance" not in stdout

# Make sure the tests are failing for the expected reason (e.g. not because
# an import is missing or something like that):

assert "AssertionError: Series are different (exact value mismatch)" in stdout
assert "AssertionError: Series are equal" in stdout
assert "AssertionError: Series are different (nan value mismatch)" in stdout
assert "AssertionError: Series are different (dtype mismatch)" in stdout
assert "AssertionError: inputs are different (unexpected input types)" in stdout

0 comments on commit 81bd89c

Please sign in to comment.