|
| 1 | +from collections import defaultdict, deque |
1 | 2 | from itertools import filterfalse
|
2 | 3 |
|
3 | 4 |
|
@@ -71,3 +72,100 @@ def always_iterable(obj, base_type=(str, bytes)):
|
71 | 72 | return iter(obj)
|
72 | 73 | except TypeError:
|
73 | 74 | return iter((obj,))
|
| 75 | + |
| 76 | + |
| 77 | +# Copied from more_itertools 10.3 |
| 78 | +class bucket: |
| 79 | + """Wrap *iterable* and return an object that buckets the iterable into |
| 80 | + child iterables based on a *key* function. |
| 81 | +
|
| 82 | + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] |
| 83 | + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character |
| 84 | + >>> sorted(list(s)) # Get the keys |
| 85 | + ['a', 'b', 'c'] |
| 86 | + >>> a_iterable = s['a'] |
| 87 | + >>> next(a_iterable) |
| 88 | + 'a1' |
| 89 | + >>> next(a_iterable) |
| 90 | + 'a2' |
| 91 | + >>> list(s['b']) |
| 92 | + ['b1', 'b2', 'b3'] |
| 93 | +
|
| 94 | + The original iterable will be advanced and its items will be cached until |
| 95 | + they are used by the child iterables. This may require significant storage. |
| 96 | +
|
| 97 | + By default, attempting to select a bucket to which no items belong will |
| 98 | + exhaust the iterable and cache all values. |
| 99 | + If you specify a *validator* function, selected buckets will instead be |
| 100 | + checked against it. |
| 101 | +
|
| 102 | + >>> from itertools import count |
| 103 | + >>> it = count(1, 2) # Infinite sequence of odd numbers |
| 104 | + >>> key = lambda x: x % 10 # Bucket by last digit |
| 105 | + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only |
| 106 | + >>> s = bucket(it, key=key, validator=validator) |
| 107 | + >>> 2 in s |
| 108 | + False |
| 109 | + >>> list(s[2]) |
| 110 | + [] |
| 111 | +
|
| 112 | + """ |
| 113 | + |
| 114 | + def __init__(self, iterable, key, validator=None): |
| 115 | + self._it = iter(iterable) |
| 116 | + self._key = key |
| 117 | + self._cache = defaultdict(deque) |
| 118 | + self._validator = validator or (lambda x: True) |
| 119 | + |
| 120 | + def __contains__(self, value): |
| 121 | + if not self._validator(value): |
| 122 | + return False |
| 123 | + |
| 124 | + try: |
| 125 | + item = next(self[value]) |
| 126 | + except StopIteration: |
| 127 | + return False |
| 128 | + else: |
| 129 | + self._cache[value].appendleft(item) |
| 130 | + |
| 131 | + return True |
| 132 | + |
| 133 | + def _get_values(self, value): |
| 134 | + """ |
| 135 | + Helper to yield items from the parent iterator that match *value*. |
| 136 | + Items that don't match are stored in the local cache as they |
| 137 | + are encountered. |
| 138 | + """ |
| 139 | + while True: |
| 140 | + # If we've cached some items that match the target value, emit |
| 141 | + # the first one and evict it from the cache. |
| 142 | + if self._cache[value]: |
| 143 | + yield self._cache[value].popleft() |
| 144 | + # Otherwise we need to advance the parent iterator to search for |
| 145 | + # a matching item, caching the rest. |
| 146 | + else: |
| 147 | + while True: |
| 148 | + try: |
| 149 | + item = next(self._it) |
| 150 | + except StopIteration: |
| 151 | + return |
| 152 | + item_value = self._key(item) |
| 153 | + if item_value == value: |
| 154 | + yield item |
| 155 | + break |
| 156 | + elif self._validator(item_value): |
| 157 | + self._cache[item_value].append(item) |
| 158 | + |
| 159 | + def __iter__(self): |
| 160 | + for item in self._it: |
| 161 | + item_value = self._key(item) |
| 162 | + if self._validator(item_value): |
| 163 | + self._cache[item_value].append(item) |
| 164 | + |
| 165 | + yield from self._cache.keys() |
| 166 | + |
| 167 | + def __getitem__(self, value): |
| 168 | + if not self._validator(value): |
| 169 | + return iter(()) |
| 170 | + |
| 171 | + return self._get_values(value) |
0 commit comments