Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generics in rtd_github_links #160

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/scanpydoc/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from __future__ import annotations

from typing import Generic, TypeVar


_GenericAlias: type = type(Generic[TypeVar("_")])
6 changes: 2 additions & 4 deletions src/scanpydoc/elegant_typehints/_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import sys
import inspect
from types import GenericAlias
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin
from typing import TYPE_CHECKING, Any, cast, get_args, get_origin

from sphinx_autodoc_typehints import format_annotation

from scanpydoc import elegant_typehints
from scanpydoc._types import _GenericAlias


if TYPE_CHECKING:
Expand All @@ -20,9 +21,6 @@
UnionType = None


_GenericAlias: type = type(Generic[TypeVar("_")])


def typehints_formatter(annotation: type[Any], config: Config) -> str | None:
"""Generate reStructuredText containing links to the types.

Expand Down
8 changes: 5 additions & 3 deletions src/scanpydoc/rtd_github_links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

import sys
import inspect
from types import ModuleType
from types import ModuleType, GenericAlias
from typing import TYPE_CHECKING
from pathlib import Path, PurePosixPath
from importlib import import_module
Expand All @@ -70,6 +70,7 @@
from jinja2.defaults import DEFAULT_FILTERS # type: ignore[attr-defined]

from scanpydoc import metadata, _setup_sig
from scanpydoc._types import _GenericAlias


if TYPE_CHECKING:
Expand Down Expand Up @@ -158,7 +159,7 @@ def _get_obj_module(qualname: str) -> tuple[Any, ModuleType]:
raise e from None
if isinstance(thing, ModuleType): # pragma: no cover
mod = thing
elif is_dataclass(obj):
elif is_dataclass(obj) or isinstance(thing, (GenericAlias, _GenericAlias)):
obj = thing
else:
obj = thing
Expand Down Expand Up @@ -186,7 +187,8 @@ def _module_path(obj: _SourceObjectType, module: ModuleType) -> PurePosixPath:
try:
file = Path(inspect.getabsfile(obj))
except TypeError:
file = Path(module.__file__ or "")
# Some don’t have the attribute, some have it set to None
file = Path(getattr(module, "__file__", None) or "")
offset = -1 if file.name == "__init__.py" else 0
parts = module.__name__.split(".")
return PurePosixPath(*file.parts[offset - len(parts) :])
Expand Down
20 changes: 20 additions & 0 deletions src/scanpydoc/rtd_github_links/_testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,31 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Generic, TypeVar
from dataclasses import field, dataclass

from legacy_api_wrap import legacy_api


if TYPE_CHECKING:
from typing import TypeAlias


_T = TypeVar("_T")


class _G(Generic[_T]):
pass


# make sure that TestGenericClass keeps its __module__
_G.__module__ = "somewhere_else"


TestGenericBuiltin: TypeAlias = list[str]
TestGenericClass: TypeAlias = _G[int]


@dataclass
class TestDataCls:
test_attr: dict[str, str] = field(default_factory=dict)
Expand Down
14 changes: 14 additions & 0 deletions tests/test_rtd_github_links.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,20 @@ def test_get_github_url_error() -> None:
"scanpydoc/rtd_github_links/_testdata.py",
id="anno",
),
pytest.param(
"scanpydoc.rtd_github_links._testdata.TestGenericBuiltin",
_testdata.TestGenericBuiltin,
_testdata,
"scanpydoc/rtd_github_links/_testdata.py",
id="generic_builtin",
),
pytest.param(
"scanpydoc.rtd_github_links._testdata.TestGenericClass",
_testdata.TestGenericClass,
_testdata,
"scanpydoc/rtd_github_links/_testdata.py",
id="generic_class",
),
],
)
def test_get_obj_module_path(
Expand Down
Loading