diff --git a/changelog.d/13242.misc b/changelog.d/13242.misc new file mode 100644 index 000000000000..7f8ec0815f64 --- /dev/null +++ b/changelog.d/13242.misc @@ -0,0 +1 @@ +Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/database.py b/synapse/storage/database.py index e21ab0851530..6a6d0dcd7391 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -57,7 +57,7 @@ from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor -from synapse.util.async_helpers import delay_cancellation +from synapse.util.async_helpers import delay_cancellation, maybe_awaitable from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -818,12 +818,14 @@ async def _runInteraction() -> R: ) for after_callback, after_args, after_kwargs in after_callbacks: - after_callback(*after_args, **after_kwargs) + await maybe_awaitable(after_callback(*after_args, **after_kwargs)) return cast(R, result) except Exception: - for after_callback, after_args, after_kwargs in exception_callbacks: - after_callback(*after_args, **after_kwargs) + for exception_callback, after_args, after_kwargs in exception_callbacks: + await maybe_awaitable( + exception_callback(*after_args, **after_kwargs) + ) raise # To handle cancellation, we ensure that `after_callback`s and diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 1653a6a9b694..2367ddeea3fd 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -193,7 +193,10 @@ def _invalidate_caches_for_event( relates_to: Optional[str], backfilled: bool, ) -> None: - self._invalidate_get_event_cache(event_id) + # This invalidates any local in-memory cached event objects, the original + # process triggering the invalidation is responsible for clearing any external + # cached objects. + self._invalidate_local_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) self.get_latest_event_ids_in_room.invalidate((room_id,)) @@ -208,7 +211,7 @@ def _invalidate_caches_for_event( self._events_stream_cache.entity_has_changed(room_id, stream_ordering) if redacts: - self._invalidate_get_event_cache(redacts) + self._invalidate_local_get_event_cache(redacts) # Caches which might leak edits must be invalidated for the event being # redacted. self.get_relations_for_event.invalidate((redacts,)) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 2ff3d21305cf..d41697a85069 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1669,9 +1669,9 @@ def _add_to_cache( if not row["rejects"] and not row["redacts"]: to_prefill.append(EventCacheEntry(event=event, redacted_event=None)) - def prefill() -> None: + async def prefill() -> None: for cache_entry in to_prefill: - self.store._get_event_cache.set( + await self.store._get_event_cache.set( (cache_entry.event.event_id,), cache_entry ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index b99b10778490..dd3b28bcc8eb 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -79,7 +79,7 @@ from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.lrucache import LruCache +from synapse.util.caches.lrucache import AsyncLruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -238,7 +238,9 @@ def __init__( 5 * 60 * 1000, ) - self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache( + self._get_event_cache: AsyncLruCache[ + Tuple[str], EventCacheEntry + ] = AsyncLruCache( cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, ) @@ -617,7 +619,7 @@ async def _get_events_from_cache_or_db( Returns: map from event id to result """ - event_entry_map = self._get_events_from_cache( + event_entry_map = await self._get_events_from_cache( event_ids, ) @@ -729,12 +731,22 @@ async def get_missing_events_from_db() -> Dict[str, EventCacheEntry]: return event_entry_map - def _invalidate_get_event_cache(self, event_id: str) -> None: - self._get_event_cache.invalidate((event_id,)) + async def _invalidate_get_event_cache(self, event_id: str) -> None: + # First we invalidate the asynchronous cache instance. This may include + # out-of-process caches such as Redis/memcache. Once complete we can + # invalidate any in memory cache. The ordering is important here to + # ensure we don't pull in any remote invalid value after we invalidate + # the in-memory cache. + await self._get_event_cache.invalidate((event_id,)) self._event_ref.pop(event_id, None) self._current_event_fetches.pop(event_id, None) - def _get_events_from_cache( + def _invalidate_local_get_event_cache(self, event_id: str) -> None: + self._get_event_cache.invalidate_local((event_id,)) + self._event_ref.pop(event_id, None) + self._current_event_fetches.pop(event_id, None) + + async def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: """Fetch events from the caches. @@ -749,7 +761,7 @@ def _get_events_from_cache( for event_id in events: # First check if it's in the event cache - ret = self._get_event_cache.get( + ret = await self._get_event_cache.get( (event_id,), None, update_metrics=update_metrics ) if ret: @@ -771,7 +783,7 @@ def _get_events_from_cache( # We add the entry back into the cache as we want to keep # recently queried events in the cache. - self._get_event_cache.set((event_id,), cache_entry) + await self._get_event_cache.set((event_id,), cache_entry) return event_map @@ -1148,7 +1160,7 @@ async def _get_events_from_db( event=original_ev, redacted_event=redacted_event ) - self._get_event_cache.set((event_id,), cache_entry) + await self._get_event_cache.set((event_id,), cache_entry) result_map[event_id] = cache_entry if not redacted_event: @@ -1382,7 +1394,9 @@ async def _have_seen_events_dict( # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) + (rid, eid) + for (rid, eid) in keys + if await self._get_event_cache.contains((eid,)) } results = dict.fromkeys(cache_results, True) remaining = [k for k in keys if k not in cache_results] diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 87b0d090391a..549ce07c1635 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -302,7 +302,7 @@ def _purge_history_txn( self._invalidate_cache_and_stream( txn, self.have_seen_event, (room_id, event_id) ) - self._invalidate_get_event_cache(event_id) + txn.call_after(self._invalidate_get_event_cache, event_id) logger.info("[purge] done") diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 31bc8c56011a..9c7f8ce88311 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -777,7 +777,9 @@ async def _get_joined_users_from_context( # We don't update the event cache hit ratio as it completely throws off # the hit ratio counts. After all, we don't populate the cache if we # miss it here - event_map = self._get_events_from_cache(member_event_ids, update_metrics=False) + event_map = await self._get_events_from_cache( + member_event_ids, update_metrics=False + ) missing_member_event_ids = [] for event_id in member_event_ids: diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 8ed5325c5d07..31f41fec8284 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -730,3 +730,41 @@ def __del__(self) -> None: # This happens e.g. in the sync code where we have an expiring cache of # lru caches. self.clear() + + +class AsyncLruCache(Generic[KT, VT]): + """ + An asynchronous wrapper around a subset of the LruCache API. + + On its own this doesn't change the behaviour but allows subclasses that + utilize external cache systems that require await behaviour to be created. + """ + + def __init__(self, *args, **kwargs): # type: ignore + self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs) + + async def get( + self, key: KT, default: Optional[T] = None, update_metrics: bool = True + ) -> Optional[VT]: + return self._lru_cache.get(key, update_metrics=update_metrics) + + async def set(self, key: KT, value: VT) -> None: + self._lru_cache.set(key, value) + + async def invalidate(self, key: KT) -> None: + # This method should invalidate any external cache and then invalidate the LruCache. + return self._lru_cache.invalidate(key) + + def invalidate_local(self, key: KT) -> None: + """Remove an entry from the local cache + + This variant of `invalidate` is useful if we know that the external + cache has already been invalidated. + """ + return self._lru_cache.invalidate(key) + + async def contains(self, key: KT) -> bool: + return self._lru_cache.contains(key) + + async def clear(self) -> None: + self._lru_cache.clear() diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index ecc7cc6461ee..e3f38fbcc5ce 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,7 +159,7 @@ def test_unknown_room_version(self): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() # The rooms should be excluded from the sync response. diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 38963ce4a74c..46d829b062a0 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -143,7 +143,7 @@ def prepare(self, reactor, clock, hs): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) def test_simple(self): """Test that we cache events that we pull from the DB.""" @@ -160,7 +160,7 @@ def test_event_ref(self): """ # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: # We keep hold of the event event though we never use it. @@ -170,7 +170,7 @@ def test_event_ref(self): self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) # Reset the event cache - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) with LoggingContext("test") as ctx: self.get_success(self.store.get_event(self.event_id)) @@ -345,7 +345,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.store._get_event_cache.clear() + self.get_success(self.store._get_event_cache.clear()) @contextmanager def blocking_get_event_calls( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 8dfaa0559b26..9c1182ed16b5 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -115,6 +115,6 @@ def test_purge_room(self): ) # The events aren't found. - self.store._invalidate_get_event_cache(create_event.event_id) + self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)