Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge Bundle as subclass of SampledFromStrategy #4084

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
RELEASE_TYPE: minor

This release changes :class:`hypothesis.stateful.Bundle` to use the internals of
:func:`~hypothesis.strategies.sampled_from`, improving the `filter` and `map` methods.
In addition to performance improvements, you can now ``consumes(some_bundle).filter(...)``!

Thanks to Reagan Lee for this feature (:issue:`3944`).
113 changes: 88 additions & 25 deletions hypothesis-python/src/hypothesis/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""
import collections
import inspect
import sys
from collections.abc import Iterable, Sequence
from copy import copy
from functools import lru_cache
Expand Down Expand Up @@ -56,8 +57,10 @@
Ex,
Ex_Inv,
OneOfStrategy,
SampledFromStrategy,
SearchStrategy,
check_strategy,
filter_not_satisfied,
)
from hypothesis.vendor.pretty import RepresentationPrinter

Expand Down Expand Up @@ -184,12 +187,12 @@ def output(s):
try:
data = dict(data)
for k, v in list(data.items()):
if isinstance(v, VarReference):
data[k] = machine.names_to_values[v.name]
if isinstance(v, VarReferenceMapping):
data[k] = v.value
elif isinstance(v, list) and all(
isinstance(item, VarReference) for item in v
isinstance(item, VarReferenceMapping) for item in v
):
data[k] = [machine.names_to_values[item.name] for item in v]
data[k] = [item.value for item in v]

label = f"execute:rule:{rule.function.__name__}"
start = perf_counter()
Expand Down Expand Up @@ -292,12 +295,12 @@ def __init__(self) -> None:
)

def _pretty_print(self, value):
if isinstance(value, VarReference):
return value.name
if isinstance(value, VarReferenceMapping):
return value.reference.name
elif isinstance(value, list) and all(
isinstance(item, VarReference) for item in value
isinstance(item, VarReferenceMapping) for item in value
):
return "[" + ", ".join([item.name for item in value]) + "]"
return "[" + ", ".join([item.reference.name for item in value]) + "]"
self.__stream.seek(0)
self.__stream.truncate(0)
self.__printer.output_width = 0
Expand Down Expand Up @@ -469,7 +472,7 @@ def __repr__(self) -> str:
self_strategy = st.runner()


class Bundle(SearchStrategy[Ex]):
class Bundle(SampledFromStrategy[Ex]):
"""A collection of values for use in stateful testing.

Bundles are a kind of strategy where values can be added by rules,
Expand All @@ -492,32 +495,81 @@ class MyStateMachine(RuleBasedStateMachine):
"""

def __init__(
self, name: str, *, consume: bool = False, draw_references: bool = True
self,
name: str,
*,
consume: bool = False,
draw_references: bool = True,
repr_: Optional[str] = None,
transformations: Iterable[tuple[str, Callable]] = (),
) -> None:
super().__init__(
[...],
repr_=repr_,
transformations=transformations,
) # Some random items that'll get replaced in do_draw
self.name = name
self.consume = consume
self.draw_references = draw_references

def do_draw(self, data):
machine = data.draw(self_strategy)
self.bundle = None
self.machine = None

bundle = machine.bundle(self.name)
if not bundle:
data.mark_invalid(f"Cannot draw from empty bundle {self.name!r}")
# Shrink towards the right rather than the left. This makes it easier
# to delete data generated earlier, as when the error is towards the
# end there can be a lot of hard to remove padding.
position = data.draw_integer(0, len(bundle) - 1, shrink_towards=len(bundle))
self._SHRINK_TOWARDS = sys.maxsize

def get_transformed_value(self, reference):
assert isinstance(reference, VarReference)
return self._transform(self.machine.names_to_values.get(reference.name))

def get_element(self, i):
idx = self.elements[i]
assert isinstance(idx, int)
reference = self.bundle[idx]
value = self.get_transformed_value(reference)
return idx if value is not filter_not_satisfied else filter_not_satisfied

def do_draw(self, data):
self.machine = data.draw(self_strategy)
self.bundle = self.machine.bundle(self.name)
if not self.bundle:
data.mark_invalid(f"Cannot draw from empty bundle {self.name!r}")

# We use both self.bundle and self.elements to make sure an index is used to safely pop
self.elements = range(len(self.bundle))

idx = super().do_draw(data)
reference = self.bundle[idx]
Comment on lines +540 to +544
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is missing handling for the .map() case; we need to use .get_element() for that - and I guess either return (value, index) or track the latest index as state on the class.


if self.consume:
reference = bundle.pop(
position
) # pragma: no cover # coverage is flaky here
else:
reference = bundle[position]
self.bundle.pop(idx) # pragma: no cover # coverage is flaky here

if not self.draw_references:
return self.get_transformed_value(reference)

# we need both reference and the value itself to pretty-print deterministically
# and maintain any transformations that is bundle-specific
return VarReferenceMapping(reference, self.get_transformed_value(reference))

def filter(self, condition):
return type(self)(
self.name,
consume=self.consume,
draw_references=self.draw_references,
transformations=(*self._transformations, ("filter", condition)),
repr_=self.repr_,
)

if self.draw_references:
return reference
return machine.names_to_values[reference.name]
def map(self, pack):
return type(self)(
self.name,
consume=self.consume,
draw_references=self.draw_references,
transformations=(*self._transformations, ("map", pack)),
repr_=self.repr_,
)

def __repr__(self):
consume = self.consume
Expand All @@ -539,7 +591,11 @@ def available(self, data):
def flatmap(self, expand):
if self.draw_references:
return type(self)(
self.name, consume=self.consume, draw_references=False
self.name,
consume=self.consume,
draw_references=False,
transformations=self._transformations,
repr_=self.repr_,
).flatmap(expand)
return super().flatmap(expand)

Expand All @@ -562,6 +618,7 @@ def consumes(bundle: Bundle[Ex]) -> SearchStrategy[Ex]:
return type(bundle)(
name=bundle.name,
consume=True,
transformations=bundle._transformations,
)


Expand Down Expand Up @@ -828,6 +885,12 @@ class VarReference:
name = attr.ib()


@attr.s()
class VarReferenceMapping:
reference: VarReference = attr.ib()
value: Any = attr.ib()


# There are multiple alternatives for annotating the `precond` type, all of them
# have drawbacks. See https://github.com/HypothesisWorks/hypothesis/pull/3068#issuecomment-906642371
def precondition(precond: Callable[[Any], bool]) -> Callable[[TestFunc], TestFunc]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,13 +473,14 @@ def is_simple_data(value):
return False


class SampledFromStrategy(SearchStrategy):
class SampledFromStrategy(SearchStrategy[Ex]):
"""A strategy which samples from a set of elements. This is essentially
equivalent to using a OneOfStrategy over Just strategies but may be more
efficient and convenient.
"""

_MAX_FILTER_CALLS = 10_000
_SHRINK_TOWARDS = 0

def __init__(self, elements, repr_=None, transformations=()):
super().__init__()
Expand Down Expand Up @@ -562,7 +563,9 @@ def do_filtered_draw(self, data):
# Start with ordinary rejection sampling. It's fast if it works, and
# if it doesn't work then it was only a small amount of overhead.
for _ in range(3):
i = data.draw_integer(0, len(self.elements) - 1)
i = data.draw_integer(
0, len(self.elements) - 1, shrink_towards=self._SHRINK_TOWARDS
)
if i not in known_bad_indices:
element = self.get_element(i)
if element is not filter_not_satisfied:
Expand All @@ -583,7 +586,9 @@ def do_filtered_draw(self, data):
# Before building the list of allowed indices, speculatively choose
# one of them. We don't yet know how many allowed indices there will be,
# so this choice might be out-of-bounds, but that's OK.
speculative_index = data.draw_integer(0, max_good_indices - 1)
speculative_index = data.draw_integer(
0, max_good_indices - 1, shrink_towards=self._SHRINK_TOWARDS
)

# Calculate the indices of allowed values, so that we can choose one
# of them at random. But if we encounter the speculatively-chosen one,
Expand All @@ -598,14 +603,21 @@ def do_filtered_draw(self, data):
if len(allowed) > speculative_index:
# Early-exit case: We reached the speculative index, so
# we just return the corresponding element.
data.draw_integer(0, len(self.elements) - 1, forced=i)
data.draw_integer(
0,
len(self.elements) - 1,
forced=i,
shrink_towards=self._SHRINK_TOWARDS,
)
return element

# The speculative index didn't work out, but at this point we've built
# and can choose from the complete list of allowed indices and elements.
if allowed:
i, element = data.choice(allowed)
data.draw_integer(0, len(self.elements) - 1, forced=i)
data.draw_integer(
0, len(self.elements) - 1, forced=i, shrink_towards=self._SHRINK_TOWARDS
)
return element
# If there are no allowed indices, the filter couldn't be satisfied.
return filter_not_satisfied
Expand Down
Loading
Loading