Skip to content

Commit a65c29a

Browse files
committed
Prioritize valid dists to invalid dists when retrieving by name.
Closes #489
1 parent 48f6b14 commit a65c29a

File tree

4 files changed

+111
-3
lines changed

4 files changed

+111
-3
lines changed

importlib_metadata/__init__.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
install,
2626
)
2727
from ._functools import method_cache, pass_none
28-
from ._itertools import always_iterable, unique_everseen
28+
from ._itertools import always_iterable, bucket, unique_everseen
2929
from ._meta import PackageMetadata, SimplePath
3030

3131
from contextlib import suppress
@@ -388,7 +388,7 @@ def from_name(cls, name: str) -> Distribution:
388388
if not name:
389389
raise ValueError("A distribution name is required.")
390390
try:
391-
return next(iter(cls.discover(name=name)))
391+
return next(iter(cls._prefer_valid(cls.discover(name=name))))
392392
except StopIteration:
393393
raise PackageNotFoundError(name)
394394

@@ -412,6 +412,16 @@ def discover(
412412
resolver(context) for resolver in cls._discover_resolvers()
413413
)
414414

415+
@staticmethod
416+
def _prefer_valid(dists: Iterable[Distribution]) -> Iterable[Distribution]:
417+
"""
418+
Prefer (move to the front) distributions that have metadata.
419+
420+
Ref python/importlib_resources#489.
421+
"""
422+
buckets = bucket(dists, lambda dist: bool(dist.metadata))
423+
return itertools.chain(buckets[True], buckets[False])
424+
415425
@staticmethod
416426
def at(path: str | os.PathLike[str]) -> Distribution:
417427
"""Return a Distribution for the indicated metadata path.

importlib_metadata/_itertools.py

+98
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict, deque
12
from itertools import filterfalse
23

34

@@ -71,3 +72,100 @@ def always_iterable(obj, base_type=(str, bytes)):
7172
return iter(obj)
7273
except TypeError:
7374
return iter((obj,))
75+
76+
77+
# Copied from more_itertools 10.3
78+
class bucket:
79+
"""Wrap *iterable* and return an object that buckets the iterable into
80+
child iterables based on a *key* function.
81+
82+
>>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3']
83+
>>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character
84+
>>> sorted(list(s)) # Get the keys
85+
['a', 'b', 'c']
86+
>>> a_iterable = s['a']
87+
>>> next(a_iterable)
88+
'a1'
89+
>>> next(a_iterable)
90+
'a2'
91+
>>> list(s['b'])
92+
['b1', 'b2', 'b3']
93+
94+
The original iterable will be advanced and its items will be cached until
95+
they are used by the child iterables. This may require significant storage.
96+
97+
By default, attempting to select a bucket to which no items belong will
98+
exhaust the iterable and cache all values.
99+
If you specify a *validator* function, selected buckets will instead be
100+
checked against it.
101+
102+
>>> from itertools import count
103+
>>> it = count(1, 2) # Infinite sequence of odd numbers
104+
>>> key = lambda x: x % 10 # Bucket by last digit
105+
>>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only
106+
>>> s = bucket(it, key=key, validator=validator)
107+
>>> 2 in s
108+
False
109+
>>> list(s[2])
110+
[]
111+
112+
"""
113+
114+
def __init__(self, iterable, key, validator=None):
115+
self._it = iter(iterable)
116+
self._key = key
117+
self._cache = defaultdict(deque)
118+
self._validator = validator or (lambda x: True)
119+
120+
def __contains__(self, value):
121+
if not self._validator(value):
122+
return False
123+
124+
try:
125+
item = next(self[value])
126+
except StopIteration:
127+
return False
128+
else:
129+
self._cache[value].appendleft(item)
130+
131+
return True
132+
133+
def _get_values(self, value):
134+
"""
135+
Helper to yield items from the parent iterator that match *value*.
136+
Items that don't match are stored in the local cache as they
137+
are encountered.
138+
"""
139+
while True:
140+
# If we've cached some items that match the target value, emit
141+
# the first one and evict it from the cache.
142+
if self._cache[value]:
143+
yield self._cache[value].popleft()
144+
# Otherwise we need to advance the parent iterator to search for
145+
# a matching item, caching the rest.
146+
else:
147+
while True:
148+
try:
149+
item = next(self._it)
150+
except StopIteration:
151+
return
152+
item_value = self._key(item)
153+
if item_value == value:
154+
yield item
155+
break
156+
elif self._validator(item_value):
157+
self._cache[item_value].append(item)
158+
159+
def __iter__(self):
160+
for item in self._it:
161+
item_value = self._key(item)
162+
if self._validator(item_value):
163+
self._cache[item_value].append(item)
164+
165+
yield from self._cache.keys()
166+
167+
def __getitem__(self, value):
168+
if not self._validator(value):
169+
return iter(())
170+
171+
return self._get_values(value)

newsfragments/489.feature.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Prioritize valid dists to invalid dists when retrieving by name.

tests/test_main.py

-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def make_pkg(name, files=dict(METADATA="VERSION: 1.0")):
140140
f'{name}.dist-info': files,
141141
}
142142

143-
@__import__('pytest').mark.xfail(reason="#489")
144143
def test_valid_dists_preferred(self):
145144
"""
146145
Dists with metadata should be preferred when discovered by name.

0 commit comments

Comments
 (0)