diff --git a/hypothesis-python/RELEASE.rst b/hypothesis-python/RELEASE.rst new file mode 100644 index 0000000000..c8739170f4 --- /dev/null +++ b/hypothesis-python/RELEASE.rst @@ -0,0 +1,7 @@ +RELEASE_TYPE: minor + +This release changes :class:`hypothesis.stateful.Bundle` to use the internals of +:func:`~hypothesis.strategies.sampled_from`, improving the `filter` and `map` methods. +In addition to performance improvements, you can now ``consumes(some_bundle).filter(...)``! + +Thanks to Reagan Lee for this feature (:issue:`3944`). \ No newline at end of file diff --git a/hypothesis-python/src/hypothesis/stateful.py b/hypothesis-python/src/hypothesis/stateful.py index 7c60d2752f..512117e0b5 100644 --- a/hypothesis-python/src/hypothesis/stateful.py +++ b/hypothesis-python/src/hypothesis/stateful.py @@ -17,6 +17,7 @@ """ import collections import inspect +import sys from collections.abc import Iterable, Sequence from copy import copy from functools import lru_cache @@ -56,8 +57,10 @@ Ex, Ex_Inv, OneOfStrategy, + SampledFromStrategy, SearchStrategy, check_strategy, + filter_not_satisfied, ) from hypothesis.vendor.pretty import RepresentationPrinter @@ -184,12 +187,12 @@ def output(s): try: data = dict(data) for k, v in list(data.items()): - if isinstance(v, VarReference): - data[k] = machine.names_to_values[v.name] + if isinstance(v, VarReferenceMapping): + data[k] = v.value elif isinstance(v, list) and all( - isinstance(item, VarReference) for item in v + isinstance(item, VarReferenceMapping) for item in v ): - data[k] = [machine.names_to_values[item.name] for item in v] + data[k] = [item.value for item in v] label = f"execute:rule:{rule.function.__name__}" start = perf_counter() @@ -292,12 +295,12 @@ def __init__(self) -> None: ) def _pretty_print(self, value): - if isinstance(value, VarReference): - return value.name + if isinstance(value, VarReferenceMapping): + return value.reference.name elif isinstance(value, list) and all( - isinstance(item, VarReference) for item in value + isinstance(item, VarReferenceMapping) for item in value ): - return "[" + ", ".join([item.name for item in value]) + "]" + return "[" + ", ".join([item.reference.name for item in value]) + "]" self.__stream.seek(0) self.__stream.truncate(0) self.__printer.output_width = 0 @@ -469,7 +472,7 @@ def __repr__(self) -> str: self_strategy = st.runner() -class Bundle(SearchStrategy[Ex]): +class Bundle(SampledFromStrategy[Ex]): """A collection of values for use in stateful testing. Bundles are a kind of strategy where values can be added by rules, @@ -492,32 +495,81 @@ class MyStateMachine(RuleBasedStateMachine): """ def __init__( - self, name: str, *, consume: bool = False, draw_references: bool = True + self, + name: str, + *, + consume: bool = False, + draw_references: bool = True, + repr_: Optional[str] = None, + transformations: Iterable[tuple[str, Callable]] = (), ) -> None: + super().__init__( + [...], + repr_=repr_, + transformations=transformations, + ) # Some random items that'll get replaced in do_draw self.name = name self.consume = consume self.draw_references = draw_references - def do_draw(self, data): - machine = data.draw(self_strategy) + self.bundle = None + self.machine = None - bundle = machine.bundle(self.name) - if not bundle: - data.mark_invalid(f"Cannot draw from empty bundle {self.name!r}") # Shrink towards the right rather than the left. This makes it easier # to delete data generated earlier, as when the error is towards the # end there can be a lot of hard to remove padding. - position = data.draw_integer(0, len(bundle) - 1, shrink_towards=len(bundle)) + self._SHRINK_TOWARDS = sys.maxsize + + def get_transformed_value(self, reference): + assert isinstance(reference, VarReference) + return self._transform(self.machine.names_to_values.get(reference.name)) + + def get_element(self, i): + idx = self.elements[i] + assert isinstance(idx, int) + reference = self.bundle[idx] + value = self.get_transformed_value(reference) + return idx if value is not filter_not_satisfied else filter_not_satisfied + + def do_draw(self, data): + self.machine = data.draw(self_strategy) + self.bundle = self.machine.bundle(self.name) + if not self.bundle: + data.mark_invalid(f"Cannot draw from empty bundle {self.name!r}") + + # We use both self.bundle and self.elements to make sure an index is used to safely pop + self.elements = range(len(self.bundle)) + + idx = super().do_draw(data) + reference = self.bundle[idx] + if self.consume: - reference = bundle.pop( - position - ) # pragma: no cover # coverage is flaky here - else: - reference = bundle[position] + self.bundle.pop(idx) # pragma: no cover # coverage is flaky here + + if not self.draw_references: + return self.get_transformed_value(reference) + + # we need both reference and the value itself to pretty-print deterministically + # and maintain any transformations that is bundle-specific + return VarReferenceMapping(reference, self.get_transformed_value(reference)) + + def filter(self, condition): + return type(self)( + self.name, + consume=self.consume, + draw_references=self.draw_references, + transformations=(*self._transformations, ("filter", condition)), + repr_=self.repr_, + ) - if self.draw_references: - return reference - return machine.names_to_values[reference.name] + def map(self, pack): + return type(self)( + self.name, + consume=self.consume, + draw_references=self.draw_references, + transformations=(*self._transformations, ("map", pack)), + repr_=self.repr_, + ) def __repr__(self): consume = self.consume @@ -539,7 +591,11 @@ def available(self, data): def flatmap(self, expand): if self.draw_references: return type(self)( - self.name, consume=self.consume, draw_references=False + self.name, + consume=self.consume, + draw_references=False, + transformations=self._transformations, + repr_=self.repr_, ).flatmap(expand) return super().flatmap(expand) @@ -562,6 +618,7 @@ def consumes(bundle: Bundle[Ex]) -> SearchStrategy[Ex]: return type(bundle)( name=bundle.name, consume=True, + transformations=bundle._transformations, ) @@ -828,6 +885,12 @@ class VarReference: name = attr.ib() +@attr.s() +class VarReferenceMapping: + reference: VarReference = attr.ib() + value: Any = attr.ib() + + # There are multiple alternatives for annotating the `precond` type, all of them # have drawbacks. See https://github.com/HypothesisWorks/hypothesis/pull/3068#issuecomment-906642371 def precondition(precond: Callable[[Any], bool]) -> Callable[[TestFunc], TestFunc]: diff --git a/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py b/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py index c040f748a5..a6fc51cefc 100644 --- a/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py +++ b/hypothesis-python/src/hypothesis/strategies/_internal/strategies.py @@ -473,13 +473,14 @@ def is_simple_data(value): return False -class SampledFromStrategy(SearchStrategy): +class SampledFromStrategy(SearchStrategy[Ex]): """A strategy which samples from a set of elements. This is essentially equivalent to using a OneOfStrategy over Just strategies but may be more efficient and convenient. """ _MAX_FILTER_CALLS = 10_000 + _SHRINK_TOWARDS = 0 def __init__(self, elements, repr_=None, transformations=()): super().__init__() @@ -562,7 +563,9 @@ def do_filtered_draw(self, data): # Start with ordinary rejection sampling. It's fast if it works, and # if it doesn't work then it was only a small amount of overhead. for _ in range(3): - i = data.draw_integer(0, len(self.elements) - 1) + i = data.draw_integer( + 0, len(self.elements) - 1, shrink_towards=self._SHRINK_TOWARDS + ) if i not in known_bad_indices: element = self.get_element(i) if element is not filter_not_satisfied: @@ -583,7 +586,9 @@ def do_filtered_draw(self, data): # Before building the list of allowed indices, speculatively choose # one of them. We don't yet know how many allowed indices there will be, # so this choice might be out-of-bounds, but that's OK. - speculative_index = data.draw_integer(0, max_good_indices - 1) + speculative_index = data.draw_integer( + 0, max_good_indices - 1, shrink_towards=self._SHRINK_TOWARDS + ) # Calculate the indices of allowed values, so that we can choose one # of them at random. But if we encounter the speculatively-chosen one, @@ -598,14 +603,21 @@ def do_filtered_draw(self, data): if len(allowed) > speculative_index: # Early-exit case: We reached the speculative index, so # we just return the corresponding element. - data.draw_integer(0, len(self.elements) - 1, forced=i) + data.draw_integer( + 0, + len(self.elements) - 1, + forced=i, + shrink_towards=self._SHRINK_TOWARDS, + ) return element # The speculative index didn't work out, but at this point we've built # and can choose from the complete list of allowed indices and elements. if allowed: i, element = data.choice(allowed) - data.draw_integer(0, len(self.elements) - 1, forced=i) + data.draw_integer( + 0, len(self.elements) - 1, forced=i, shrink_towards=self._SHRINK_TOWARDS + ) return element # If there are no allowed indices, the filter couldn't be satisfied. return filter_not_satisfied diff --git a/hypothesis-python/tests/cover/test_stateful.py b/hypothesis-python/tests/cover/test_stateful.py index 2dd51081cc..23b5c9a82e 100644 --- a/hypothesis-python/tests/cover/test_stateful.py +++ b/hypothesis-python/tests/cover/test_stateful.py @@ -1320,7 +1320,106 @@ def rule1(self, data): TestLotsOfEntropyPerStepMachine = LotsOfEntropyPerStepMachine.TestCase -def test_flatmap(): +def test_filter(): + class Machine(RuleBasedStateMachine): + a = Bundle("a") + + @initialize(target=a) + def initialize(self): + return multiple(1, 2, 3) + + @rule( + a1=a.filter(lambda x: x < 2), + a2=a.filter(lambda x: x > 2), + a3=a, + ) + def fail_fast(self, a1, a2, a3): + raise AssertionError + + Machine.TestCase.settings = NO_BLOB_SETTINGS + with pytest.raises(AssertionError) as err: + run_state_machine_as_test(Machine) + + result = "\n".join(err.value.__notes__) + assert ( + result + == """ +Falsifying example: +state = Machine() +a_0, a_1, a_2 = state.initialize() +state.fail_fast(a1=a_0, a2=a_2, a3=a_2) +state.teardown() +""".strip() + ) + + +def test_consumes_filter(): + class Machine(RuleBasedStateMachine): + a = Bundle("a") + + @initialize(target=a) + def initialize(self): + return multiple(1, 2, 3) + + @rule( + a1=consumes(a).filter(lambda x: x < 2), + a2=consumes(a).filter(lambda x: x > 2), + a3=consumes(a), + ) + def fail_fast(self, a1, a2, a3): + raise AssertionError + + Machine.TestCase.settings = NO_BLOB_SETTINGS + with pytest.raises(AssertionError) as err: + run_state_machine_as_test(Machine) + + result = "\n".join(err.value.__notes__) + assert ( + result + == """ +Falsifying example: +state = Machine() +a_0, a_1, a_2 = state.initialize() +state.fail_fast(a1=a_0, a2=a_2, a3=a_1) +state.teardown() +""".strip() + ) + + +def test_consumes_map_with_filter(): + class Machine(RuleBasedStateMachine): + a = Bundle("a") + + @initialize(target=a) + def initialize(self): + return multiple(2, 4) + + @rule( + a1=a.map(lambda x: x**2).filter(lambda x: x < 3**2), + a2=consumes(a).map(lambda x: x**2).filter(lambda x: x > 3**2), + a3=consumes(a), + ) + def fail_fast(self, a1, a2, a3): + raise AssertionError + + Machine.TestCase.settings = NO_BLOB_SETTINGS + with pytest.raises(AssertionError) as err: + run_state_machine_as_test(Machine) + + result = "\n".join(err.value.__notes__) + assert ( + result + == """ +Falsifying example: +state = Machine() +a_0, a_1 = state.initialize() +state.fail_fast(a1=a_0, a2=a_1, a3=a_0) +state.teardown() +""".strip() + ) + + +def test_flatmap_with_combinations(): class Machine(RuleBasedStateMachine): buns = Bundle("buns") @@ -1333,9 +1432,133 @@ def use_flatmap(self, bun): assert isinstance(bun, int) return bun + @rule( + target=buns, + bun=buns.flatmap(lambda x: just(-x)).filter(lambda x: x < -1), + ) + def use_flatmap_filtered(self, bun): + assert isinstance(bun, int) + assert bun < -1 + return -bun + + @rule( + target=buns, + bun=buns.flatmap(lambda x: just(x + 1)).map(lambda x: -x), + ) + def use_flatmap_mapped(self, bun): + assert isinstance(bun, int) + assert bun < 0 + return -bun + + @rule(bun=buns) + def use_directly(self, bun): + assert isinstance(bun, int) + assert bun >= 0 + + Machine.TestCase.settings = Settings(stateful_step_count=5, max_examples=10) + run_state_machine_as_test(Machine) + + +def test_map_with_combinations(): + class Machine(RuleBasedStateMachine): + buns = Bundle("buns") + + @initialize(target=buns) + def create_bun(self): + return 1 + + @rule(bun=buns.map(lambda x: -x)) + def use_map_base(self, bun): + assert isinstance(bun, int) + assert bun < 0 + + @rule( + bun=buns.map(lambda x: -x).filter(lambda x: x < -1), + ) + def use_flatmap_filtered(self, bun): + assert isinstance(bun, int) + assert bun < -1 + + @rule( + bun=buns.map(lambda x: -x).flatmap(lambda x: just(abs(x) + 1)), + ) + def use_flatmap_mapped(self, bun): + assert isinstance(bun, int) + assert bun > 0 + + @rule(bun=buns) + def use_directly(self, bun): + assert isinstance(bun, int) + assert bun > 0 + + Machine.TestCase.settings = Settings(stateful_step_count=5, max_examples=10) + run_state_machine_as_test(Machine) + + +def test_filter_with_combinations(): + class Machine(RuleBasedStateMachine): + buns = Bundle("buns") + + @initialize(target=buns) + def create_bun(self): + return multiple(0, -1, -2) + + @rule(bun=buns.filter(lambda x: x > 0)) + def use_filter_base(self, bun): + assert isinstance(bun, int) + assert bun > 0 + + @rule( + bun=buns.filter(lambda x: x > 0).flatmap(lambda x: just(-x)), + ) + def use_filter_flatmapped(self, bun): + assert isinstance(bun, int) + assert bun < 0 + + @rule( + bun=buns.filter(lambda x: x < 0).map(lambda x: -x), + ) + def use_flatmap_mapped(self, bun): + assert isinstance(bun, int) + assert bun > 0 + @rule(bun=buns) def use_directly(self, bun): assert isinstance(bun, int) Machine.TestCase.settings = Settings(stateful_step_count=5, max_examples=10) run_state_machine_as_test(Machine) + + +def test_mapped_values_assigned_properly(): + class Machine(RuleBasedStateMachine): + a = Bundle("a") + + @initialize(target=a) + def initialize(self): + return multiple("ret1", "ret2") + + @rule( + a1=a, + a2=a.map(lambda x: x + x), + a3=consumes(a).map(lambda x: x + x), + a4=a, + ) + def fail_fast(self, a1, a2, a3, a4): + raise AssertionError + + Machine.TestCase.settings = NO_BLOB_SETTINGS + with pytest.raises(AssertionError) as err: + run_state_machine_as_test(Machine) + + result = "\n".join(err.value.__notes__) + assert ( + result + == """ +Falsifying example: +state = Machine() +a_0, a_1 = state.initialize() +state.fail_fast(a1=a_1, a2=a_1, a3=a_1, a4=a_0) +state.teardown() +""".strip() + )