diff --git a/ibis/common/deferred.py b/ibis/common/deferred.py index 5ed0af7f03fc..593cb3a09782 100644 --- a/ibis/common/deferred.py +++ b/ibis/common/deferred.py @@ -10,7 +10,7 @@ from ibis.common.bases import Final, FrozenSlotted, Hashable, Immutable, Slotted from ibis.common.collections import FrozenDict from ibis.common.typing import Coercible, CoercionError -from ibis.util import PseudoHashable, is_iterable +from ibis.util import PseudoHashable class Resolver(Coercible, Hashable): @@ -519,9 +519,12 @@ def resolver(obj): elif isinstance(obj, collections.abc.Mapping): # allow nesting deferred patterns in dicts return Mapping(obj) - elif is_iterable(obj): + elif isinstance(obj, collections.abc.Sequence): # allow nesting deferred patterns in tuples/lists - return Sequence(obj) + if isinstance(obj, (str, bytes)): + return Just(obj) + else: + return Sequence(obj) elif isinstance(obj, type): return Just(obj) elif callable(obj): diff --git a/ibis/common/patterns.py b/ibis/common/patterns.py index 180e8104e8ee..566b78ee8abe 100644 --- a/ibis/common/patterns.py +++ b/ibis/common/patterns.py @@ -1426,15 +1426,18 @@ class PatternList(Slotted, Pattern): @classmethod def __create__(cls, patterns, type=list): + if patterns == (): + return EqualTo(patterns) + patterns = tuple(map(pattern, patterns)) for pat in patterns: pat = _maybe_unwrap_capture(pat) if isinstance(pat, (SomeOf, SomeChunksOf)): return VariadicPatternList(patterns, type) + return super().__create__(patterns, type) def __init__(self, patterns, type): - patterns = tuple(map(pattern, patterns)) super().__init__(patterns=patterns, type=type) def describe(self, plural=False): @@ -1584,12 +1587,15 @@ def pattern(obj: AnyType) -> Pattern: return Capture(obj) elif isinstance(obj, Mapping): raise TypeError("Cannot create a pattern from a mapping") + elif isinstance(obj, Sequence): + if isinstance(obj, (str, bytes)): + return EqualTo(obj) + else: + return PatternList(obj, type=type(obj)) elif isinstance(obj, type): return InstanceOf(obj) elif get_origin(obj): return Pattern.from_typehint(obj, allow_coercion=False) - elif is_iterable(obj): - return PatternList(obj) elif callable(obj): return Custom(obj) else: diff --git a/ibis/common/tests/test_patterns.py b/ibis/common/tests/test_patterns.py index ad9d78a4077d..8a80f1412191 100644 --- a/ibis/common/tests/test_patterns.py +++ b/ibis/common/tests/test_patterns.py @@ -842,6 +842,11 @@ def test_matching_sequence_pattern(): assert match([Some(...), 2, 3, 4, Some(...)], list(range(8))) == list(range(8)) +def test_matching_sequence_pattern_keeps_original_type(): + assert match([1, 2, 3, 4, Some(...)], tuple(range(1, 9))) == list(range(1, 9)) + assert match((1, 2, 3, Some(...)), [1, 2, 3, 4, 5]) == (1, 2, 3, 4, 5) + + def test_matching_sequence_with_captures(): v = list(range(1, 9)) assert match([1, 2, 3, 4, Some(...)], v) == v