Skip to content

Commit

Permalink
Merge pull request #4287 from tybug/filter-typing
Browse files Browse the repository at this point in the history
More correct type hints for filter and map
  • Loading branch information
tybug authored Mar 5, 2025
2 parents 0ce28fa + 8d59b93 commit 49986c1
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 27 deletions.
8 changes: 8 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
RELEASE_TYPE: patch

Fix a type-hinting regression from :ref:`version 6.125.1 <v6.125.1>`, where we would no longer guarantee the type of the argument to ``.filter`` predicates (:issue:`4269`).

.. code-block:: python
# x was previously Unknown, but is now correctly guaranteed to be int
st.integers().filter(lambda x: x > 0)
50 changes: 25 additions & 25 deletions hypothesis-python/src/hypothesis/strategies/_internal/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,6 @@
MappedFrom = TypeVar("MappedFrom")
MappedTo = TypeVar("MappedTo")
RecurT: "TypeAlias" = Callable[["SearchStrategy"], Any]
# These PackT and PredicateT aliases can only be used when you don't want to
# specify a relationship between the generic Ts and some other function param
# / return value. If you do - like the actual map definition in SearchStrategy -
# you'll need to write Callable[[Ex], T] (replacing Ex/T as appropriate) instead.
# TypeAlias is *not* simply a macro that inserts the text. it has different semantics.
PackT: "TypeAlias" = Callable[[T], T3]
PredicateT: "TypeAlias" = Callable[[T], object]
TransformationsT: "TypeAlias" = tuple[
Union[tuple[Literal["filter"], PredicateT], tuple[Literal["map"], PackT]], ...
]

calculating = UniqueIdentifier("calculating")

MAPPED_SEARCH_STRATEGY_DO_DRAW_LABEL = calc_label_from_name(
Expand Down Expand Up @@ -390,7 +379,15 @@ def flatmap(

return FlatMapStrategy(expand=expand, strategy=self)

def filter(self, condition: PredicateT) -> "SearchStrategy[Ex]":
# Note that we previously had condition extracted to a type alias as
# PredicateT. However, that was only useful when not specifying a relationship
# between the generic Ts and some other function param / return value.
# If we do want to - like here, where we want to say that the Ex arg to condition
# is of the same type as the strategy's Ex - then you need to write out the
# entire Callable[[Ex], Any] expression rather than use a type alias.
# TypeAlias is *not* simply a macro that inserts the text. TypeAlias will not
# reference the local TypeVar context.
def filter(self, condition: Callable[[Ex], Any]) -> "SearchStrategy[Ex]":
"""Returns a new strategy that generates values from this strategy
which satisfy the provided condition. Note that if the condition is too
hard to satisfy this might result in your tests failing with
Expand All @@ -400,7 +397,9 @@ def filter(self, condition: PredicateT) -> "SearchStrategy[Ex]":
"""
return FilteredStrategy(conditions=(condition,), strategy=self)

def _filter_for_filtered_draw(self, condition: PredicateT) -> "SearchStrategy[Ex]":
def _filter_for_filtered_draw(
self, condition: Callable[[Ex], Any]
) -> "SearchStrategy[Ex]":
# Hook for parent strategies that want to perform fallible filtering
# on one of their internal strategies (e.g. UniqueListStrategy).
# The returned object must have a `.do_filtered_draw(data)` method
Expand Down Expand Up @@ -502,7 +501,10 @@ def __init__(
self,
elements: Sequence[Ex],
repr_: Optional[str] = None,
transformations: TransformationsT = (),
transformations: tuple[
tuple[Literal["filter", "map"], Callable[[Ex], Any]],
...,
] = (),
):
super().__init__()
self.elements = cu.check_sample(elements, "sampled_from")
Expand All @@ -519,7 +521,7 @@ def map(self, pack: Callable[[Ex], T]) -> SearchStrategy[T]:
# guaranteed by the ("map", pack) transformation
return cast(SearchStrategy[T], s)

def filter(self, condition: PredicateT) -> SearchStrategy[Ex]:
def filter(self, condition: Callable[[Ex], Any]) -> SearchStrategy[Ex]:
return type(self)(
self.elements,
repr_=self.repr_,
Expand Down Expand Up @@ -557,14 +559,12 @@ def _transform(
# Used in UniqueSampledListStrategy
for name, f in self._transformations:
if name == "map":
f = cast(PackT, f)
result = f(element)
if build_context := _current_build_context.value:
build_context.record_call(result, f, [element], {})
element = result
else:
assert name == "filter"
f = cast(PredicateT, f)
if not f(element):
return filter_not_satisfied
return element
Expand Down Expand Up @@ -731,7 +731,7 @@ def branches(self) -> Sequence[SearchStrategy[Ex]]:
else:
return [self]

def filter(self, condition: PredicateT) -> SearchStrategy[Ex]:
def filter(self, condition: Callable[[Ex], Any]) -> SearchStrategy[Ex]:
return FilteredStrategy(
OneOfStrategy([s.filter(condition) for s in self.original_strategies]),
conditions=(),
Expand Down Expand Up @@ -960,12 +960,12 @@ def _collection_ish_functions() -> Sequence[Any]:

class FilteredStrategy(SearchStrategy[Ex]):
def __init__(
self, strategy: SearchStrategy[Ex], conditions: tuple[PredicateT, ...]
self, strategy: SearchStrategy[Ex], conditions: tuple[Callable[[Ex], Any], ...]
):
super().__init__()
if isinstance(strategy, FilteredStrategy):
# Flatten chained filters into a single filter with multiple conditions.
self.flat_conditions: tuple[PredicateT, ...] = (
self.flat_conditions: tuple[Callable[[Ex], Any], ...] = (
strategy.flat_conditions + conditions
)
self.filtered_strategy: SearchStrategy[Ex] = strategy.filtered_strategy
Expand All @@ -976,7 +976,7 @@ def __init__(
assert isinstance(self.flat_conditions, tuple)
assert not isinstance(self.filtered_strategy, FilteredStrategy)

self.__condition: Optional[PredicateT] = None
self.__condition: Optional[Callable[[Ex], Any]] = None

def calc_is_empty(self, recur: RecurT) -> Any:
return recur(self.filtered_strategy)
Expand Down Expand Up @@ -1017,7 +1017,7 @@ def do_validate(self) -> None:
# an in-place method so we still just re-initialize the strategy!
FilteredStrategy.__init__(self, fresh, ())

def filter(self, condition: PredicateT) -> "FilteredStrategy[Ex]":
def filter(self, condition: Callable[[Ex], Any]) -> "FilteredStrategy[Ex]":
# If we can, it's more efficient to rewrite our strategy to satisfy the
# condition. We therefore exploit the fact that the order of predicates
# doesn't matter (`f(x) and g(x) == g(x) and f(x)`) by attempting to apply
Expand All @@ -1033,16 +1033,16 @@ def filter(self, condition: PredicateT) -> "FilteredStrategy[Ex]":
return FilteredStrategy(out, self.flat_conditions)

@property
def condition(self) -> PredicateT:
def condition(self) -> Callable[[Ex], Any]:
if self.__condition is None:
if len(self.flat_conditions) == 1:
# Avoid an extra indirection in the common case of only one condition.
self.__condition = self.flat_conditions[0]
elif len(self.flat_conditions) == 0:
# Possible, if unlikely, due to filter predicate rewriting
self.__condition = lambda _: True
self.__condition = lambda _: True # type: ignore # covariant type param
else:
self.__condition = lambda x: all(
self.__condition = lambda x: all( # type: ignore # covariant type param
cond(x) for cond in self.flat_conditions
)
return self.__condition
Expand Down
4 changes: 4 additions & 0 deletions whole_repo_tests/revealed_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
"tuples(text(), text(), text(), text(), text(), text())",
"tuple[Any, ...]",
),
("lists(none())", "list[None]"),
("integers().filter(lambda x: x > 0)", "int"),
("booleans().filter(lambda x: x)", "bool"),
("integers().map(bool).filter(lambda x: x)", "bool"),
]

NUMPY_REVEALED_TYPES = [
Expand Down
1 change: 0 additions & 1 deletion whole_repo_tests/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def convert_lines():
"val,expect",
[
*REVEALED_TYPES, # shared with Pyright
("lists(none())", "list[None]"),
("data()", "hypothesis.strategies._internal.core.DataObject"),
("none() | integers()", "Union[None, int]"),
("recursive(integers(), lists)", "Union[list[Any], int]"),
Expand Down
1 change: 0 additions & 1 deletion whole_repo_tests/test_pyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def test_numpy_arrays_strategy(tmp_path: Path):
"val,expect",
[
*REVEALED_TYPES, # shared with Mypy
("lists(none())", "list[None]"),
("dictionaries(integers(), datetimes())", "dict[int, datetime]"),
("data()", "DataObject"),
("none() | integers()", "int | None"),
Expand Down

0 comments on commit 49986c1

Please sign in to comment.