Skip to content

Commit

Permalink
Merge Bundle as subclass of SampledFromStrategy
Browse files Browse the repository at this point in the history
  • Loading branch information
reaganjlee committed Oct 11, 2024
1 parent 228437f commit cefcea8
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 23 deletions.
7 changes: 7 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
RELEASE_TYPE: minor

This release changes :class:`hypothesis.stateful.Bundle` to use the internals of
:func:`~hypothesis.strategies.sampled_from`, improving the `filter` and `map` methods.
In addition to performance improvements, you can now ``consumes(some_bundle).filter(...)``!

Thanks to Reagan Lee for this feature (:issue:`3944`).
89 changes: 72 additions & 17 deletions hypothesis-python/src/hypothesis/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,25 @@
"""
import collections
import inspect
import sys
from collections.abc import Iterable, Sequence
from copy import copy
from functools import lru_cache
from functools import lru_cache, partial
from io import StringIO
from time import perf_counter
from typing import Any, Callable, ClassVar, Optional, Union, overload
from typing import (
Any,
Callable,
ClassVar,
Dict,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
overload,
)
from unittest import TestCase

import attr
Expand Down Expand Up @@ -56,6 +69,7 @@
Ex,
Ex_Inv,
OneOfStrategy,
SampledFromStrategy,
SearchStrategy,
check_strategy,
)
Expand Down Expand Up @@ -469,7 +483,7 @@ def __repr__(self) -> str:
self_strategy = st.runner()


class Bundle(SearchStrategy[Ex]):
class Bundle(SampledFromStrategy[Ex]):
"""A collection of values for use in stateful testing.
Bundles are a kind of strategy where values can be added by rules,
Expand All @@ -492,32 +506,72 @@ class MyStateMachine(RuleBasedStateMachine):
"""

def __init__(
self, name: str, *, consume: bool = False, draw_references: bool = True
self,
name: str,
*,
consume: bool = False,
draw_references: bool = True,
**kwargs,
) -> None:
super().__init__(
[...], **kwargs
) # Some random items that'll get replaced in do_draw
self.name = name
self.consume = consume
self.draw_references = draw_references

self.bundle = None

# Shrink towards the right rather than the left. This makes it easier
# to delete data generated earlier, as when the error is towards the
# end there can be a lot of hard to remove padding.
self._SHRINK_TOWARDS = sys.maxsize

def reference_to_val_func(self, dic, item):
assert isinstance(item, int)

element = self.bundle[item]

assert isinstance(element, VarReference)
return dic.get(element.name)

def do_draw(self, data):
machine = data.draw(self_strategy)

bundle = machine.bundle(self.name)
if not bundle:
data.mark_invalid(f"Cannot draw from empty bundle {self.name!r}")
# Shrink towards the right rather than the left. This makes it easier
# to delete data generated earlier, as when the error is towards the
# end there can be a lot of hard to remove padding.
position = data.draw_integer(0, len(bundle) - 1, shrink_towards=len(bundle))

# We use both self.bundle and self.elements to make sure an index is used to safely pop
self.bundle = bundle
self.elements = range(len(bundle))

self.reference_to_value = partial(
self.reference_to_val_func, machine.names_to_values
)

idx = super().do_draw(data)
reference = bundle[idx]
if self.consume:
reference = bundle.pop(
position
) # pragma: no cover # coverage is flaky here
else:
reference = bundle[position]
bundle.pop(idx) # pragma: no cover # coverage is flaky here
return reference

def filter(self, condition):
return type(self)(
self.name,
consume=self.consume,
draw_references=self.draw_references,
transformations=(*self._transformations, ("filter", condition)),
repr=self.repr_,
)

if self.draw_references:
return reference
return machine.names_to_values[reference.name]
def map(self, pack):
return type(self)(
self.name,
consume=self.consume,
draw_references=self.draw_references,
transformations=(*self._transformations, ("map", pack)),
repr=self.repr_,
)

def __repr__(self):
consume = self.consume
Expand Down Expand Up @@ -562,6 +616,7 @@ def consumes(bundle: Bundle[Ex]) -> SearchStrategy[Ex]:
return type(bundle)(
name=bundle.name,
consume=True,
transformations=bundle._transformations,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -473,20 +473,22 @@ def is_simple_data(value):
return False


class SampledFromStrategy(SearchStrategy):
class SampledFromStrategy(SearchStrategy[Ex]):
"""A strategy which samples from a set of elements. This is essentially
equivalent to using a OneOfStrategy over Just strategies but may be more
efficient and convenient.
"""

_MAX_FILTER_CALLS = 10_000
_SHRINK_TOWARDS = 0

def __init__(self, elements, repr_=None, transformations=()):
super().__init__()
self.elements = cu.check_sample(elements, "sampled_from")
assert self.elements
self.repr_ = repr_
self._transformations = transformations
self.reference_to_value = lambda x: x

def map(self, pack):
return type(self)(
Expand Down Expand Up @@ -552,7 +554,11 @@ def do_draw(self, data):
return result

def get_element(self, i):
return self._transform(self.elements[i])
element = self.elements[i]
value = self._transform(self.reference_to_value(element))
if is_identity_function(self.reference_to_value):
return value
return element if value is not filter_not_satisfied else filter_not_satisfied

def do_filtered_draw(self, data):
# Set of indices that have been tried so far, so that we never test
Expand All @@ -562,7 +568,9 @@ def do_filtered_draw(self, data):
# Start with ordinary rejection sampling. It's fast if it works, and
# if it doesn't work then it was only a small amount of overhead.
for _ in range(3):
i = data.draw_integer(0, len(self.elements) - 1)
i = data.draw_integer(
0, len(self.elements) - 1, shrink_towards=self._SHRINK_TOWARDS
)
if i not in known_bad_indices:
element = self.get_element(i)
if element is not filter_not_satisfied:
Expand All @@ -583,7 +591,9 @@ def do_filtered_draw(self, data):
# Before building the list of allowed indices, speculatively choose
# one of them. We don't yet know how many allowed indices there will be,
# so this choice might be out-of-bounds, but that's OK.
speculative_index = data.draw_integer(0, max_good_indices - 1)
speculative_index = data.draw_integer(
0, max_good_indices - 1, shrink_towards=self._SHRINK_TOWARDS
)

# Calculate the indices of allowed values, so that we can choose one
# of them at random. But if we encounter the speculatively-chosen one,
Expand All @@ -598,14 +608,21 @@ def do_filtered_draw(self, data):
if len(allowed) > speculative_index:
# Early-exit case: We reached the speculative index, so
# we just return the corresponding element.
data.draw_integer(0, len(self.elements) - 1, forced=i)
data.draw_integer(
0,
len(self.elements) - 1,
forced=i,
shrink_towards=self._SHRINK_TOWARDS,
)
return element

# The speculative index didn't work out, but at this point we've built
# and can choose from the complete list of allowed indices and elements.
if allowed:
i, element = data.choice(allowed)
data.draw_integer(0, len(self.elements) - 1, forced=i)
data.draw_integer(
0, len(self.elements) - 1, forced=i, shrink_towards=self._SHRINK_TOWARDS
)
return element
# If there are no allowed indices, the filter couldn't be satisfied.
return filter_not_satisfied
Expand Down
99 changes: 99 additions & 0 deletions hypothesis-python/tests/cover/test_stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,105 @@ def rule1(self, data):
TestLotsOfEntropyPerStepMachine = LotsOfEntropyPerStepMachine.TestCase


def test_filter():
class Machine(RuleBasedStateMachine):
a = Bundle("a")

@initialize(target=a)
def initialize(self):
return multiple(1, 2, 3)

@rule(
a1=a.filter(lambda x: x < 2),
a2=a.filter(lambda x: x > 2),
a3=a,
)
def fail_fast(self, a1, a2, a3):
raise AssertionError

Machine.TestCase.settings = NO_BLOB_SETTINGS
with pytest.raises(AssertionError) as err:
run_state_machine_as_test(Machine)

result = "\n".join(err.value.__notes__)
assert (
result
== """
Falsifying example:
state = Machine()
a_0, a_1, a_2 = state.initialize()
state.fail_fast(a1=a_0, a2=a_2, a3=a_2)
state.teardown()
""".strip()
)


def test_consumes_filter():
class Machine(RuleBasedStateMachine):
a = Bundle("a")

@initialize(target=a)
def initialize(self):
return multiple(1, 2, 3)

@rule(
a1=consumes(a).filter(lambda x: x < 2),
a2=consumes(a).filter(lambda x: x > 2),
a3=consumes(a),
)
def fail_fast(self, a1, a2, a3):
raise AssertionError

Machine.TestCase.settings = NO_BLOB_SETTINGS
with pytest.raises(AssertionError) as err:
run_state_machine_as_test(Machine)

result = "\n".join(err.value.__notes__)
assert (
result
== """
Falsifying example:
state = Machine()
a_0, a_1, a_2 = state.initialize()
state.fail_fast(a1=a_0, a2=a_2, a3=a_1)
state.teardown()
""".strip()
)


def test_map_with_filter():
class Machine(RuleBasedStateMachine):
a = Bundle("a")

@initialize(target=a)
def initialize(self):
return multiple(2, 4)

@rule(
a1=a.map(lambda x: x**2).filter(lambda x: x < 3**2),
a2=consumes(a).map(lambda x: x**2).filter(lambda x: x > 3**2),
a3=consumes(a),
)
def fail_fast(self, a1, a2, a3):
raise AssertionError

Machine.TestCase.settings = NO_BLOB_SETTINGS
with pytest.raises(AssertionError) as err:
run_state_machine_as_test(Machine)

result = "\n".join(err.value.__notes__)
assert (
result
== """
Falsifying example:
state = Machine()
a_0, a_1 = state.initialize()
state.fail_fast(a1=a_0, a2=a_1, a3=a_0)
state.teardown()
""".strip()
)


def test_flatmap():
class Machine(RuleBasedStateMachine):
buns = Bundle("buns")
Expand Down

0 comments on commit cefcea8

Please sign in to comment.