diff --git a/CHANGELOG b/CHANGELOG index d0af7297f..adb184bc3 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,9 @@ +Version 2020.07.24 +* pyi parser: allow aliases inside a class to values outside the class. +* Copy annotations instead of modifying them when adding a scope. +* Make self.__class__ return Any in __init__. +* Check object visibility before setting attributes. + Version 2020.07.20 * pyi parser: support importing TypedDict from typing_extensions. diff --git a/pytype/__version__.py b/pytype/__version__.py index 4c8803d96..43466e32b 100644 --- a/pytype/__version__.py +++ b/pytype/__version__.py @@ -1,2 +1,2 @@ # pylint: skip-file -__version__ = '2020.07.20' +__version__ = '2020.07.24' diff --git a/pytype/annotations_util.py b/pytype/annotations_util.py index 35d8e2380..3124f5d37 100644 --- a/pytype/annotations_util.py +++ b/pytype/annotations_util.py @@ -79,13 +79,15 @@ def add_scope(self, annot, types, module): return new_annot return annot elif isinstance(annot, abstract.TupleClass): - annot.formal_type_parameters[abstract_utils.T] = self.add_scope( + params = dict(annot.formal_type_parameters) + params[abstract_utils.T] = self.add_scope( annot.formal_type_parameters[abstract_utils.T], types, module) - return annot + return abstract.TupleClass( + annot.base_cls, params, self.vm, annot.template) elif isinstance(annot, mixin.NestedAnnotation): - for key, typ in annot.get_inner_types(): - annot.update_inner_type(key, self.add_scope(typ, types, module)) - return annot + inner_types = [(key, self.add_scope(typ, types, module)) + for key, typ in annot.get_inner_types()] + return annot.replace(inner_types) return annot def get_type_parameters(self, annot, seen=None): diff --git a/pytype/pyi/parser.py b/pytype/pyi/parser.py index 1411756c3..f5ece562b 100644 --- a/pytype/pyi/parser.py +++ b/pytype/pyi/parser.py @@ -1169,15 +1169,26 @@ def new_class(self, decorators, class_name, parent_args, defs): for val in constants + aliases + methods + classes} for val in aliases: name = val.name + seen_names = set() while isinstance(val, pytd.Alias): if isinstance(val.type, pytd.NamedType): _, _, base_name = val.type.name.rpartition(".") + if base_name in seen_names: + # This happens in cases like: + # class X: + # Y = something.Y + # Since we try to resolve aliases immediately, we don't know what + # type to fill in when the alias value comes from outside the + # class. The best we can do is Any. + val = pytd.Constant(name, pytd.AnythingType()) + continue + seen_names.add(base_name) if base_name in vals_dict: val = vals_dict[base_name] continue - raise ParseError( - "Illegal value for alias %r. Value must be an attribute " - "on the same class." % val.name) + # The alias value comes from outside the class. The best we can do is + # to fill in Any. + val = pytd.Constant(name, pytd.AnythingType()) if isinstance(val, _NameAndSig): methods.append(val._replace(name=name)) else: diff --git a/pytype/pyi/parser_test.py b/pytype/pyi/parser_test.py index b457645d7..0c041d032 100644 --- a/pytype/pyi/parser_test.py +++ b/pytype/pyi/parser_test.py @@ -374,6 +374,33 @@ def test_alias_lookup(self): x: somewhere.Foo""") + def test_external_alias(self): + self.check(""" + from somewhere import Foo + + class Bar: + Baz = Foo + """, """ + from typing import Any + + from somewhere import Foo + + class Bar: + Baz: Any + """) + + def test_same_named_alias(self): + self.check(""" + import somewhere + class Bar: + Foo = somewhere.Foo + """, """ + from typing import Any + + class Bar: + Foo: Any + """) + def test_type_params(self): ast = self.check(""" from typing import TypeVar @@ -1650,13 +1677,6 @@ class Foo: import foo """, 3, "syntax error") - def test_bad_alias(self): - self.check_error(""" - class Foo: - if sys.version_info > (3, 4, 0): - a = b - """, 1, "Illegal value for alias 'a'") - def test_no_class(self): self.check(""" class Foo: diff --git a/pytype/pytd/pytd_utils.py b/pytype/pytd/pytd_utils.py index 2f4f962e5..a43bc43bc 100644 --- a/pytype/pytd/pytd_utils.py +++ b/pytype/pytd/pytd_utils.py @@ -565,7 +565,7 @@ def MergeBaseClass(cls, base): classes = cls.classes + tuple(c for c in base.classes if c.name not in class_names) if cls.slots: - slots = cls.clots + tuple(s for s in base.slots or () if s not in cls.slots) + slots = cls.slots + tuple(s for s in base.slots or () if s not in cls.slots) else: slots = base.slots return pytd.Class(name=cls.name, diff --git a/pytype/tests/py3/test_generic.py b/pytype/tests/py3/test_generic.py index c20747f86..06d633fb0 100644 --- a/pytype/tests/py3/test_generic.py +++ b/pytype/tests/py3/test_generic.py @@ -579,6 +579,21 @@ class Foo(List["Foo"]): pass """) + def test_type_parameter_count(self): + self.Check(""" + from typing import Generic, List, TypeVar + + T = TypeVar('T') + SomeAlias = List[T] + + class Foo(Generic[T]): + def __init__(self, x: T, y: SomeAlias): + pass + + def f(x: T) -> SomeAlias: + return [x] + """) + class GenericFeatureTest(test_base.TargetPython3FeatureTest): """Tests for User-defined Generic Type.""" diff --git a/pytype/tests/test_attributes.py b/pytype/tests/test_attributes.py index 59584b9fc..ef48ef651 100644 --- a/pytype/tests/test_attributes.py +++ b/pytype/tests/test_attributes.py @@ -503,24 +503,6 @@ def __getattribute__(self, name) -> bool def f(x) -> Any """) - @test_base.skip("TODO(b/63407497): implement strict checking for __setitem__") - def test_union_set_attribute(self): - ty, _ = self.InferWithErrors(""" - class A(object): - x = "Hello world" - def f(i): - t = A() - l = [t] - l[i].x = 1 # not-writable - return l[i].x - """) - self.assertTypesMatchPytd(ty, """ - from typing import Any - class A(object): - x = ... # type: str - def f(i) -> Any - """) - def test_set_class(self): ty = self.Infer(""" def f(x): @@ -846,5 +828,25 @@ def oops(self) -> None: ... """) self.assertErrorRegexes(errors, {"e": r"Annotation: int.*Assignment: str"}) + def test_split(self): + ty = self.Infer(""" + from typing import Union + class Foo: + pass + class Bar: + pass + def f(x): + # type: (Union[Foo, Bar]) -> None + if isinstance(x, Foo): + x.foo = 42 + """) + self.assertTypesMatchPytd(ty, """ + from typing import Union + class Foo: + foo: int + class Bar: ... + def f(x: Union[Foo, Bar]) -> None: ... + """) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tests/test_classes.py b/pytype/tests/test_classes.py index 963eda595..57bc16054 100644 --- a/pytype/tests/test_classes.py +++ b/pytype/tests/test_classes.py @@ -1411,5 +1411,19 @@ class Foo(Any): list(Foo()) """) + def test_instantiate_class(self): + self.Check(""" + import abc + import six + @six.add_metaclass(abc.ABCMeta) + class Foo(object): + def __init__(self, x): + if x > 0: + print(self.__class__(x-1)) + @abc.abstractmethod + def f(self): + pass + """) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tests/test_slots.py b/pytype/tests/test_slots.py index 073b9b91c..698b18476 100644 --- a/pytype/tests/test_slots.py +++ b/pytype/tests/test_slots.py @@ -225,5 +225,18 @@ def __init__(self): """) self.assertErrorRegexes(errors, {"e": r"__baz"}) + def test_union(self): + self.Check(""" + from typing import Union + class Foo(object): + pass + class Bar(object): + __slots__ = () + def f(x): + # type: (Union[Foo, Bar]) -> None + if isinstance(x, Foo): + x.foo = 42 + """) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tools/annotate_ast/annotate_ast.py b/pytype/tools/annotate_ast/annotate_ast.py index b569ed1ec..99f09b304 100644 --- a/pytype/tools/annotate_ast/annotate_ast.py +++ b/pytype/tools/annotate_ast/annotate_ast.py @@ -59,6 +59,9 @@ def visit_Name(self, node): def visit_Attribute(self, node): self._maybe_annotate(node) + def visit_FunctionDef(self, node): + self._maybe_annotate(node) + def _maybe_annotate(self, node): """Annotates a node.""" try: diff --git a/pytype/tools/annotate_ast/annotate_ast_test.py b/pytype/tools/annotate_ast/annotate_ast_test.py index a188d18ff..e921439e1 100644 --- a/pytype/tools/annotate_ast/annotate_ast_test.py +++ b/pytype/tools/annotate_ast/annotate_ast_test.py @@ -1,11 +1,9 @@ import ast -import itertools import textwrap from pytype import config from pytype.tests import test_base from pytype.tools.annotate_ast import annotate_ast -import six class AnnotaterTest(test_base.TargetIndependentTest): @@ -18,28 +16,10 @@ def annotate(self, source): module = annotate_ast.annotate_source(source, ast_factory, pytype_options) return module - def assert_annotations_equal(self, expected, module): - nodes = [ - node for node in ast.walk(module) - if getattr(node, 'resolved_type', None) - ] - actual = {} - for node in nodes: - key = self._get_node_key(node) - actual[key] = '{} :: {!r}'.format(node.resolved_annotation, - node.resolved_type) - - for key in sorted(set(itertools.chain(expected, actual))): - expected_pattern = expected.get(key) - if not expected_pattern: - self.fail('Unexpected annotation: {} -> {}'.format(key, actual[key])) - actual_text = actual.get(key) - if not actual_text: - self.fail( - 'Expected to find node {} annotated, but it was not.'.format(key)) - msg = ('Resolved annotation value does not match {!r}: Node {} annotated ' - 'with {}').format(expected_pattern, key, actual_text) - six.assertRegex(self, actual_text, expected_pattern, msg=msg) + def get_annotations_dict(self, module): + return {self._get_node_key(node): node.resolved_annotation + for node in ast.walk(module) + if hasattr(node, 'resolved_type')} def _get_node_key(self, node): base = (node.lineno, node.__class__.__name__) @@ -48,25 +28,27 @@ def _get_node_key(self, node): return base + (node.id,) elif isinstance(node, ast.Attribute): return base + (node.attr,) + elif isinstance(node, ast.FunctionDef): + return base + (node.name,) else: return base def test_annotating_name(self): source = """ a = 1 - b = {} - c = [] + b = {1: 'foo'} + c = [1, 2, 3] d = 3, 4 """ module = self.annotate(source) expected = { (1, 'Name', 'a'): 'int', - (2, 'Name', 'b'): 'dict', - (3, 'Name', 'c'): 'list', - (4, 'Name', 'd'): 'tuple', + (2, 'Name', 'b'): 'Dict[int, str]', + (3, 'Name', 'c'): 'List[int]', + (4, 'Name', 'd'): 'Tuple[int, int]', } - self.assert_annotations_equal(expected, module) + self.assertEqual(expected, self.get_annotations_dict(module)) def test_annotating_attribute(self): source = """ @@ -84,6 +66,47 @@ def test_annotating_attribute(self): (2, 'Attribute', 'Bar'): 'Any', (2, 'Attribute', 'bar'): 'Any', } - self.assert_annotations_equal(expected, module) + self.assertEqual(expected, self.get_annotations_dict(module)) + + def test_annotating_for(self): + source = """ + for i in 1, 2, 3: + pass + """ + + module = self.annotate(source) + + expected = { + (1, 'Name', 'i'): 'int', + } + self.assertEqual(expected, self.get_annotations_dict(module)) + + def test_annotating_with(self): + source = """ + with foo() as f: + pass + """ + + module = self.annotate(source) + + expected = { + (1, 'Name', 'foo'): 'Any', + (1, 'Name', 'f'): 'Any', + } + self.assertEqual(expected, self.get_annotations_dict(module)) + + def test_annotating_def(self): + source = """ + def foo(a, b): + # type: (str, int) -> str + pass + """ + + module = self.annotate(source) + + expected = { + (1, 'FunctionDef', 'foo'): 'Callable[[str, int], str]', + } + self.assertEqual(expected, self.get_annotations_dict(module)) test_base.main(globals(), __name__ == '__main__') diff --git a/pytype/tools/traces/traces.py b/pytype/tools/traces/traces.py index 7018e1af0..8c4efcd99 100644 --- a/pytype/tools/traces/traces.py +++ b/pytype/tools/traces/traces.py @@ -131,7 +131,7 @@ def __init__(self, matches): def match(self, symbol): for match in self._matches: if isinstance(match, self._PATTERN_TYPE): - if match.match(symbol): + if match.match(str(symbol)): return True elif match == symbol: return True @@ -216,14 +216,21 @@ def match_Call(self, node): for tr in self._get_traces( node.lineno, _CALL_OPS, name, maxmatch=1, num_lines=5)] - def match_Ellipsis(self, node): - return self._match_constant(node, Ellipsis) - def match_Constant(self, node): # As of Python 3.8, bools, numbers, bytes, strings, ellipsis etc are # all constants instead of individual ast nodes. return self._match_constant(node, node.s) + def match_Ellipsis(self, node): + return self._match_constant(node, Ellipsis) + + def match_FunctionDef(self, node): + symbol = _SymbolMatcher.from_regex(r"(%s|None)" % node.name) + return [ + (self._get_match_location(node, tr.symbol), tr) + for tr in self._get_traces(node.lineno, ["MAKE_FUNCTION"], symbol, 1) + ] + def match_Import(self, node): return list(self._match_import(node, is_from=False)) diff --git a/pytype/vm.py b/pytype/vm.py index ed4ec47f9..692a28939 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1391,6 +1391,9 @@ def _del_name(self, op, state, name, local): def _retrieve_attr(self, node, obj, attr): """Load an attribute from an object.""" assert isinstance(obj, cfg.Variable), obj + if (attr == "__class__" and self.callself_stack and + obj.data == self.callself_stack[-1].data): + return node, self.new_unsolvable(node), [] # Resolve the value independently for each value of obj result = self.program.NewVariable() log.debug("getting attr %s from %r", attr, obj) @@ -1508,12 +1511,12 @@ def store_attr(self, state, obj, attr, value): log.info("Ignoring setattr on %r", obj) return state nodes = [] - for val in obj.bindings: + for val in obj.Filter(state.node): # TODO(b/159038991): Check whether val.data is a descriptor (i.e. has # "__set__") nodes.append(self.attribute_handler.set_attribute( state.node, val.data, attr, value)) - return state.change_cfg_node(self.join_cfg_nodes(nodes)) + return state.change_cfg_node(self.join_cfg_nodes(nodes)) if nodes else state def del_attr(self, state, obj, attr): """Delete an attribute."""