Skip to content

Commit

Permalink
Add support for conditionally defined overloads (#10712)
Browse files Browse the repository at this point in the history
### Description

This PR allows users to define overloads conditionally, e.g., based on the Python version. At the moment this is only possible if all overloads are contained in the same block which requires duplications.

```py
from typing import overload, Any
import sys

class A: ...
class B: ...

if sys.version_info >= (3, 9):
    class C: ...


@overload
def func(g: int) -> A: ...

@overload
def func(g: bytes) -> B: ...

if sys.version_info >= (3, 9):
    @overload
    def func(g: str) -> C: ...

def func(g: Any) -> Any: ...
```

Closes #9744

## Test Plan

Unit tests have been added.

## Limitations
Only `if` is supported. Support for `elif` and `else` might be added in the future. However, I believe that the single if as shown in the example is also the most common use case.

The change itself is fully backwards compatible, i.e. the current workaround (see below) will continue to function as expected.

~~**Update**: Seems like support for `elif` and `else` is required for the tests to pass.~~

**Update**: Added support for `elif` and `else`.

## Current workaround

```py
from typing import overload, Any
import sys

class A: ...
class B: ...

if sys.version_info >= (3, 9):
    class C: ...


if sys.version_info >= (3, 9):
    @overload
    def func(g: int) -> A: ...

    @overload
    def func(g: bytes) -> B: ...

    @overload
    def func(g: str) -> C: ...

    def func(g: Any) -> Any: ...

else:
    @overload
    def func(g: int) -> A: ...

    @overload
    def func(g: bytes) -> B: ...

    def func(g: Any) -> Any: ...
```
  • Loading branch information
cdce8p authored Mar 3, 2022
1 parent 68b3b27 commit 0777c10
Show file tree
Hide file tree
Showing 3 changed files with 1,185 additions and 3 deletions.
108 changes: 108 additions & 0 deletions docs/source/more_types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,114 @@ with ``Union[int, slice]`` and ``Union[T, Sequence]``.
to returning ``Any`` only if the input arguments also contain ``Any``.


Conditional overloads
---------------------

Sometimes it is useful to define overloads conditionally.
Common use cases include types that are unavailable at runtime or that
only exist in a certain Python version. All existing overload rules still apply.
For example, there must be at least two overloads.

.. note::

Mypy can only infer a limited number of conditions.
Supported ones currently include :py:data:`~typing.TYPE_CHECKING`, ``MYPY``,
:ref:`version_and_platform_checks`, and :option:`--always-true <mypy --always-true>`
and :option:`--always-false <mypy --always-false>` values.

.. code-block:: python
from typing import TYPE_CHECKING, Any, overload
if TYPE_CHECKING:
class A: ...
class B: ...
if TYPE_CHECKING:
@overload
def func(var: A) -> A: ...
@overload
def func(var: B) -> B: ...
def func(var: Any) -> Any:
return var
reveal_type(func(A())) # Revealed type is "A"
.. code-block:: python
# flags: --python-version 3.10
import sys
from typing import Any, overload
class A: ...
class B: ...
class C: ...
class D: ...
if sys.version_info < (3, 7):
@overload
def func(var: A) -> A: ...
elif sys.version_info >= (3, 10):
@overload
def func(var: B) -> B: ...
else:
@overload
def func(var: C) -> C: ...
@overload
def func(var: D) -> D: ...
def func(var: Any) -> Any:
return var
reveal_type(func(B())) # Revealed type is "B"
reveal_type(func(C())) # No overload variant of "func" matches argument type "C"
# Possible overload variants:
# def func(var: B) -> B
# def func(var: D) -> D
# Revealed type is "Any"
.. note::

In the last example, mypy is executed with
:option:`--python-version 3.10 <mypy --python-version>`.
Therefore, the condition ``sys.version_info >= (3, 10)`` will match and
the overload for ``B`` will be added.
The overloads for ``A`` and ``C`` are ignored!
The overload for ``D`` is not defined conditionally and thus is also added.

When mypy cannot infer a condition to be always True or always False, an error is emitted.

.. code-block:: python
from typing import Any, overload
class A: ...
class B: ...
def g(bool_var: bool) -> None:
if bool_var: # Condition can't be inferred, unable to merge overloads
@overload
def func(var: A) -> A: ...
@overload
def func(var: B) -> B: ...
def func(var: Any) -> Any: ...
reveal_type(func(A())) # Revealed type is "Any"
.. _advanced_self:

Advanced uses of self-types
Expand Down
199 changes: 196 additions & 3 deletions mypy/fastparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from mypy import message_registry, errorcodes as codes
from mypy.errors import Errors
from mypy.options import Options
from mypy.reachability import mark_block_unreachable
from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable
from mypy.util import bytes_to_human_readable_repr

try:
Expand Down Expand Up @@ -344,9 +344,19 @@ def fail(self,
msg: str,
line: int,
column: int,
blocker: bool = True) -> None:
blocker: bool = True,
code: codes.ErrorCode = codes.SYNTAX) -> None:
if blocker or not self.options.ignore_errors:
self.errors.report(line, column, msg, blocker=blocker, code=codes.SYNTAX)
self.errors.report(line, column, msg, blocker=blocker, code=code)

def fail_merge_overload(self, node: IfStmt) -> None:
self.fail(
"Condition can't be inferred, unable to merge overloads",
line=node.line,
column=node.column,
blocker=False,
code=codes.MISC,
)

def visit(self, node: Optional[AST]) -> Any:
if node is None:
Expand Down Expand Up @@ -476,12 +486,93 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
ret: List[Statement] = []
current_overload: List[OverloadPart] = []
current_overload_name: Optional[str] = None
last_if_stmt: Optional[IfStmt] = None
last_if_overload: Optional[Union[Decorator, FuncDef, OverloadedFuncDef]] = None
last_if_stmt_overload_name: Optional[str] = None
last_if_unknown_truth_value: Optional[IfStmt] = None
skipped_if_stmts: List[IfStmt] = []
for stmt in stmts:
if_overload_name: Optional[str] = None
if_block_with_overload: Optional[Block] = None
if_unknown_truth_value: Optional[IfStmt] = None
if (
isinstance(stmt, IfStmt)
and len(stmt.body[0].body) == 1
and (
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
or current_overload_name is not None
and isinstance(stmt.body[0].body[0], FuncDef)
)
):
# Check IfStmt block to determine if function overloads can be merged
if_overload_name = self._check_ifstmt_for_overloads(stmt)
if if_overload_name is not None:
if_block_with_overload, if_unknown_truth_value = \
self._get_executable_if_block_with_overloads(stmt)

if (current_overload_name is not None
and isinstance(stmt, (Decorator, FuncDef))
and stmt.name == current_overload_name):
if last_if_stmt is not None:
skipped_if_stmts.append(last_if_stmt)
if last_if_overload is not None:
# Last stmt was an IfStmt with same overload name
# Add overloads to current_overload
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if last_if_unknown_truth_value:
self.fail_merge_overload(last_if_unknown_truth_value)
last_if_unknown_truth_value = None
current_overload.append(stmt)
elif (
current_overload_name is not None
and isinstance(stmt, IfStmt)
and if_overload_name == current_overload_name
):
# IfStmt only contains stmts relevant to current_overload.
# Check if stmts are reachable and add them to current_overload,
# otherwise skip IfStmt to allow subsequent overload
# or function definitions.
skipped_if_stmts.append(stmt)
if if_block_with_overload is None:
if if_unknown_truth_value is not None:
self.fail_merge_overload(if_unknown_truth_value)
continue
if last_if_overload is not None:
# Last stmt was an IfStmt with same overload name
# Add overloads to current_overload
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if isinstance(if_block_with_overload.body[0], OverloadedFuncDef):
current_overload.extend(if_block_with_overload.body[0].items)
else:
current_overload.append(
cast(Union[Decorator, FuncDef], if_block_with_overload.body[0])
)
else:
if last_if_stmt is not None:
ret.append(last_if_stmt)
last_if_stmt_overload_name = current_overload_name
last_if_stmt, last_if_overload = None, None
last_if_unknown_truth_value = None

if current_overload and current_overload_name == last_if_stmt_overload_name:
# Remove last stmt (IfStmt) from ret if the overload names matched
# Only happens if no executable block had been found in IfStmt
skipped_if_stmts.append(cast(IfStmt, ret.pop()))
if current_overload and skipped_if_stmts:
# Add bare IfStmt (without overloads) to ret
# Required for mypy to be able to still check conditions
for if_stmt in skipped_if_stmts:
self._strip_contents_from_if_stmt(if_stmt)
ret.append(if_stmt)
skipped_if_stmts = []
if len(current_overload) == 1:
ret.append(current_overload[0])
elif len(current_overload) > 1:
Expand All @@ -495,17 +586,119 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]:
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
current_overload = [stmt]
current_overload_name = stmt.name
elif (
isinstance(stmt, IfStmt)
and if_overload_name is not None
):
current_overload = []
current_overload_name = if_overload_name
last_if_stmt = stmt
last_if_stmt_overload_name = None
if if_block_with_overload is not None:
last_if_overload = cast(
Union[Decorator, FuncDef, OverloadedFuncDef],
if_block_with_overload.body[0]
)
last_if_unknown_truth_value = if_unknown_truth_value
else:
current_overload = []
current_overload_name = None
ret.append(stmt)

if current_overload and skipped_if_stmts:
# Add bare IfStmt (without overloads) to ret
# Required for mypy to be able to still check conditions
for if_stmt in skipped_if_stmts:
self._strip_contents_from_if_stmt(if_stmt)
ret.append(if_stmt)
if len(current_overload) == 1:
ret.append(current_overload[0])
elif len(current_overload) > 1:
ret.append(OverloadedFuncDef(current_overload))
elif last_if_stmt is not None:
ret.append(last_if_stmt)
return ret

def _check_ifstmt_for_overloads(self, stmt: IfStmt) -> Optional[str]:
"""Check if IfStmt contains only overloads with the same name.
Return overload_name if found, None otherwise.
"""
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
# Multiple overloads have already been merged as OverloadedFuncDef.
if not (
len(stmt.body[0].body) == 1
and isinstance(stmt.body[0].body[0], (Decorator, FuncDef, OverloadedFuncDef))
):
return None

overload_name = stmt.body[0].body[0].name
if stmt.else_body is None:
return overload_name

if isinstance(stmt.else_body, Block) and len(stmt.else_body.body) == 1:
# For elif: else_body contains an IfStmt itself -> do a recursive check.
if (
isinstance(stmt.else_body.body[0], (Decorator, FuncDef, OverloadedFuncDef))
and stmt.else_body.body[0].name == overload_name
):
return overload_name
if (
isinstance(stmt.else_body.body[0], IfStmt)
and self._check_ifstmt_for_overloads(stmt.else_body.body[0]) == overload_name
):
return overload_name

return None

def _get_executable_if_block_with_overloads(
self, stmt: IfStmt
) -> Tuple[Optional[Block], Optional[IfStmt]]:
"""Return block from IfStmt that will get executed.
Return
0 -> A block if sure that alternative blocks are unreachable.
1 -> An IfStmt if the reachability of it can't be inferred,
i.e. the truth value is unknown.
"""
infer_reachability_of_if_statement(stmt, self.options)
if (
stmt.else_body is None
and stmt.body[0].is_unreachable is True
):
# always False condition with no else
return None, None
if (
stmt.else_body is None
or stmt.body[0].is_unreachable is False
and stmt.else_body.is_unreachable is False
):
# The truth value is unknown, thus not conclusive
return None, stmt
if stmt.else_body.is_unreachable is True:
# else_body will be set unreachable if condition is always True
return stmt.body[0], None
if stmt.body[0].is_unreachable is True:
# body will be set unreachable if condition is always False
# else_body can contain an IfStmt itself (for elif) -> do a recursive check
if isinstance(stmt.else_body.body[0], IfStmt):
return self._get_executable_if_block_with_overloads(stmt.else_body.body[0])
return stmt.else_body, None
return None, stmt

def _strip_contents_from_if_stmt(self, stmt: IfStmt) -> None:
"""Remove contents from IfStmt.
Needed to still be able to check the conditions after the contents
have been merged with the surrounding function overloads.
"""
if len(stmt.body) == 1:
stmt.body[0].body = []
if stmt.else_body and len(stmt.else_body.body) == 1:
if isinstance(stmt.else_body.body[0], IfStmt):
self._strip_contents_from_if_stmt(stmt.else_body.body[0])
else:
stmt.else_body.body = []

def in_method_scope(self) -> bool:
return self.class_and_function_stack[-2:] == ['C', 'F']

Expand Down
Loading

0 comments on commit 0777c10

Please sign in to comment.