Skip to content

Commit

Permalink
Switch back to pytest_runtest_protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
kenodegard committed Mar 26, 2024
1 parent 26b5d09 commit f2595e8
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 61 deletions.
172 changes: 116 additions & 56 deletions src/pytest_codspeed/plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import functools
import contextlib
import gc
import os
import pkgutil
Expand All @@ -10,14 +10,15 @@

import pytest
from _pytest.fixtures import FixtureManager
from _pytest.runner import runtestprotocol

from pytest_codspeed.utils import get_git_relative_uri

from . import __version__
from ._wrapper import get_lib

if TYPE_CHECKING:
from typing import Any, Callable, Iterator, ParamSpec, TypeVar
from typing import Callable, Iterator, ParamSpec, TypeVar

from ._wrapper import LibType

Expand Down Expand Up @@ -172,60 +173,122 @@ def pytest_collection_modifyitems(
items[:] = selected


def wrap_pyfunc_with_instrumentation(
def _run_with_instrumentation(
lib: LibType,
nodeid: str,
config: pytest.Config,
testfunction: Callable[P, T],
) -> Callable[P, T]:
@functools.wraps(testfunction)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
def __codspeed_root_frame__():
return testfunction(*args, **kwargs)

is_gc_enabled = gc.isenabled()
if is_gc_enabled:
gc.collect()
gc.disable()
try:
if SUPPORTS_PERF_TRAMPOLINE:
# Warmup CPython performance map cache
__codspeed_root_frame__()

lib.zero_stats()
lib.start_instrumentation()
try:
return __codspeed_root_frame__()
finally:
lib.stop_instrumentation()
uri = get_git_relative_uri(nodeid, config.rootpath)
lib.dump_stats_at(uri.encode("ascii"))
finally:
if is_gc_enabled:
gc.enable()

return wrapper


@pytest.hookimpl(hookwrapper=True)
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> Iterator[None]:
plugin = get_plugin(pyfuncitem.config)
if (
plugin.is_codspeed_enabled
and should_benchmark_item(pyfuncitem)
and not has_benchmark_fixture(pyfuncitem)
):
plugin.benchmark_count += 1
if plugin.lib is not None and plugin.should_measure:
pyfuncitem.obj = wrap_pyfunc_with_instrumentation(
plugin.lib,
pyfuncitem.nodeid,
pyfuncitem.config,
pyfuncitem.obj,
)
*args: P.args,
**kwargs: P.kwargs,
) -> T:
lib.zero_stats()
lib.start_instrumentation()

try:
return testfunction(*args, **kwargs)
finally:
lib.stop_instrumentation()
uri = get_git_relative_uri(nodeid, config.rootpath)
lib.dump_stats_at(uri.encode("ascii"))


@contextlib.contextmanager
def collect_garbage(is_benchmarking: bool) -> Iterator[None]:
if toggle := is_benchmarking and gc.isenabled():
gc.collect()
gc.disable()

try:
yield
finally:
# Re-enable garbage collection if it was enabled previously
if toggle:
gc.enable()


@contextlib.contextmanager
def prime_cache(
is_benchmarking: bool,
item: pytest.Item,
nextitem: pytest.Item | None,
) -> Iterator[None]:
if SUPPORTS_PERF_TRAMPOLINE and is_benchmarking:
runtestprotocol(item, log=False, nextitem=nextitem)

# Clear item's cashed results
_remove_cached_results_from_fixtures(item)
_remove_setup_state_from_session(item)

yield


@contextlib.contextmanager
def add_instrumentation(plugin: CodSpeedPlugin, item: pytest.Item) -> Iterator[None]:
if wrapped := (plugin.lib is not None and plugin.should_measure):
orig_obj = item.obj
item.obj = lambda *args, **kwargs: _run_with_instrumentation(
plugin.lib,
item.nodeid,
item.config,
orig_obj,
*args,
**kwargs,
)

try:
yield
finally:
# Restore unadorned function
if wrapped:
item.obj = orig_obj


def _remove_cached_results_from_fixtures(item: pytest.Item) -> None:
"""Borrowed from pytest_rerunfailures._remove_cached_results_from_failed_fixtures"""
fixtureinfo = getattr(item, "_fixtureinfo", None)
name2fixturedefs = getattr(fixtureinfo, "name2fixturedefs", {})
for fixture_defs in name2fixturedefs.values():
for fixture_def in fixture_defs:
setattr(fixture_def, "cached_result", None)


def _remove_setup_state_from_session(item: pytest.Item) -> None:
"""Borrowed from pytest_rerunfailures._remove_failed_setup_state_from_session"""
item.session._setupstate.stack = {}


@pytest.hookimpl(tryfirst=True)
def pytest_runtest_protocol(
item: pytest.Item,
nextitem: pytest.Item | None,
) -> bool | None:
plugin = get_plugin(item.config)
if not plugin.is_codspeed_enabled or not should_benchmark_item(item):
return None # Defer to default test protocol since no benchmarking is needed

if has_benchmark_fixture(item):
return None # Instrumentation is handled by the fixture

plugin.benchmark_count += 1
if not plugin.should_measure:
return None # Benchmark counted but will be run in the default protocol

ihook = item.ihook
ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)

is_benchmarking = plugin.is_codspeed_enabled and should_benchmark_item(item)
with (
collect_garbage(is_benchmarking),
prime_cache(is_benchmarking, item, nextitem),
add_instrumentation(plugin, item),
):
# Run the test
runtestprotocol(item, log=True, nextitem=nextitem)

ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location)
return True


class BenchmarkFixture:
"""The fixture that can be used to benchmark a function."""

Expand All @@ -234,7 +297,7 @@ def __init__(self, request: pytest.FixtureRequest):

self._request = request

def __call__(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
def __call__(self, func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
config = self._request.config
plugin = get_plugin(config)
plugin.benchmark_count += 1
Expand All @@ -243,12 +306,9 @@ def __call__(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
and plugin.lib is not None
and plugin.should_measure
):
return wrap_pyfunc_with_instrumentation(
plugin.lib,
self._request.node.nodeid,
config,
func,
)(*args, **kwargs)
return _run_with_instrumentation(
plugin.lib, self._request.node.nodeid, config, func, *args, **kwargs
)
else:
return func(*args, **kwargs)

Expand Down
64 changes: 59 additions & 5 deletions tests/test_pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,12 @@ def test_perf_maps_generation(pytester: pytest.Pytester, codspeed_env) -> None:
@pytest.mark.benchmark
def test_some_addition_marked():
return 1 + 1
assert 1 + 1
def test_some_addition_fixtured(benchmark):
@benchmark
def fixtured_child():
return 1 + 1
assert 1 + 1
"""
)
with codspeed_env():
Expand All @@ -308,9 +308,7 @@ def fixtured_child():
lines = perf_file.readlines()

assert any(
"py::wrap_pyfunc_with_instrumentation.<locals>.wrapper.<locals>.__codspeed_root_frame__"
in line
for line in lines
"py::_run_with_instrumentation" in line for line in lines
), "No root frame found in perf map"
assert any(
"py::test_some_addition_marked" in line for line in lines
Expand Down Expand Up @@ -397,3 +395,59 @@ def test_capsys(capsys):
result.stdout.fnmatch_lines(["*1 benchmarked*"])
result.stdout.no_fnmatch_line("*print to stdout*")
result.stderr.no_fnmatch_line("*print to stderr*")


@skip_without_valgrind
@skip_with_pytest_benchmark
def test_benchmark_marker_tmp_path(pytester: pytest.Pytester, codspeed_env) -> None:
pytester.makepyfile(
"""
import pytest
@pytest.mark.benchmark
def test_tmp_path(tmp_path):
(tmp_path / "random").mkdir()
"""
)
with codspeed_env():
result = pytester.runpytest("--codspeed")
assert result.ret == 0, "the run should have succeeded"


@skip_without_valgrind
@skip_with_pytest_benchmark
def test_benchmark_fixture_tmp_path(pytester: pytest.Pytester, codspeed_env) -> None:
pytester.makepyfile(
"""
import pytest
def test_tmp_path(benchmark, tmp_path):
@benchmark
def _():
(tmp_path / "random").mkdir()
"""
)
with codspeed_env():
result = pytester.runpytest("--codspeed")
assert result.ret == 0, "the run should have succeeded"


@skip_without_valgrind
@skip_with_pytest_benchmark
def test_benchmark_fixture_warmup(pytester: pytest.Pytester, codspeed_env) -> None:
pytester.makepyfile(
"""
def test_bench(benchmark):
called_once = False
@benchmark
def _():
nonlocal called_once
if not called_once:
called_once = True
else:
raise Exception("called twice")
"""
)
with codspeed_env():
result = pytester.runpytest("--codspeed")
assert result.ret == 0, "the run should have succeeded"

0 comments on commit f2595e8

Please sign in to comment.