5
5
import contextlib
6
6
import types
7
7
import importlib
8
+ import inspect
8
9
import warnings
10
+ import itertools
9
11
10
12
from typing import Union , Optional , cast
11
13
from .abc import ResourceReader , Traversable
@@ -22,12 +24,9 @@ def package_to_anchor(func):
22
24
23
25
Other errors should fall through.
24
26
25
- >>> files()
26
- Traceback (most recent call last):
27
- TypeError: files() missing 1 required positional argument: 'anchor'
28
27
>>> files('a', 'b')
29
28
Traceback (most recent call last):
30
- TypeError: files() takes 1 positional argument but 2 were given
29
+ TypeError: files() takes from 0 to 1 positional arguments but 2 were given
31
30
"""
32
31
undefined = object ()
33
32
@@ -50,7 +49,7 @@ def wrapper(anchor=undefined, package=undefined):
50
49
51
50
52
51
@package_to_anchor
53
- def files (anchor : Anchor ) -> Traversable :
52
+ def files (anchor : Optional [ Anchor ] = None ) -> Traversable :
54
53
"""
55
54
Get a Traversable resource for an anchor.
56
55
"""
@@ -74,7 +73,7 @@ def get_resource_reader(package: types.ModuleType) -> Optional[ResourceReader]:
74
73
75
74
76
75
@functools .singledispatch
77
- def resolve (cand : Anchor ) -> types .ModuleType :
76
+ def resolve (cand : Optional [ Anchor ] ) -> types .ModuleType :
78
77
return cast (types .ModuleType , cand )
79
78
80
79
@@ -83,6 +82,28 @@ def _(cand: str) -> types.ModuleType:
83
82
return importlib .import_module (cand )
84
83
85
84
85
+ @resolve .register
86
+ def _ (cand : None ) -> types .ModuleType :
87
+ return resolve (_infer_caller ().f_globals ['__name__' ])
88
+
89
+
90
+ def _infer_caller ():
91
+ """
92
+ Walk the stack and find the frame of the first caller not in this module.
93
+ """
94
+
95
+ def is_this_file (frame_info ):
96
+ return frame_info .filename == __file__
97
+
98
+ def is_wrapper (frame_info ):
99
+ return frame_info .function == 'wrapper'
100
+
101
+ not_this_file = itertools .filterfalse (is_this_file , inspect .stack ())
102
+ # also exclude 'wrapper' due to singledispatch in the call stack
103
+ callers = itertools .filterfalse (is_wrapper , not_this_file )
104
+ return next (callers ).frame
105
+
106
+
86
107
def from_package (package : types .ModuleType ):
87
108
"""
88
109
Return a Traversable object for the given package.
0 commit comments