Skip to content

Commit

Permalink
Merge pull request #4138 from tybug/integer-weights-simple
Browse files Browse the repository at this point in the history
Fold integer endpoint upweighting into `weights=`
  • Loading branch information
tybug authored Oct 14, 2024
2 parents 760373e + 84dbaee commit 7145c74
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 147 deletions.
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

This release improves integer shrinking by folding the endpoint upweighting for :func:`~hypothesis.strategies.integers` into the ``weights`` parameter of our IR (:issue:`3921`).

If you maintain an alternative backend as part of our (for now explicitly unstable) :ref:`alternative-backends`, this release changes the type of the ``weights`` parameter to ``draw_integer`` and may be a breaking change for you.
73 changes: 37 additions & 36 deletions hypothesis-python/src/hypothesis/internal/conjecture/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def wrapper(tp):
class IntegerKWargs(TypedDict):
min_value: Optional[int]
max_value: Optional[int]
weights: Optional[Sequence[float]]
weights: Optional[dict[int, float]]
shrink_towards: int


Expand Down Expand Up @@ -1287,7 +1287,7 @@ def draw_integer(
max_value: Optional[int] = None,
*,
# weights are for choosing an element index from a bounded range
weights: Optional[Sequence[float]] = None,
weights: Optional[dict[int, float]] = None,
shrink_towards: int = 0,
forced: Optional[int] = None,
fake_forced: bool = False,
Expand Down Expand Up @@ -1456,8 +1456,7 @@ def draw_integer(
min_value: Optional[int] = None,
max_value: Optional[int] = None,
*,
# weights are for choosing an element index from a bounded range
weights: Optional[Sequence[float]] = None,
weights: Optional[dict[int, float]] = None,
shrink_towards: int = 0,
forced: Optional[int] = None,
fake_forced: bool = False,
Expand All @@ -1475,22 +1474,31 @@ def draw_integer(
assert min_value is not None
assert max_value is not None

sampler = Sampler(weights, observe=False)
gap = max_value - shrink_towards

forced_idx = None
if forced is not None:
if forced >= shrink_towards:
forced_idx = forced - shrink_towards
else:
forced_idx = shrink_towards + gap - forced
idx = sampler.sample(self._cd, forced=forced_idx, fake_forced=fake_forced)
# format of weights is a mapping of ints to p, where sum(p) < 1.
# The remaining probability mass is uniformly distributed over
# *all* ints (not just the unmapped ones; this is somewhat undesirable,
# but simplifies things).
#
# We assert that sum(p) is strictly less than 1 because it simplifies
# handling forced values when we can force into the unmapped probability
# mass. We should eventually remove this restriction.
sampler = Sampler(
[1 - sum(weights.values()), *weights.values()], observe=False
)
# if we're forcing, it's easiest to force into the unmapped probability
# mass and then force the drawn value after.
idx = sampler.sample(
self._cd, forced=None if forced is None else 0, fake_forced=fake_forced
)

# For range -2..2, interpret idx = 0..4 as [0, 1, 2, -1, -2]
if idx <= gap:
return shrink_towards + idx
else:
return shrink_towards - (idx - gap)
return self._draw_bounded_integer(
min_value,
max_value,
# implicit reliance on dicts being sorted for determinism
forced=forced if idx == 0 else list(weights)[idx - 1],
center=shrink_towards,
fake_forced=fake_forced,
)

if min_value is None and max_value is None:
return self._draw_unbounded_integer(forced=forced, fake_forced=fake_forced)
Expand Down Expand Up @@ -2116,8 +2124,7 @@ def draw_integer(
min_value: Optional[int] = None,
max_value: Optional[int] = None,
*,
# weights are for choosing an element index from a bounded range
weights: Optional[Sequence[float]] = None,
weights: Optional[dict[int, float]] = None,
shrink_towards: int = 0,
forced: Optional[int] = None,
fake_forced: bool = False,
Expand All @@ -2127,9 +2134,14 @@ def draw_integer(
if weights is not None:
assert min_value is not None
assert max_value is not None
width = max_value - min_value + 1
assert width <= 255 # arbitrary practical limit
assert len(weights) == width
assert len(weights) <= 255 # arbitrary practical limit
# We can and should eventually support total weights. But this
# complicates shrinking as we can no longer assume we can force
# a value to the unmapped probability mass if that mass might be 0.
assert sum(weights.values()) < 1
# similarly, things get simpler if we assume every value is possible.
# we'll want to drop this restriction eventually.
assert all(w != 0 for w in weights.values())

if forced is not None and (min_value is None or max_value is None):
# We draw `forced=forced - shrink_towards` here internally, after clamping.
Expand Down Expand Up @@ -2365,18 +2377,7 @@ def _pooled_kwargs(self, ir_type, kwargs):
if self.provider.avoid_realization:
return kwargs

key = []
for k, v in kwargs.items():
if ir_type == "float" and k in ["min_value", "max_value"]:
# handle -0.0 vs 0.0, etc.
v = float_to_int(v)
elif ir_type == "integer" and k == "weights":
# make hashable
v = v if v is None else tuple(v)
key.append((k, v))

key = (ir_type, *sorted(key))

key = (ir_type, *ir_kwargs_key(ir_type, kwargs))
try:
return POOLED_KWARGS_CACHE[key]
except KeyError:
Expand Down
61 changes: 25 additions & 36 deletions hypothesis-python/src/hypothesis/internal/conjecture/shrinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
ConjectureData,
ConjectureResult,
Status,
bits_to_bytes,
ir_value_equal,
ir_value_key,
ir_value_permitted,
Expand Down Expand Up @@ -681,7 +680,7 @@ def greedy_shrink(self):
"reorder_examples",
"minimize_duplicated_nodes",
"minimize_individual_nodes",
"redistribute_block_pairs",
"redistribute_integer_pairs",
"lower_blocks_together",
]
)
Expand Down Expand Up @@ -1227,42 +1226,32 @@ def minimize_duplicated_nodes(self, chooser):
self.minimize_nodes(nodes)

@defines_shrink_pass()
def redistribute_block_pairs(self, chooser):
def redistribute_integer_pairs(self, chooser):
"""If there is a sum of generated integers that we need their sum
to exceed some bound, lowering one of them requires raising the
other. This pass enables that."""
# TODO_SHRINK let's extend this to floats as well.

node = chooser.choose(
# look for a pair of nodes (node1, node2) which are both integers and
# aren't separated by too many other nodes. We'll decrease node1 and
# increase node2 (note that the other way around doesn't make sense as
# it's strictly worse in the ordering).
node1 = chooser.choose(
self.nodes, lambda node: node.ir_type == "integer" and not node.trivial
)
node2 = chooser.choose(
self.nodes,
lambda node: node.ir_type == "integer"
# Note that it's fine for node2 to be trivial, because we're going to
# explicitly make it *not* trivial by adding to its value.
and not node.was_forced
# to avoid quadratic behavior, scan ahead only a small amount for
# the related node.
and node1.index < node.index <= node1.index + 4,
)

# The preconditions for this pass are that the two integer draws are only
# separated by non-integer nodes, and have the same size value in bytes.
#
# This isn't particularly principled. For instance, this wouldn't reduce
# e.g. @given(integers(), integers(), integers()) where the sum property
# involves the first and last integers.
#
# A better approach may be choosing *two* such integer nodes arbitrarily
# from the list, instead of conditionally scanning forward.

for j in range(node.index + 1, len(self.nodes)):
next_node = self.nodes[j]
if next_node.ir_type == "integer" and bits_to_bytes(
node.value.bit_length()
) == bits_to_bytes(next_node.value.bit_length()):
break
else:
return

if next_node.was_forced:
# avoid modifying a forced node. Note that it's fine for next_node
# to be trivial, because we're going to explicitly make it *not*
# trivial by adding to its value.
return

m = node.value
n = next_node.value
m = node1.value
n = node2.value

def boost(k):
if k > m:
Expand All @@ -1272,11 +1261,11 @@ def boost(k):
next_node_value = n + k

return self.consider_new_tree(
self.nodes[: node.index]
+ [node.copy(with_value=node_value)]
+ self.nodes[node.index + 1 : next_node.index]
+ [next_node.copy(with_value=next_node_value)]
+ self.nodes[next_node.index + 1 :]
self.nodes[: node1.index]
+ [node1.copy(with_value=node_value)]
+ self.nodes[node1.index + 1 : node2.index]
+ [node2.copy(with_value=next_node_value)]
+ self.nodes[node2.index + 1 :]
)

find_integer(boost)
Expand Down
19 changes: 8 additions & 11 deletions hypothesis-python/src/hypothesis/strategies/_internal/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,21 @@ def __repr__(self):

def do_draw(self, data):
# For bounded integers, make the bounds and near-bounds more likely.
forced = None
weights = None
if (
self.end is not None
and self.start is not None
and self.end - self.start > 127
):
bits = data.draw_integer(0, 127)
forced = {
122: self.start,
123: self.start,
124: self.end,
125: self.end,
126: self.start + 1,
127: self.end - 1,
}.get(bits)
weights = {
self.start: (2 / 128),
self.start + 1: (1 / 128),
self.end - 1: (1 / 128),
self.end: (2 / 128),
}

return data.draw_integer(
min_value=self.start, max_value=self.end, forced=forced
min_value=self.start, max_value=self.end, weights=weights
)

def filter(self, condition):
Expand Down
49 changes: 15 additions & 34 deletions hypothesis-python/tests/conjecture/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,9 @@ def integer_kwargs(
draw(st.booleans()) if (use_min_value and use_max_value) else False
)

# this generation is complicated to deal with maintaining any combination of
# the following invariants, depending on which parameters are passed:
#
# Invariants:
# (1) min_value <= forced <= max_value
# (2) max_value - min_value + 1 == len(weights)
# (2) sum(weights.values()) < 1
# (3) len(weights) <= 255

if use_shrink_towards:
Expand All @@ -158,39 +156,22 @@ def integer_kwargs(
if use_weights:
assert use_max_value
assert use_min_value
# handle the weights case entirely independently from the non-weights case.
# We'll treat the weights as our "key" draw and base all other draws on that.

# weights doesn't play well with super small floats, so exclude <.01
min_value = draw(st.integers(max_value=forced))
min_val = max(min_value, forced) if forced is not None else min_value
max_value = draw(st.integers(min_value=min_val))

# Sampler doesn't play well with super small floats, so exclude them
weights = draw(
st.lists(st.just(0) | st.floats(0.01, 1), min_size=1, max_size=255)
st.dictionaries(st.integers(), st.floats(0.001, 1), max_size=255)
)
# zero is allowed, but it can't be all zeroes
assume(sum(weights) > 0)

# we additionally pick a central value (if not forced), and then the index
# into the weights at which it can be found - aka the min-value offset.
center = forced if use_forced else draw(st.integers())
min_value = center - draw(st.integers(0, len(weights) - 1))
max_value = min_value + len(weights) - 1

if use_forced:
# can't force a 0-weight index.
# we avoid clamping the returned shrink_towards to maximize
# bug-finding power.
_shrink_towards = clamped_shrink_towards(
{
"shrink_towards": shrink_towards,
"min_value": min_value,
"max_value": max_value,
}
)
forced_idx = (
forced - _shrink_towards
if forced >= _shrink_towards
else max_value - forced
)
assume(weights[forced_idx] > 0)
# invalid to have a weighting that disallows all possibilities
assume(sum(weights.values()) != 0)
target = draw(st.floats(0.001, 0.999))
# re-normalize probabilities to sum to some arbitrary value < 1
weights = {k: v / target for k, v in weights.items()}
# float rounding error can cause this to fail.
assume(sum(weights.values()) == target)
else:
if use_min_value:
min_value = draw(st.integers(max_value=forced))
Expand Down
4 changes: 1 addition & 3 deletions hypothesis-python/tests/conjecture/test_alt_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import math
import sys
from collections.abc import Sequence
from contextlib import contextmanager
from random import Random
from typing import Optional
Expand Down Expand Up @@ -67,8 +66,7 @@ def draw_integer(
min_value: Optional[int] = None,
max_value: Optional[int] = None,
*,
# weights are for choosing an element index from a bounded range
weights: Optional[Sequence[float]] = None,
weights: Optional[dict[int, float]] = None,
shrink_towards: int = 0,
forced: Optional[int] = None,
fake_forced: bool = False,
Expand Down
28 changes: 26 additions & 2 deletions hypothesis-python/tests/conjecture/test_forced.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_forced_many(data):
"min_value": -1,
"max_value": 1,
"shrink_towards": 1,
"weights": [0.1] * 3,
"weights": {-1: 0.2, 0: 0.2, 1: 0.2},
"forced": 0,
},
)
Expand All @@ -80,11 +80,35 @@ def test_forced_many(data):
"min_value": -1,
"max_value": 1,
"shrink_towards": -1,
"weights": [0.1] * 3,
"weights": {-1: 0.2, 0: 0.2, 1: 0.2},
"forced": 0,
},
)
)
@example(
(
"integer",
{
"min_value": 10,
"max_value": 1_000,
"shrink_towards": 17,
"weights": {20: 0.1},
"forced": 15,
},
)
)
@example(
(
"integer",
{
"min_value": -1_000,
"max_value": -10,
"shrink_towards": -17,
"weights": {-20: 0.1},
"forced": -15,
},
)
)
@example(("float", {"forced": 0.0}))
@example(("float", {"forced": -0.0}))
@example(("float", {"forced": 1.0}))
Expand Down
Loading

0 comments on commit 7145c74

Please sign in to comment.