From 1889ee043e96de1ecd0be677f135acc705fbcedd Mon Sep 17 00:00:00 2001 From: mdemello Date: Fri, 31 Jul 2020 15:10:34 -0700 Subject: [PATCH 1/4] Support aliases to unions with type parameters This currently only works within a file; in the following code: Foo = Union[T, List[T]] x: Foo[int] the generated pyi code will contain Foo = Any x: Union[int, List[int]] that is, concrete instantiations of the alias will be correctly used and exported, but the alias itself will not be exported. PiperOrigin-RevId: 324293324 --- pytype/abstract.py | 25 ++++++++++++++++++++++++- pytype/output.py | 12 ++++++++++-- pytype/pytd/visitors.py | 20 ++++++++++++++++++-- pytype/tests/py3/test_typevar.py | 28 +++++++++++++++++++++++++++- pytype/tests/test_typevar.py | 13 +++++++++---- pytype/utils.py | 11 +++++++++++ pytype/vm.py | 10 ---------- 7 files changed, 99 insertions(+), 20 deletions(-) diff --git a/pytype/abstract.py b/pytype/abstract.py index 86469e880..311171a8f 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -1423,7 +1423,7 @@ def is_empty(self): return not bool(self._member_map) -class Union(AtomicAbstractValue, mixin.NestedAnnotation): +class Union(AtomicAbstractValue, mixin.NestedAnnotation, mixin.HasSlots): """A list of types. Used for parameter matching. Attributes: @@ -1437,6 +1437,8 @@ def __init__(self, options, vm): # TODO(rechen): Don't allow a mix of formal and non-formal types self.formal = any(t.formal for t in self.options) mixin.NestedAnnotation.init_mixin(self) + mixin.HasSlots.init_mixin(self) + self.set_slot("__getitem__", self.getitem_slot) def __repr__(self): return "%s[%s]" % (self.name, ", ".join(repr(o) for o in self.options)) @@ -1455,6 +1457,27 @@ def __hash__(self): def _unique_parameters(self): return [o.to_variable(self.vm.root_cfg_node) for o in self.options] + def _get_type_params(self): + params = self.vm.annotations_util.get_type_parameters(self) + params = [x.name for x in params] + return utils.unique_list(params) + + def getitem_slot(self, node, slice_var): + """Custom __getitem__ implementation.""" + slice_content = abstract_utils.maybe_extract_tuple(slice_var) + params = self._get_type_params() + # Check that we are instantiating all the unbound type parameters + if len(params) != len(slice_content): + details = ("Union has %d type parameters but was instantiated with %d" % + (len(params), len(slice_content))) + self.vm.errorlog.invalid_annotation( + self.vm.frames, self, details=details) + return node, self.vm.new_unsolvable(node) + concrete = [x.data[0].instantiate(node) for x in slice_content] + substs = [dict(zip(params, concrete))] + new = self.vm.annotations_util.sub_one_annotation(node, self, substs) + return node, new.to_variable(node) + def instantiate(self, node, container=None): var = self.vm.program.NewVariable() for option in self.options: diff --git a/pytype/output.py b/pytype/output.py index 96a3e0af6..4d5a33bfb 100644 --- a/pytype/output.py +++ b/pytype/output.py @@ -261,8 +261,16 @@ def value_to_pytd_type(self, node, v, seen, view): log.info("Using ? for %s", v.name) return pytd.AnythingType() elif isinstance(v, abstract.Union): - return pytd.UnionType(tuple(self.value_to_pytd_type(node, o, seen, view) - for o in v.options)) + opts = [] + for o in v.options: + # NOTE: Guarding printing of type parameters behind _detailed until + # round-tripping is working properly. + if self._detailed and isinstance(o, abstract.TypeParameter): + opt = self._typeparam_to_def(node, o, o.name) + else: + opt = self.value_to_pytd_type(node, o, seen, view) + opts.append(opt) + return pytd.UnionType(tuple(opts)) elif isinstance(v, special_builtins.SuperInstance): return pytd.NamedType("__builtin__.super") elif isinstance(v, abstract.TypeParameter): diff --git a/pytype/pytd/visitors.py b/pytype/pytd/visitors.py index daa1850d5..b8336bda2 100644 --- a/pytype/pytd/visitors.py +++ b/pytype/pytd/visitors.py @@ -1274,6 +1274,7 @@ def __init__(self): self.function_name = None self.constant_name = None self.all_typeparams = set() + self.generic_level = 0 def _GetTemplateItems(self, param): """Get a list of template items from a parameter.""" @@ -1396,6 +1397,18 @@ def EnterConstant(self, node): def LeaveConstant(self, unused_node): self.constant_name = None + def EnterUnionType(self, unused_node): + self.generic_level += 1 + + def LeaveUnionType(self, unused_node): + self.generic_level -= 1 + + def EnterGenericType(self, unused_node): + self.generic_level += 1 + + def LeaveGenericType(self, unused_node): + self.generic_level -= 1 + def _GetFullName(self, name): return ".".join(n for n in [self.class_name, name] if n) @@ -1404,10 +1417,13 @@ def _GetScope(self, name): return self.class_name return self._GetFullName(self.function_name) + def _IsBoundTypeParam(self, node): + in_class = self.class_name and node.name in self.class_typeparams + return in_class or self.generic_level + def VisitTypeParameter(self, node): """Add scopes to type parameters, track unbound params.""" - if self.constant_name and (not self.class_name or - node.name not in self.class_typeparams): + if self.constant_name and not self._IsBoundTypeParam(node): raise ContainerError("Unbound type parameter %s in %s" % ( node.name, self._GetFullName(self.constant_name))) scope = self._GetScope(node.name) diff --git a/pytype/tests/py3/test_typevar.py b/pytype/tests/py3/test_typevar.py index bae029853..387970874 100644 --- a/pytype/tests/py3/test_typevar.py +++ b/pytype/tests/py3/test_typevar.py @@ -413,12 +413,38 @@ def f(x: T) -> T: """) self.assertErrorRegexes(errors, {"e": "Expected.*T.*Actual.*TypeVar"}) + def test_typevar_in_union_alias(self): + ty = self.Infer(""" + from typing import Dict, List, TypeVar, Union + T = TypeVar("T") + U = TypeVar("U") + Foo = Union[T, List[T], Dict[T, List[U]], complex] + def f(x: Foo[int, str]): ... + """) + self.assertTypesMatchPytd(ty, """ + from typing import Dict, List, TypeVar, Union, Any + T = TypeVar("T") + U = TypeVar("U") + Foo: Any + def f(x: Union[Dict[int, List[str]], List[int], complex, int]) -> None: ... + """) + + def test_typevar_in_union_alias_error(self): + err = self.CheckWithErrors(""" + from typing import Dict, List, TypeVar, Union + T = TypeVar("T") + U = TypeVar("U") + Foo = Union[T, List[T], Dict[T, List[U]], complex] + def f(x: Foo[int]): ... # invalid-annotation[e] + """) + self.assertErrorRegexes(err, {"e": "Union.*2.*instantiated.*1"}) + def test_use_unsupported_typevar(self): # Test that we don't crash when using this pattern (b/162274390) self.CheckWithErrors(""" from typing import List, TypeVar, Union T = TypeVar("T") - Tree = Union[T, List['Tree']] # not-supported-yet # not-supported-yet + Tree = Union[T, List['Tree']] # not-supported-yet def f(x: Tree[int]): ... # no error since Tree is set to Any """) diff --git a/pytype/tests/test_typevar.py b/pytype/tests/test_typevar.py index 8c691a1f3..7f52f9902 100644 --- a/pytype/tests/test_typevar.py +++ b/pytype/tests/test_typevar.py @@ -402,13 +402,18 @@ class Foo: """) def test_typevar_in_alias(self): - err = self.CheckWithErrors(""" + ty = self.Infer(""" from typing import TypeVar, Union T = TypeVar("T", int, float) - Num = Union[T, complex] # not-supported-yet[e] + Num = Union[T, complex] + x = 10 # type: Num[int] + """) + self.assertTypesMatchPytd(ty, """ + from typing import Any, TypeVar, Union + T = TypeVar("T", int, float) + Num: Any + x: Union[int, complex] = ... """) - self.assertErrorRegexes( - err, {"e": "aliases of Unions with type parameters"}) def test_recursive_alias(self): errors = self.CheckWithErrors(""" diff --git a/pytype/utils.py b/pytype/utils.py index 095ce96e0..549eedd5b 100644 --- a/pytype/utils.py +++ b/pytype/utils.py @@ -322,6 +322,17 @@ def invert_dict(d): return inverted +def unique_list(xs): + """Return a unique list from an iterable, preserving order.""" + seen = set() + out = [] + for x in xs: + if x not in seen: + seen.add(x) + out.append(x) + return out + + class DynamicVar(object): """A dynamically scoped variable. diff --git a/pytype/vm.py b/pytype/vm.py index d664ead81..b6de20245 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1287,14 +1287,6 @@ def store_global(self, state, name, value): """Same as store_local except for globals.""" return self._store_value(state, name, value, local=False) - def _check_for_aliased_type_params(self, value): - for v in value.data: - if isinstance(v, abstract.Union) and v.formal: - self.errorlog.not_supported_yet( - self.frames, "aliases of Unions with type parameters") - return True - return False - def _remove_recursion(self, node, name, value): """Remove any recursion in the named value.""" if not value.data or any(not isinstance(v, mixin.NestedAnnotation) @@ -1376,8 +1368,6 @@ def _pop_and_store(self, state, op, name, local): value = self._apply_annotation( state, op, name, orig_val, annotations_dict, check_types) value = self._remove_recursion(state.node, name, value) - if self._check_for_aliased_type_params(value): - value = self.new_unsolvable(state.node) state = state.forward_cfg_node() state = self._store_value(state, name, value, local) self.trace_opcode(op, name, value) From c4c06a95b0e98b18a57887575b0b6241197f9117 Mon Sep 17 00:00:00 2001 From: mdemello Date: Mon, 3 Aug 2020 11:12:13 -0700 Subject: [PATCH 2/4] FIX: Don't crash when a nested NamedTuple subclass hits max depth. PiperOrigin-RevId: 324641138 --- pytype/overlays/typing_overlay.py | 7 +++++++ pytype/tests/py3/test_typing_namedtuple.py | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/pytype/overlays/typing_overlay.py b/pytype/overlays/typing_overlay.py index d69ed6b03..fa9ba7924 100644 --- a/pytype/overlays/typing_overlay.py +++ b/pytype/overlays/typing_overlay.py @@ -512,6 +512,13 @@ def call(self, node, _, args): return self.namedtuple.call(node, None, args) def make_class(self, node, f_locals): + # If BuildClass.call() hits max depth, f_locals will be [unsolvable] + # Since we don't support defining NamedTuple subclasses in a nested scope + # anyway, we can just return unsolvable here to prevent a crash, and let the + # invalid namedtuple error get raised later. + if f_locals.data[0].isinstance_Unsolvable(): + return node, self.vm.new_unsolvable(node) + f_locals = abstract_utils.get_atomic_python_constant(f_locals) # retrieve __qualname__ to get the name of class diff --git a/pytype/tests/py3/test_typing_namedtuple.py b/pytype/tests/py3/test_typing_namedtuple.py index b45529cfd..50c7858f0 100644 --- a/pytype/tests/py3/test_typing_namedtuple.py +++ b/pytype/tests/py3/test_typing_namedtuple.py @@ -294,5 +294,18 @@ class Foo(NamedTuple): """) self.assertErrorRegexes(errors, {"e": r"Annotation: str.*Assignment: int"}) + def test_nested_namedtuple(self): + # Guard against a crash when hitting max depth (b/162619036) + self.assertNoCrash(self.Check, """ + from typing import NamedTuple + + def foo() -> None: + class A(NamedTuple): + x: int + + def bar(): + foo() + """) + test_base.main(globals(), __name__ == "__main__") From d0b736a8dafd6e45aa36bba0d104e17f5a106dcd Mon Sep 17 00:00:00 2001 From: rechen Date: Mon, 3 Aug 2020 14:33:13 -0700 Subject: [PATCH 3/4] Fix a weird bug involving calling __getitem__ on unions. The CL to support Union type macros caused a bizarre bug where Union.set_slot('__getitem__') calls getattribute(Union, '__getitem__'), creating an extra node, causing convert_class_annotations for a tuple annotation containing a forward referenced imported type to be evaluated at the wrong node, leading to the generation of an extra late annotation for the imported module, finally producing "Invalid type annotation '': Not a type" errors all over the place for pb2 enum annotations, since those always have to be quoted. I resolved this by special-casing Union.__getitem__ in attribute.py, so the extra node is not generated. I also noticed that the type macro CL produced a corner case where Union.getitem_slot is indistinguishable from calling __getitem__ on a Union's options when a value is annotated as Type[Union[C1, C2, ...]] (but *not* when the annotation is Union[Type[C1], Type[C2], ...]!). The inferred type is still correct, if imprecise, so I just documented this behavior in a test rather than trying to fix it. PiperOrigin-RevId: 324684760 --- pytype/attribute.py | 3 +++ pytype/tests/py3/test_annotations.py | 12 ++++++++++ pytype/tests/py3/test_typing.py | 35 ++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/pytype/attribute.py b/pytype/attribute.py index f171e277a..7a06b0fe0 100644 --- a/pytype/attribute.py +++ b/pytype/attribute.py @@ -58,6 +58,9 @@ def get_attribute(self, node, obj, name, valself=None): elif isinstance(obj, abstract.SimpleAbstractValue): return self._get_instance_attribute(node, obj, name, valself) elif isinstance(obj, abstract.Union): + if name == "__getitem__": + # __getitem__ is implemented in abstract.Union.getitem_slot. + return node, self.vm.new_unsolvable(node) nodes = [] ret = self.vm.program.NewVariable() for o in obj.options: diff --git a/pytype/tests/py3/test_annotations.py b/pytype/tests/py3/test_annotations.py index 8597f9266..16953bd19 100644 --- a/pytype/tests/py3/test_annotations.py +++ b/pytype/tests/py3/test_annotations.py @@ -1081,6 +1081,18 @@ class Foo: X = List['int'] """) + def test_nested_forward_ref_to_import(self): + with file_utils.Tempdir() as d: + d.create_file("foo.pyi", """ + class Foo: ... + """) + self.Check(""" + import foo + from typing import Tuple + def f(x: Tuple[str, 'foo.Foo']): + pass + """, pythonpath=[d.path]) + class TestAnnotationsPython3Feature(test_base.TargetPython3FeatureTest): """Tests for PEP 484 style inline annotations.""" diff --git a/pytype/tests/py3/test_typing.py b/pytype/tests/py3/test_typing.py index c0ea8993d..b7e8cba26 100644 --- a/pytype/tests/py3/test_typing.py +++ b/pytype/tests/py3/test_typing.py @@ -560,6 +560,41 @@ def f(x: MutableSet) -> MutableSet: return x - {0} """) + def test_union_of_classes(self): + ty = self.Infer(""" + from typing import Type, Union + + class Foo(object): + def __getitem__(self, x) -> int: + return 0 + class Bar(object): + def __getitem__(self, x) -> str: + return '' + + def f(x: Union[Type[Foo], Type[Bar]]): + return x.__getitem__ + def g(x: Type[Union[Foo, Bar]]): + return x.__getitem__ + """) + # The inferred return type of `g` is technically incorrect: it is inferred + # from the type of abstract.Union.getitem_slot, which is a NativeFunction, + # so its type defaults to a plain Callable. We should instead look up + # Foo.__getitem__ and Bar.__getitem__ as we do for `f`, but it is currently + # not possible to distinguish between using Union.getitem_slot and accessing + # the actual __getitem__ method on a union's options. Inferring `Callable` + # should generally be safe, since __getitem__ is a method by convention. + self.assertTypesMatchPytd(ty, """ + from typing import Any, Callable, Type, Union + + class Foo: + def __getitem__(self, x) -> int: ... + class Bar: + def __getitem__(self, x) -> str: ... + + def f(x: Type[Union[Foo, Bar]]) -> Callable[[Any, Any], Union[int, str]]: ... + def g(x: Type[Union[Foo, Bar]]) -> Callable: ... + """) + class CounterTest(test_base.TargetPython3BasicTest): """Tests for typing.Counter.""" From 6401b36ba24a8736010a31f444ab755d110d71a0 Mon Sep 17 00:00:00 2001 From: rechen Date: Tue, 4 Aug 2020 14:32:56 -0700 Subject: [PATCH 4/4] Entering a CallableType or TupleType should increase the generic level. PiperOrigin-RevId: 324891454 --- pytype/pytd/visitors.py | 24 ++++++++++++++++++------ pytype/tests/test_typevar.py | 20 ++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/pytype/pytd/visitors.py b/pytype/pytd/visitors.py index b8336bda2..4b41dc811 100644 --- a/pytype/pytd/visitors.py +++ b/pytype/pytd/visitors.py @@ -1397,18 +1397,30 @@ def EnterConstant(self, node): def LeaveConstant(self, unused_node): self.constant_name = None - def EnterUnionType(self, unused_node): - self.generic_level += 1 - - def LeaveUnionType(self, unused_node): - self.generic_level -= 1 - def EnterGenericType(self, unused_node): self.generic_level += 1 def LeaveGenericType(self, unused_node): self.generic_level -= 1 + def EnterCallableType(self, node): + self.EnterGenericType(node) + + def LeaveCallableType(self, node): + self.LeaveGenericType(node) + + def EnterTupleType(self, node): + self.EnterGenericType(node) + + def LeaveTupleType(self, node): + self.LeaveGenericType(node) + + def EnterUnionType(self, node): + self.EnterGenericType(node) + + def LeaveUnionType(self, node): + self.LeaveGenericType(node) + def _GetFullName(self, name): return ".".join(n for n in [self.class_name, name] if n) diff --git a/pytype/tests/test_typevar.py b/pytype/tests/test_typevar.py index 7f52f9902..35c158ac7 100644 --- a/pytype/tests/test_typevar.py +++ b/pytype/tests/test_typevar.py @@ -444,5 +444,25 @@ def g(x): # type: (Sequence[T]) -> Type[Sequence[T]] """) self.assertErrorRegexes(errors, {"e": "Expected.*int.*Actual.*Sequence"}) + def test_typevar_in_constant(self): + ty = self.Infer(""" + from typing import TypeVar + T = TypeVar('T') + class Foo(object): + def __init__(self): + self.f1 = self.f2 + def f2(self, x): + # type: (T) -> T + return x + """) + self.assertTypesMatchPytd(ty, """ + from typing import Callable, TypeVar + T = TypeVar('T') + class Foo: + f1: Callable[[T], T] + def __init__(self) -> None: ... + def f2(self, x: T) -> T: ... + """) + test_base.main(globals(), __name__ == "__main__")