Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Allow for ignoring some arguments when caching. #12189

Merged
merged 9 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12189.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support skipping some arguments when generating cache keys.
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ async def have_seen_events(
)
return {eid for ((_rid, eid), have_event) in res.items() if have_event}

@cachedList("have_seen_event", "keys")
@cachedList(cached_method_name="have_seen_event", list_name="keys")
async def _have_seen_events_dict(
self, keys: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
Expand Down Expand Up @@ -1954,7 +1954,7 @@ def get_event_id_for_timestamp_txn(txn: LoggingTransaction) -> Optional[str]:
get_event_id_for_timestamp_txn,
)

@cachedList("is_partial_state_event", list_name="event_ids")
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
async def get_partial_state_events(
self, event_ids: Collection[str]
) -> Dict[str, bool]:
Expand Down
74 changes: 58 additions & 16 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Any,
Awaitable,
Callable,
Collection,
Dict,
Generic,
Hashable,
Expand Down Expand Up @@ -69,13 +70,21 @@ def __init__(
self,
orig: Callable[..., Any],
num_args: Optional[int],
uncached_args: Optional[Collection[str]] = None,
cache_context: bool = False,
):
self.orig = orig

arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args

# There's no reason that keyword-only arguments couldn't be supported,
# but right now they're buggy so do not allow them.
if arg_spec.kwonlyargs:
raise ValueError(
"_CacheDescriptorBase does not support keyword-only arguments."
)

if "cache_context" in all_args:
if not cache_context:
raise ValueError(
Expand All @@ -88,6 +97,9 @@ def __init__(
" named `cache_context`"
)

if num_args is not None and uncached_args is not None:
raise ValueError("Cannot provide both num_args and uncached_args")

if num_args is None:
num_args = len(all_args) - 1
if cache_context:
Expand All @@ -105,6 +117,12 @@ def __init__(
# list of the names of the args used as the cache key
self.arg_names = all_args[1 : num_args + 1]

# If there are args to not cache on, filter them out (and fix the size of num_args).
if uncached_args is not None:
include_arg_in_cache_key = [n not in uncached_args for n in self.arg_names]
else:
include_arg_in_cache_key = [True] * len(self.arg_names)

# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
Expand All @@ -119,8 +137,8 @@ def __init__(

self.add_cache_context = cache_context

self.cache_key_builder = get_cache_key_builder(
self.arg_names, self.arg_defaults
self.cache_key_builder = _get_cache_key_builder(
self.arg_names, include_arg_in_cache_key, self.arg_defaults
)


Expand All @@ -130,8 +148,7 @@ class _LruCachedFunction(Generic[F]):


def lru_cache(
max_entries: int = 1000,
cache_context: bool = False,
*, max_entries: int = 1000, cache_context: bool = False
) -> Callable[[F], _LruCachedFunction[F]]:
"""A method decorator that applies a memoizing cache around the function.

Expand Down Expand Up @@ -186,7 +203,9 @@ def __init__(
max_entries: int = 1000,
cache_context: bool = False,
):
super().__init__(orig, num_args=None, cache_context=cache_context)
super().__init__(
orig, num_args=None, uncached_args=None, cache_context=cache_context
)
self.max_entries = max_entries

def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
Expand Down Expand Up @@ -260,6 +279,9 @@ def foo(self, key, cache_context):
num_args: number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
uncached_args: a list of argument names to not use as the cache key.
(``self`` and ``cache_context`` are always ignored.) Cannot be used
with num_args.
tree:
cache_context:
iterable:
Expand All @@ -273,12 +295,18 @@ def __init__(
orig: Callable[..., Any],
max_entries: int = 1000,
num_args: Optional[int] = None,
uncached_args: Optional[Collection[str]] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
prune_unread_entries: bool = True,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
super().__init__(
orig,
num_args=num_args,
uncached_args=uncached_args,
cache_context=cache_context,
)

if tree and self.num_args < 2:
raise RuntimeError(
Expand Down Expand Up @@ -369,7 +397,7 @@ def __init__(
but including list_name) to use as cache keys. Defaults to all
named args of the function.
"""
super().__init__(orig, num_args=num_args)
super().__init__(orig, num_args=num_args, uncached_args=None)

self.list_name = list_name

Expand Down Expand Up @@ -530,8 +558,10 @@ def get_instance(


def cached(
*,
max_entries: int = 1000,
num_args: Optional[int] = None,
uncached_args: Optional[Collection[str]] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
Expand All @@ -541,6 +571,7 @@ def cached(
orig,
max_entries=max_entries,
num_args=num_args,
uncached_args=uncached_args,
tree=tree,
cache_context=cache_context,
iterable=iterable,
Expand All @@ -551,7 +582,7 @@ def cached(


def cachedList(
cached_method_name: str, list_name: str, num_args: Optional[int] = None
*, cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.

Expand Down Expand Up @@ -590,13 +621,16 @@ def batch_do_something(self, first_arg, second_args):
return cast(Callable[[F], _CachedFunction[F]], func)


def get_cache_key_builder(
param_names: Sequence[str], param_defaults: Mapping[str, Any]
def _get_cache_key_builder(
param_names: Sequence[str],
include_params: Sequence[bool],
param_defaults: Mapping[str, Any],
) -> Callable[[Sequence[Any], Mapping[str, Any]], CacheKey]:
"""Construct a function which will build cache keys suitable for a cached function

Args:
param_names: list of formal parameter names for the cached function
include_params: list of bools of whether to include the parameter name in the cache key
param_defaults: a mapping from parameter name to default value for that param

Returns:
Expand All @@ -608,6 +642,7 @@ def get_cache_key_builder(

if len(param_names) == 1:
nm = param_names[0]
assert include_params[0] is True

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
if nm in kwargs:
Expand All @@ -620,13 +655,18 @@ def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
else:

def get_cache_key(args: Sequence[Any], kwargs: Mapping[str, Any]) -> CacheKey:
return tuple(_get_cache_key_gen(param_names, param_defaults, args, kwargs))
return tuple(
_get_cache_key_gen(
param_names, include_params, param_defaults, args, kwargs
)
)

return get_cache_key


def _get_cache_key_gen(
param_names: Iterable[str],
include_params: Iterable[bool],
param_defaults: Mapping[str, Any],
args: Sequence[Any],
kwargs: Mapping[str, Any],
Expand All @@ -637,16 +677,18 @@ def _get_cache_key_gen(
This is essentially the same operation as `inspect.getcallargs`, but optimised so
that we don't need to inspect the target function for each call.
"""

# We loop through each arg name, looking up if its in the `kwargs`,
# otherwise using the next argument in `args`. If there are no more
# args then we try looking the arg name up in the defaults.
pos = 0
for nm in param_names:
for nm, inc in zip(param_names, include_params):
if nm in kwargs:
yield kwargs[nm]
if inc:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
if inc:
yield args[pos]
pos += 1
else:
yield param_defaults[nm]
if inc:
yield param_defaults[nm]
84 changes: 81 additions & 3 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,84 @@ def fn(self, arg1, arg2):
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

@defer.inlineCallbacks
def test_cache_uncached_args(self):
"""
Only the arguments not named in uncached_args should matter to the cache

Note that this is identical to test_cache_num_args, but provides the
arguments differently.
"""

class Cls:
# Note that it is important that this is not the last argument to
# test behaviour of skipping arguments properly.
@descriptors.cached(uncached_args=("arg2",))
def fn(self, arg1, arg2, arg3):
return self.mock(arg1, arg2, arg3)

def __init__(self):
self.mock = mock.Mock()

obj = Cls()
obj.mock.return_value = "fish"
r = yield obj.fn(1, 2, 3)
self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = "chips"
r = yield obj.fn(2, 3, 4)
self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(2, 3, 4)
obj.mock.reset_mock()

# the two values should now be cached; we should be able to vary
# the second argument and still get the cached result.
r = yield obj.fn(1, 4, 3)
self.assertEqual(r, "fish")
r = yield obj.fn(2, 5, 4)
self.assertEqual(r, "chips")
obj.mock.assert_not_called()

@defer.inlineCallbacks
def test_cache_kwargs(self):
"""Test that keyword arguments are treated properly"""

class Cls:
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, kwarg1=2):
return self.mock(arg1, kwarg1=kwarg1)

obj = Cls()
obj.mock.return_value = "fish"
r = yield obj.fn(1, kwarg1=2)
self.assertEqual(r, "fish")
obj.mock.assert_called_once_with(1, kwarg1=2)
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = "chips"
r = yield obj.fn(1, kwarg1=3)
self.assertEqual(r, "chips")
obj.mock.assert_called_once_with(1, kwarg1=3)
obj.mock.reset_mock()

# the values should now be cached.
r = yield obj.fn(1, kwarg1=2)
self.assertEqual(r, "fish")
# We should be able to not provide kwarg1 and get the cached value back.
r = yield obj.fn(1)
self.assertEqual(r, "fish")
# Keyword arguments can be in any order.
r = yield obj.fn(kwarg1=2, arg1=1)
self.assertEqual(r, "fish")
obj.mock.assert_not_called()

def test_cache_with_sync_exception(self):
"""If the wrapped function throws synchronously, things should continue to work"""

Expand Down Expand Up @@ -656,7 +734,7 @@ def __init__(self):
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1")
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
assert current_context().name == "c1"
# we want this to behave like an asynchronous function
Expand Down Expand Up @@ -715,7 +793,7 @@ def __init__(self):
def fn(self, arg1):
pass

@descriptors.cachedList("fn", "args1")
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
def list_fn(self, args1) -> "Deferred[dict]":
return self.mock(args1)

Expand Down Expand Up @@ -758,7 +836,7 @@ def __init__(self):
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1")
@descriptors.cachedList(cached_method_name="fn", list_name="args1")
async def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
await run_on_reactor()
Expand Down