Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Allow for ignoring some arguments when caching. (#12189)
Browse files Browse the repository at this point in the history
* `@cached` can now take an `uncached_args` which is an iterable of names to not use in the cache key.
* Requires `@cached`, @cachedList` and `@lru_cache` to use keyword arguments for clarity.
* Asserts that keyword-only arguments in cached functions are not accepted. (I tested this briefly and I don't believe this works properly.)
  • Loading branch information
clokep authored Mar 9, 2022
1 parent 0326888 commit 690cb4f
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog.d/12189.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support skipping some arguments when generating cache keys.
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ async def have_seen_events(
)
return {eid for ((_rid, eid), have_event) in res.items() if have_event}

@cachedList("have_seen_event", "keys")
@cachedList(cached_method_name="have_seen_event", list_name="keys")
async def _have_seen_events_dict(
self, keys: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
Expand Down Expand Up @@ -1954,7 +1954,7 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
get_event_id_for_timestamp_txn,
)

@cachedList("is_partial_state_event", list_name="event_ids")
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
Expand Down
74 changes: 58 additions & 16 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Generic,
Hashable,
Expand Down Expand Up @@ -69,13 +70,21 @@ def __init__(
self,
orig: Callable[..., Any],
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
):
self.orig = orig

arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args

# There's no reason that keyword-only arguments couldn't be supported,
# but right now they're buggy so do not allow them.
if arg_spec.kwonlyargs:
raise ValueError(
"_CacheDescriptorBase does not support keyword-only arguments."
)

if "cache_context" in all_args:
if not cache_context:
raise ValueError(
Expand All @@ -88,6 +97,9 @@ def __init__(
" named `cache_context`"
)

if num_args is not None and uncached_args is not None:
raise ValueError("Cannot provide both num_args and uncached_args")

if num_args is None:
num_args = len(all_args) - 1
if cache_context:
Expand All @@ -105,6 +117,12 @@ def __init__(
# list of the names of the args used as the cache key
self.arg_names = all_args[1 : num_args + 1]

# If there are args to not cache on, filter them out (and fix the size of num_args).
if uncached_args is not None:
include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names]
else:
include_arg_in_cache_key = [True] * len(self.arg_names)

# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
Expand All @@ -119,8 +137,8 @@ def __init__(

self.add_cache_context = cache_context

self.cache_key_builder = get_cache_key_builder(
self.arg_names, self.arg_defaults
self.cache_key_builder = _get_cache_key_builder(
self.arg_names, include_arg_in_cache_key, self.arg_defaults
)


Expand All @@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]):


def lru_cache(
max_entries: int = 1000,
cache_context: bool = False,
*, max_entries: int = 1000, cache_context: bool = False
) -> Callable[[F], _LruCachedFunction[F]]:
"""A method decorator that applies a memoizing cache around the function.
Expand Down Expand Up @@ -186,7 +203,9 @@ def __init__(
max_entries: int = 1000,
cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
super().__init__(
orig, num_args=None, uncached_args=None, cache_context=cache_context
)
self.max_entries = max_entries

def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
Expand Down Expand Up @@ -260,6 +279,9 @@ def foo(self, key, cache_context):
num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
uncached_args: a list of argument names to not use as the cache key.
(``self`` and ``cache_context`` are always ignored.) Cannot be used
with num_args.
tree:
cache_context:
iterable:
Expand All @@ -273,12 +295,18 @@ def __init__(
orig: Callable[..., Any],
max_entries: int = 1000,
num_args: Optional[int] = None,
uncached_args: Optional[Collection[str]] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
)

if tree and self.num_args < 2:
raise RuntimeError(
Expand Down Expand Up @@ -369,7 +397,7 @@ def __init__(
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
super().__init__(orig, num_args=num_args)
super().__init__(orig, num_args=num_args, uncached_args=None)

self.list_name = list_name

Expand Down Expand Up @@ -530,8 +558,10 @@ def get_instance(


def cached(
*,
max_entries: int = 1000,
num_args: Optional[int] = None,
uncached_args: Optional[Collection[str]] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
Expand All @@ -541,6 +571,7 @@ def cached(
orig,
max_entries=max_entries,
num_args=num_args,
uncached_args=uncached_args,
tree=tree,
cache_context=cache_context,
iterable=iterable,
Expand All @@ -551,7 +582,7 @@ def cached(


def cachedList(
cached_method_name: str, list_name: str, num_args: Optional[int] = None
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Expand Down Expand Up @@ -590,13 +621,16 @@ def batch_do_something(self, first_arg, second_args):
return cast(Callable[[F], _CachedFunction[F]], func)


def get_cache_key_builder(
param_names: Sequence[str], param_defaults: Mapping[str, Any]
def _get_cache_key_builder(
param_names: Sequence[str],
include_params: Sequence[bool],
param_defaults: Mapping[str, Any],
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
"""Construct a function which will build cache keys suitable for a cached function
Args:
param_names: list of formal parameter names for the cached function
include_params: list of bools of whether to include the parameter name in the cache key
param_defaults: a mapping from parameter name to default value for that param
Returns:
Expand All @@ -608,6 +642,7 @@ def get_cache_key_builder(

if len(param_names) == 1:
nm = param_names[0]
assert include_params[0] is True

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
if nm in kwargs:
Expand All @@ -620,13 +655,18 @@ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
else:

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
return tuple(
_get_cache_key_gen(
param_names, include_params, param_defaults, args, kwargs
)
)

return get_cache_key


def _get_cache_key_gen(
param_names: Iterable[str],
include_params: Iterable[bool],
param_defaults: Mapping[str, Any],
args: Sequence[Any],
kwargs: Mapping[str, Any],
Expand All @@ -637,16 +677,18 @@ def _get_cache_key_gen(
This is essentially the same operation as `inspect.getcallargs`, but optimised so
that we don't need to inspect the target function for each call.
"""

# We loop through each arg name, looking up if its in the `kwargs`,
# otherwise using the next argument in `args`. If there are no more
# args then we try looking the arg name up in the defaults.
pos = 0
for nm in param_names:
for nm, inc in zip(param_names, include_params):
if nm in kwargs:
yield kwargs[nm]
if inc:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
if inc:
yield args[pos]
pos += 1
else:
yield param_defaults[nm]
if inc:
yield param_defaults[nm]
84 changes: 81 additions & 3 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,84 @@ def fn(self, arg1, arg2):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

@defer.inlineCallbacks
def test_cache_uncached_args(self):
"""
Only the arguments not named in uncached_args should matter to the cache
Note that this is identical to test_cache_num_args, but provides the
arguments differently.
"""

class Cls:
# Note that it is important that this is not the last argument to
# test behaviour of skipping arguments properly.
@descriptors.cached(uncached_args=("arg2",))
def fn(self, arg1, arg2, arg3):
return self.mock(arg1, arg2, arg3)

def __init__(self):
self.mock = mock.Mock()

obj = Cls()
obj.mock.return_value = "fish"
r = yield obj.fn(1, 2, 3)
self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = "chips"
r = yield obj.fn(2, 3, 4)
self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(2, 3, 4)
obj.mock.reset_mock()

# the two values should now be cached; we should be able to vary
# the second argument and still get the cached result.
r = yield obj.fn(1, 4, 3)
self.assertEqual(r, "fish")
r = yield obj.fn(2, 5, 4)
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

@defer.inlineCallbacks
def test_cache_kwargs(self):
"""Test that keyword arguments are treated properly"""

class Cls:
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, kwarg1=2):
return self.mock(arg1, kwarg1=kwarg1)

obj = Cls()
obj.mock.return_value = "fish"
r = yield obj.fn(1, kwarg1=2)
self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, kwarg1=2)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = "chips"
r = yield obj.fn(1, kwarg1=3)
self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(1, kwarg1=3)
obj.mock.reset_mock()

# the values should now be cached.
r = yield obj.fn(1, kwarg1=2)
self.assertEqual(r, "fish")
# We should be able to not provide kwarg1 and get the cached value back.
r = yield obj.fn(1)
self.assertEqual(r, "fish")
# Keyword arguments can be in any order.
r = yield obj.fn(kwarg1=2, arg1=1)
self.assertEqual(r, "fish")
obj.mock.assert_not_called()

def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work"""

Expand Down Expand Up @@ -656,7 +734,7 @@ def __init__(self):
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1")
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
assert current_context().name == "c1"
# we want this to behave like an asynchronous function
Expand Down Expand Up @@ -715,7 +793,7 @@ def __init__(self):
def fn(self, arg1):
pass

@descriptors.cachedList("fn", "args1")
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1) -> "Deferred[dict]":
return self.mock(args1)

Expand Down Expand Up @@ -758,7 +836,7 @@ def __init__(self):
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1")
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
await run_on_reactor()
Expand Down

0 comments on commit 690cb4f

Please sign in to comment.