Skip to content

Commit

Permalink
Merge pull request #639 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 authored Aug 4, 2020
2 parents e2d73f1 + 6401b36 commit 2d8c896
Show file tree
Hide file tree
Showing 12 changed files with 201 additions and 20 deletions.
25 changes: 24 additions & 1 deletion pytype/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ def is_empty(self):
return not bool(self._member_map)


class Union(AtomicAbstractValue, mixin.NestedAnnotation):
class Union(AtomicAbstractValue, mixin.NestedAnnotation, mixin.HasSlots):
"""A list of types. Used for parameter matching.
Attributes:
Expand All @@ -1437,6 +1437,8 @@ def __init__(self, options, vm):
# TODO(rechen): Don't allow a mix of formal and non-formal types
self.formal = any(t.formal for t in self.options)
mixin.NestedAnnotation.init_mixin(self)
mixin.HasSlots.init_mixin(self)
self.set_slot("__getitem__", self.getitem_slot)

def __repr__(self):
return "%s[%s]" % (self.name, ", ".join(repr(o) for o in self.options))
Expand All @@ -1455,6 +1457,27 @@ def __hash__(self):
def _unique_parameters(self):
return [o.to_variable(self.vm.root_cfg_node) for o in self.options]

def _get_type_params(self):
params = self.vm.annotations_util.get_type_parameters(self)
params = [x.name for x in params]
return utils.unique_list(params)

def getitem_slot(self, node, slice_var):
"""Custom __getitem__ implementation."""
slice_content = abstract_utils.maybe_extract_tuple(slice_var)
params = self._get_type_params()
# Check that we are instantiating all the unbound type parameters
if len(params) != len(slice_content):
details = ("Union has %d type parameters but was instantiated with %d" %
(len(params), len(slice_content)))
self.vm.errorlog.invalid_annotation(
self.vm.frames, self, details=details)
return node, self.vm.new_unsolvable(node)
concrete = [x.data[0].instantiate(node) for x in slice_content]
substs = [dict(zip(params, concrete))]
new = self.vm.annotations_util.sub_one_annotation(node, self, substs)
return node, new.to_variable(node)

def instantiate(self, node, container=None):
var = self.vm.program.NewVariable()
for option in self.options:
Expand Down
3 changes: 3 additions & 0 deletions pytype/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def get_attribute(self, node, obj, name, valself=None):
elif isinstance(obj, abstract.SimpleAbstractValue):
return self._get_instance_attribute(node, obj, name, valself)
elif isinstance(obj, abstract.Union):
if name == "__getitem__":
# __getitem__ is implemented in abstract.Union.getitem_slot.
return node, self.vm.new_unsolvable(node)
nodes = []
ret = self.vm.program.NewVariable()
for o in obj.options:
Expand Down
12 changes: 10 additions & 2 deletions pytype/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,16 @@ def value_to_pytd_type(self, node, v, seen, view):
log.info("Using ? for %s", v.name)
return pytd.AnythingType()
elif isinstance(v, abstract.Union):
return pytd.UnionType(tuple(self.value_to_pytd_type(node, o, seen, view)
for o in v.options))
opts = []
for o in v.options:
# NOTE: Guarding printing of type parameters behind _detailed until
# round-tripping is working properly.
if self._detailed and isinstance(o, abstract.TypeParameter):
opt = self._typeparam_to_def(node, o, o.name)
else:
opt = self.value_to_pytd_type(node, o, seen, view)
opts.append(opt)
return pytd.UnionType(tuple(opts))
elif isinstance(v, special_builtins.SuperInstance):
return pytd.NamedType("__builtin__.super")
elif isinstance(v, abstract.TypeParameter):
Expand Down
7 changes: 7 additions & 0 deletions pytype/overlays/typing_overlay.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,13 @@ def call(self, node, _, args):
return self.namedtuple.call(node, None, args)

def make_class(self, node, f_locals):
# If BuildClass.call() hits max depth, f_locals will be [unsolvable]
# Since we don't support defining NamedTuple subclasses in a nested scope
# anyway, we can just return unsolvable here to prevent a crash, and let the
# invalid namedtuple error get raised later.
if f_locals.data[0].isinstance_Unsolvable():
return node, self.vm.new_unsolvable(node)

f_locals = abstract_utils.get_atomic_python_constant(f_locals)

# retrieve __qualname__ to get the name of class
Expand Down
32 changes: 30 additions & 2 deletions pytype/pytd/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,7 @@ def __init__(self):
self.function_name = None
self.constant_name = None
self.all_typeparams = set()
self.generic_level = 0

def _GetTemplateItems(self, param):
"""Get a list of template items from a parameter."""
Expand Down Expand Up @@ -1396,6 +1397,30 @@ def EnterConstant(self, node):
def LeaveConstant(self, unused_node):
self.constant_name = None

def EnterGenericType(self, unused_node):
self.generic_level += 1

def LeaveGenericType(self, unused_node):
self.generic_level -= 1

def EnterCallableType(self, node):
self.EnterGenericType(node)

def LeaveCallableType(self, node):
self.LeaveGenericType(node)

def EnterTupleType(self, node):
self.EnterGenericType(node)

def LeaveTupleType(self, node):
self.LeaveGenericType(node)

def EnterUnionType(self, node):
self.EnterGenericType(node)

def LeaveUnionType(self, node):
self.LeaveGenericType(node)

def _GetFullName(self, name):
return ".".join(n for n in [self.class_name, name] if n)

Expand All @@ -1404,10 +1429,13 @@ def _GetScope(self, name):
return self.class_name
return self._GetFullName(self.function_name)

def _IsBoundTypeParam(self, node):
in_class = self.class_name and node.name in self.class_typeparams
return in_class or self.generic_level

def VisitTypeParameter(self, node):
"""Add scopes to type parameters, track unbound params."""
if self.constant_name and (not self.class_name or
node.name not in self.class_typeparams):
if self.constant_name and not self._IsBoundTypeParam(node):
raise ContainerError("Unbound type parameter %s in %s" % (
node.name, self._GetFullName(self.constant_name)))
scope = self._GetScope(node.name)
Expand Down
12 changes: 12 additions & 0 deletions pytype/tests/py3/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,18 @@ class Foo:
X = List['int']
""")

def test_nested_forward_ref_to_import(self):
with file_utils.Tempdir() as d:
d.create_file("foo.pyi", """
class Foo: ...
""")
self.Check("""
import foo
from typing import Tuple
def f(x: Tuple[str, 'foo.Foo']):
pass
""", pythonpath=[d.path])


class TestAnnotationsPython3Feature(test_base.TargetPython3FeatureTest):
"""Tests for PEP 484 style inline annotations."""
Expand Down
28 changes: 27 additions & 1 deletion pytype/tests/py3/test_typevar.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,38 @@ def f(x: T) -> T:
""")
self.assertErrorRegexes(errors, {"e": "Expected.*T.*Actual.*TypeVar"})

def test_typevar_in_union_alias(self):
ty = self.Infer("""
from typing import Dict, List, TypeVar, Union
T = TypeVar("T")
U = TypeVar("U")
Foo = Union[T, List[T], Dict[T, List[U]], complex]
def f(x: Foo[int, str]): ...
""")
self.assertTypesMatchPytd(ty, """
from typing import Dict, List, TypeVar, Union, Any
T = TypeVar("T")
U = TypeVar("U")
Foo: Any
def f(x: Union[Dict[int, List[str]], List[int], complex, int]) -> None: ...
""")

def test_typevar_in_union_alias_error(self):
err = self.CheckWithErrors("""
from typing import Dict, List, TypeVar, Union
T = TypeVar("T")
U = TypeVar("U")
Foo = Union[T, List[T], Dict[T, List[U]], complex]
def f(x: Foo[int]): ... # invalid-annotation[e]
""")
self.assertErrorRegexes(err, {"e": "Union.*2.*instantiated.*1"})

def test_use_unsupported_typevar(self):
# Test that we don't crash when using this pattern (b/162274390)
self.CheckWithErrors("""
from typing import List, TypeVar, Union
T = TypeVar("T")
Tree = Union[T, List['Tree']] # not-supported-yet # not-supported-yet
Tree = Union[T, List['Tree']] # not-supported-yet
def f(x: Tree[int]): ... # no error since Tree is set to Any
""")

Expand Down
35 changes: 35 additions & 0 deletions pytype/tests/py3/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,41 @@ def f(x: MutableSet) -> MutableSet:
return x - {0}
""")

def test_union_of_classes(self):
ty = self.Infer("""
from typing import Type, Union
class Foo(object):
def __getitem__(self, x) -> int:
return 0
class Bar(object):
def __getitem__(self, x) -> str:
return ''
def f(x: Union[Type[Foo], Type[Bar]]):
return x.__getitem__
def g(x: Type[Union[Foo, Bar]]):
return x.__getitem__
""")
# The inferred return type of `g` is technically incorrect: it is inferred
# from the type of abstract.Union.getitem_slot, which is a NativeFunction,
# so its type defaults to a plain Callable. We should instead look up
# Foo.__getitem__ and Bar.__getitem__ as we do for `f`, but it is currently
# not possible to distinguish between using Union.getitem_slot and accessing
# the actual __getitem__ method on a union's options. Inferring `Callable`
# should generally be safe, since __getitem__ is a method by convention.
self.assertTypesMatchPytd(ty, """
from typing import Any, Callable, Type, Union
class Foo:
def __getitem__(self, x) -> int: ...
class Bar:
def __getitem__(self, x) -> str: ...
def f(x: Type[Union[Foo, Bar]]) -> Callable[[Any, Any], Union[int, str]]: ...
def g(x: Type[Union[Foo, Bar]]) -> Callable: ...
""")


class CounterTest(test_base.TargetPython3BasicTest):
"""Tests for typing.Counter."""
Expand Down
13 changes: 13 additions & 0 deletions pytype/tests/py3/test_typing_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,5 +294,18 @@ class Foo(NamedTuple):
""")
self.assertErrorRegexes(errors, {"e": r"Annotation: str.*Assignment: int"})

def test_nested_namedtuple(self):
# Guard against a crash when hitting max depth (b/162619036)
self.assertNoCrash(self.Check, """
from typing import NamedTuple
def foo() -> None:
class A(NamedTuple):
x: int
def bar():
foo()
""")


test_base.main(globals(), __name__ == "__main__")
33 changes: 29 additions & 4 deletions pytype/tests/test_typevar.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,18 @@ class Foo:
""")

def test_typevar_in_alias(self):
err = self.CheckWithErrors("""
ty = self.Infer("""
from typing import TypeVar, Union
T = TypeVar("T", int, float)
Num = Union[T, complex] # not-supported-yet[e]
Num = Union[T, complex]
x = 10 # type: Num[int]
""")
self.assertTypesMatchPytd(ty, """
from typing import Any, TypeVar, Union
T = TypeVar("T", int, float)
Num: Any
x: Union[int, complex] = ...
""")
self.assertErrorRegexes(
err, {"e": "aliases of Unions with type parameters"})

def test_recursive_alias(self):
errors = self.CheckWithErrors("""
Expand Down Expand Up @@ -439,5 +444,25 @@ def g(x): # type: (Sequence[T]) -> Type[Sequence[T]]
""")
self.assertErrorRegexes(errors, {"e": "Expected.*int.*Actual.*Sequence"})

def test_typevar_in_constant(self):
ty = self.Infer("""
from typing import TypeVar
T = TypeVar('T')
class Foo(object):
def __init__(self):
self.f1 = self.f2
def f2(self, x):
# type: (T) -> T
return x
""")
self.assertTypesMatchPytd(ty, """
from typing import Callable, TypeVar
T = TypeVar('T')
class Foo:
f1: Callable[[T], T]
def __init__(self) -> None: ...
def f2(self, x: T) -> T: ...
""")


test_base.main(globals(), __name__ == "__main__")
11 changes: 11 additions & 0 deletions pytype/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,17 @@ def invert_dict(d):
return inverted


def unique_list(xs):
"""Return a unique list from an iterable, preserving order."""
seen = set()
out = []
for x in xs:
if x not in seen:
seen.add(x)
out.append(x)
return out


class DynamicVar(object):
"""A dynamically scoped variable.
Expand Down
10 changes: 0 additions & 10 deletions pytype/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,14 +1287,6 @@ def store_global(self, state, name, value):
"""Same as store_local except for globals."""
return self._store_value(state, name, value, local=False)

def _check_for_aliased_type_params(self, value):
for v in value.data:
if isinstance(v, abstract.Union) and v.formal:
self.errorlog.not_supported_yet(
self.frames, "aliases of Unions with type parameters")
return True
return False

def _remove_recursion(self, node, name, value):
"""Remove any recursion in the named value."""
if not value.data or any(not isinstance(v, mixin.NestedAnnotation)
Expand Down Expand Up @@ -1376,8 +1368,6 @@ def _pop_and_store(self, state, op, name, local):
value = self._apply_annotation(
state, op, name, orig_val, annotations_dict, check_types)
value = self._remove_recursion(state.node, name, value)
if self._check_for_aliased_type_params(value):
value = self.new_unsolvable(state.node)
state = state.forward_cfg_node()
state = self._store_value(state, name, value, local)
self.trace_opcode(op, name, value)
Expand Down

0 comments on commit 2d8c896

Please sign in to comment.