Skip to content
This repository was archived by the owner on Jul 11, 2022. It is now read-only.

Commit

Permalink
Improve get_future_imports implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol committed Jul 2, 2018
1 parent 3bdd423 commit dd8bde6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
29 changes: 19 additions & 10 deletions black.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Callable,
Collection,
Dict,
Generator,
Generic,
Iterable,
Iterator,
Expand Down Expand Up @@ -2910,7 +2911,23 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf

def get_future_imports(node: Node) -> Set[str]:
"""Return a set of __future__ imports in the file."""
imports = set()
imports: Set[str] = set()

def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
for child in children:
if isinstance(child, Leaf):
if child.type == token.NAME:
yield child.value
elif child.type == syms.import_as_name:
orig_name = child.children[0]
assert isinstance(orig_name, Leaf), "Invalid syntax parsing imports"
assert orig_name.type == token.NAME, "Invalid syntax parsing imports"
yield orig_name.value
elif child.type == syms.import_as_names:
yield from get_imports_from_children(child.children)
else:
assert False, "Invalid syntax parsing imports"

for child in node.children:
if child.type != syms.simple_stmt:
break
Expand All @@ -2929,15 +2946,7 @@ def get_future_imports(node: Node) -> Set[str]:
module_name = first_child.children[1]
if not isinstance(module_name, Leaf) or module_name.value != "__future__":
break
for import_from_child in first_child.children[3:]:
if isinstance(import_from_child, Leaf):
if import_from_child.type == token.NAME:
imports.add(import_from_child.value)
else:
assert import_from_child.type == syms.import_as_names
for leaf in import_from_child.children:
if isinstance(leaf, Leaf) and leaf.type == token.NAME:
imports.add(leaf.value)
imports |= set(get_imports_from_children(first_child.children[3:]))
else:
break
return imports
Expand Down
8 changes: 6 additions & 2 deletions tests/data/python2_unicode_literals.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#!/usr/bin/env python2
from __future__ import unicode_literals
from __future__ import unicode_literals as _unicode_literals
from __future__ import absolute_import
from __future__ import print_function as lol, with_function

u'hello'
U"hello"
Expand All @@ -9,7 +11,9 @@


#!/usr/bin/env python2
from __future__ import unicode_literals
from __future__ import unicode_literals as _unicode_literals
from __future__ import absolute_import
from __future__ import print_function as lol, with_function

"hello"
"hello"
Expand Down
8 changes: 8 additions & 0 deletions tests/test_black.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,14 @@ def test_get_future_imports(self) -> None:
self.assertEqual(set(), black.get_future_imports(node))
node = black.lib2to3_parse("from some.module import black\n")
self.assertEqual(set(), black.get_future_imports(node))
node = black.lib2to3_parse(
"from __future__ import unicode_literals as _unicode_literals"
)
self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
node = black.lib2to3_parse(
"from __future__ import unicode_literals as _lol, print"
)
self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))

def test_debug_visitor(self) -> None:
source, _ = read_data("debug_visitor.py")
Expand Down

0 comments on commit dd8bde6

Please sign in to comment.