Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Allow for ignoring some arguments when caching. #12189

Merged
merged 9 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def __init__(

# If there are args to not cache on, filter them out (and fix the size of num_args).
if uncached_args is not None:
self.num_args -= len(uncached_args)
self.arg_names = [n for n in self.arg_names if n not in uncached_args]
include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names]
else:
include_arg_in_cache_key = [True] * len(self.arg_names)

# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
Expand All @@ -137,7 +138,7 @@ def __init__(
self.add_cache_context = cache_context

self.cache_key_builder = _get_cache_key_builder(
self.arg_names, self.arg_defaults
self.arg_names, include_arg_in_cache_key, self.arg_defaults
)


Expand Down Expand Up @@ -621,12 +622,15 @@ def batch_do_something(self, first_arg, second_args):


def _get_cache_key_builder(
param_names: Sequence[str], param_defaults: Mapping[str, Any]
param_names: Sequence[str],
include_params: Sequence[bool],
param_defaults: Mapping[str, Any],
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
"""Construct a function which will build cache keys suitable for a cached function

Args:
param_names: list of formal parameter names for the cached function
include_params: list of bools of whether to include the parameter name in the cache key
param_defaults: a mapping from parameter name to default value for that param

Returns:
Expand All @@ -638,6 +642,7 @@ def _get_cache_key_builder(

if len(param_names) == 1:
nm = param_names[0]
assert include_params[0] is True

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
if nm in kwargs:
Expand All @@ -650,13 +655,18 @@ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
else:

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
return tuple(
_get_cache_key_gen(
param_names, include_params, param_defaults, args, kwargs
)
)

return get_cache_key


def _get_cache_key_gen(
param_names: Iterable[str],
include_params: Iterable[bool],
param_defaults: Mapping[str, Any],
args: Sequence[Any],
kwargs: Mapping[str, Any],
Expand All @@ -667,16 +677,21 @@ def _get_cache_key_gen(
This is essentially the same operation as `inspect.getcallargs`, but optimised so
that we don't need to inspect the target function for each call.
"""
if param_names == ():
pass

# We loop through each arg name, looking up if its in the `kwargs`,
# otherwise using the next argument in `args`. If there are no more
# args then we try looking the arg name up in the defaults.
pos = 0
for nm in param_names:
for nm, inc in zip(param_names, include_params):
if nm in kwargs:
yield kwargs[nm]
if inc:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
if inc:
yield args[pos]
pos += 1
else:
yield param_defaults[nm]
if inc:
yield param_defaults[nm]
18 changes: 10 additions & 8 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,32 +151,34 @@ def test_cache_uncached_args(self):
"""

class Cls:
# Note that it is important that this is not the last argument to
# test behaviour of skipping arguments properly.
@descriptors.cached(uncached_args=("arg2",))
def fn(self, arg1, arg2):
return self.mock(arg1, arg2)
def fn(self, arg1, arg2, arg3):
return self.mock(arg1, arg2, arg3)

def __init__(self):
self.mock = mock.Mock()

obj = Cls()
obj.mock.return_value = "fish"
r = yield obj.fn(1, 2)
r = yield obj.fn(1, 2, 3)
self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, 2)
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = "chips"
r = yield obj.fn(2, 3)
r = yield obj.fn(2, 3, 4)
self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(2, 3)
obj.mock.assert_called_once_with(2, 3, 4)
obj.mock.reset_mock()

# the two values should now be cached; we should be able to vary
# the second argument and still get the cached result.
r = yield obj.fn(1, 4)
r = yield obj.fn(1, 4, 3)
self.assertEqual(r, "fish")
r = yield obj.fn(2, 5)
r = yield obj.fn(2, 5, 4)
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

Expand Down