Skip to content

Commit

Permalink
Merge pull request #627 from google/google_sync
Browse files Browse the repository at this point in the history
Google sync
  • Loading branch information
rchen152 authored Jul 24, 2020
2 parents dc1f04d + 35f4e51 commit b1f27ec
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 72 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
Version 2020.07.24
* pyi parser: allow aliases inside a class to values outside the class.
* Copy annotations instead of modifying them when adding a scope.
* Make self.__class__ return Any in __init__.
* Check object visibility before setting attributes.

Version 2020.07.20
* pyi parser: support importing TypedDict from typing_extensions.

Expand Down
2 changes: 1 addition & 1 deletion pytype/__version__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# pylint: skip-file
__version__ = '2020.07.20'
__version__ = '2020.07.24'
12 changes: 7 additions & 5 deletions pytype/annotations_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def add_scope(self, annot, types, module):
return new_annot
return annot
elif isinstance(annot, abstract.TupleClass):
annot.formal_type_parameters[abstract_utils.T] = self.add_scope(
params = dict(annot.formal_type_parameters)
params[abstract_utils.T] = self.add_scope(
annot.formal_type_parameters[abstract_utils.T], types, module)
return annot
return abstract.TupleClass(
annot.base_cls, params, self.vm, annot.template)
elif isinstance(annot, mixin.NestedAnnotation):
for key, typ in annot.get_inner_types():
annot.update_inner_type(key, self.add_scope(typ, types, module))
return annot
inner_types = [(key, self.add_scope(typ, types, module))
for key, typ in annot.get_inner_types()]
return annot.replace(inner_types)
return annot

def get_type_parameters(self, annot, seen=None):
Expand Down
17 changes: 14 additions & 3 deletions pytype/pyi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,15 +1169,26 @@ def new_class(self, decorators, class_name, parent_args, defs):
for val in constants + aliases + methods + classes}
for val in aliases:
name = val.name
seen_names = set()
while isinstance(val, pytd.Alias):
if isinstance(val.type, pytd.NamedType):
_, _, base_name = val.type.name.rpartition(".")
if base_name in seen_names:
# This happens in cases like:
# class X:
# Y = something.Y
# Since we try to resolve aliases immediately, we don't know what
# type to fill in when the alias value comes from outside the
# class. The best we can do is Any.
val = pytd.Constant(name, pytd.AnythingType())
continue
seen_names.add(base_name)
if base_name in vals_dict:
val = vals_dict[base_name]
continue
raise ParseError(
"Illegal value for alias %r. Value must be an attribute "
"on the same class." % val.name)
# The alias value comes from outside the class. The best we can do is
# to fill in Any.
val = pytd.Constant(name, pytd.AnythingType())
if isinstance(val, _NameAndSig):
methods.append(val._replace(name=name))
else:
Expand Down
34 changes: 27 additions & 7 deletions pytype/pyi/parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,33 @@ def test_alias_lookup(self):
x: somewhere.Foo""")

def test_external_alias(self):
self.check("""
from somewhere import Foo
class Bar:
Baz = Foo
""", """
from typing import Any
from somewhere import Foo
class Bar:
Baz: Any
""")

def test_same_named_alias(self):
self.check("""
import somewhere
class Bar:
Foo = somewhere.Foo
""", """
from typing import Any
class Bar:
Foo: Any
""")

def test_type_params(self):
ast = self.check("""
from typing import TypeVar
Expand Down Expand Up @@ -1650,13 +1677,6 @@ class Foo:
import foo
""", 3, "syntax error")

def test_bad_alias(self):
self.check_error("""
class Foo:
if sys.version_info > (3, 4, 0):
a = b
""", 1, "Illegal value for alias 'a'")

def test_no_class(self):
self.check("""
class Foo:
Expand Down
2 changes: 1 addition & 1 deletion pytype/pytd/pytd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def MergeBaseClass(cls, base):
classes = cls.classes + tuple(c for c in base.classes
if c.name not in class_names)
if cls.slots:
slots = cls.clots + tuple(s for s in base.slots or () if s not in cls.slots)
slots = cls.slots + tuple(s for s in base.slots or () if s not in cls.slots)
else:
slots = base.slots
return pytd.Class(name=cls.name,
Expand Down
15 changes: 15 additions & 0 deletions pytype/tests/py3/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,21 @@ class Foo(List["Foo"]):
pass
""")

def test_type_parameter_count(self):
self.Check("""
from typing import Generic, List, TypeVar
T = TypeVar('T')
SomeAlias = List[T]
class Foo(Generic[T]):
def __init__(self, x: T, y: SomeAlias):
pass
def f(x: T) -> SomeAlias:
return [x]
""")


class GenericFeatureTest(test_base.TargetPython3FeatureTest):
"""Tests for User-defined Generic Type."""
Expand Down
38 changes: 20 additions & 18 deletions pytype/tests/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,24 +503,6 @@ def __getattribute__(self, name) -> bool
def f(x) -> Any
""")

@test_base.skip("TODO(b/63407497): implement strict checking for __setitem__")
def test_union_set_attribute(self):
ty, _ = self.InferWithErrors("""
class A(object):
x = "Hello world"
def f(i):
t = A()
l = [t]
l[i].x = 1 # not-writable
return l[i].x
""")
self.assertTypesMatchPytd(ty, """
from typing import Any
class A(object):
x = ... # type: str
def f(i) -> Any
""")

def test_set_class(self):
ty = self.Infer("""
def f(x):
Expand Down Expand Up @@ -846,5 +828,25 @@ def oops(self) -> None: ...
""")
self.assertErrorRegexes(errors, {"e": r"Annotation: int.*Assignment: str"})

def test_split(self):
ty = self.Infer("""
from typing import Union
class Foo:
pass
class Bar:
pass
def f(x):
# type: (Union[Foo, Bar]) -> None
if isinstance(x, Foo):
x.foo = 42
""")
self.assertTypesMatchPytd(ty, """
from typing import Union
class Foo:
foo: int
class Bar: ...
def f(x: Union[Foo, Bar]) -> None: ...
""")


test_base.main(globals(), __name__ == "__main__")
14 changes: 14 additions & 0 deletions pytype/tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,5 +1411,19 @@ class Foo(Any):
list(Foo())
""")

def test_instantiate_class(self):
self.Check("""
import abc
import six
@six.add_metaclass(abc.ABCMeta)
class Foo(object):
def __init__(self, x):
if x > 0:
print(self.__class__(x-1))
@abc.abstractmethod
def f(self):
pass
""")


test_base.main(globals(), __name__ == "__main__")
13 changes: 13 additions & 0 deletions pytype/tests/test_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,18 @@ def __init__(self):
""")
self.assertErrorRegexes(errors, {"e": r"__baz"})

def test_union(self):
self.Check("""
from typing import Union
class Foo(object):
pass
class Bar(object):
__slots__ = ()
def f(x):
# type: (Union[Foo, Bar]) -> None
if isinstance(x, Foo):
x.foo = 42
""")


test_base.main(globals(), __name__ == "__main__")
3 changes: 3 additions & 0 deletions pytype/tools/annotate_ast/annotate_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def visit_Name(self, node):
def visit_Attribute(self, node):
self._maybe_annotate(node)

def visit_FunctionDef(self, node):
self._maybe_annotate(node)

def _maybe_annotate(self, node):
"""Annotates a node."""
try:
Expand Down
85 changes: 54 additions & 31 deletions pytype/tools/annotate_ast/annotate_ast_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import ast
import itertools
import textwrap

from pytype import config
from pytype.tests import test_base
from pytype.tools.annotate_ast import annotate_ast
import six


class AnnotaterTest(test_base.TargetIndependentTest):
Expand All @@ -18,28 +16,10 @@ def annotate(self, source):
module = annotate_ast.annotate_source(source, ast_factory, pytype_options)
return module

def assert_annotations_equal(self, expected, module):
nodes = [
node for node in ast.walk(module)
if getattr(node, 'resolved_type', None)
]
actual = {}
for node in nodes:
key = self._get_node_key(node)
actual[key] = '{} :: {!r}'.format(node.resolved_annotation,
node.resolved_type)

for key in sorted(set(itertools.chain(expected, actual))):
expected_pattern = expected.get(key)
if not expected_pattern:
self.fail('Unexpected annotation: {} -> {}'.format(key, actual[key]))
actual_text = actual.get(key)
if not actual_text:
self.fail(
'Expected to find node {} annotated, but it was not.'.format(key))
msg = ('Resolved annotation value does not match {!r}: Node {} annotated '
'with {}').format(expected_pattern, key, actual_text)
six.assertRegex(self, actual_text, expected_pattern, msg=msg)
def get_annotations_dict(self, module):
return {self._get_node_key(node): node.resolved_annotation
for node in ast.walk(module)
if hasattr(node, 'resolved_type')}

def _get_node_key(self, node):
base = (node.lineno, node.__class__.__name__)
Expand All @@ -48,25 +28,27 @@ def _get_node_key(self, node):
return base + (node.id,)
elif isinstance(node, ast.Attribute):
return base + (node.attr,)
elif isinstance(node, ast.FunctionDef):
return base + (node.name,)
else:
return base

def test_annotating_name(self):
source = """
a = 1
b = {}
c = []
b = {1: 'foo'}
c = [1, 2, 3]
d = 3, 4
"""
module = self.annotate(source)

expected = {
(1, 'Name', 'a'): 'int',
(2, 'Name', 'b'): 'dict',
(3, 'Name', 'c'): 'list',
(4, 'Name', 'd'): 'tuple',
(2, 'Name', 'b'): 'Dict[int, str]',
(3, 'Name', 'c'): 'List[int]',
(4, 'Name', 'd'): 'Tuple[int, int]',
}
self.assert_annotations_equal(expected, module)
self.assertEqual(expected, self.get_annotations_dict(module))

def test_annotating_attribute(self):
source = """
Expand All @@ -84,6 +66,47 @@ def test_annotating_attribute(self):
(2, 'Attribute', 'Bar'): 'Any',
(2, 'Attribute', 'bar'): 'Any',
}
self.assert_annotations_equal(expected, module)
self.assertEqual(expected, self.get_annotations_dict(module))

def test_annotating_for(self):
source = """
for i in 1, 2, 3:
pass
"""

module = self.annotate(source)

expected = {
(1, 'Name', 'i'): 'int',
}
self.assertEqual(expected, self.get_annotations_dict(module))

def test_annotating_with(self):
source = """
with foo() as f:
pass
"""

module = self.annotate(source)

expected = {
(1, 'Name', 'foo'): 'Any',
(1, 'Name', 'f'): 'Any',
}
self.assertEqual(expected, self.get_annotations_dict(module))

def test_annotating_def(self):
source = """
def foo(a, b):
# type: (str, int) -> str
pass
"""

module = self.annotate(source)

expected = {
(1, 'FunctionDef', 'foo'): 'Callable[[str, int], str]',
}
self.assertEqual(expected, self.get_annotations_dict(module))

test_base.main(globals(), __name__ == '__main__')
Loading

0 comments on commit b1f27ec

Please sign in to comment.