From c9afe9bf3b76cc9b89c3e487971375f2db756d2d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Dec 2020 12:00:22 -0500 Subject: [PATCH 01/12] Add additional type hints to state storage. --- mypy.ini | 1 + synapse/handlers/initial_sync.py | 4 +--- synapse/handlers/sync.py | 2 +- synapse/storage/__init__.py | 9 ++++++-- synapse/storage/state.py | 35 +++++++++++++++++++++++--------- 5 files changed, 35 insertions(+), 16 deletions(-) diff --git a/mypy.ini b/mypy.ini index 0518d3f1af14..2fac850dd7b2 100644 --- a/mypy.ini +++ b/mypy.ini @@ -68,6 +68,7 @@ files = synapse/server_notices, synapse/spam_checker_api, synapse/state, + synapse/storage/__init__.py, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, synapse/storage/databases/main/pusher.py, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index cb11754bf878..fbd8df9dcc09 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -323,9 +323,7 @@ async def _room_initial_sync_parted( member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_store.get_state_for_events([member_event_id]) - - room_state = room_state[member_event_id] + room_state = await self.state_store.get_state_for_event(member_event_id) limit = pagin_config.limit if pagin_config else None if limit is None: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 9827c7eb8dfc..5c7590f38e4d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -554,7 +554,7 @@ async def get_state_after_event( event.event_id, state_filter=state_filter ) if event.is_state(): - state_ids = state_ids.copy() + state_ids = dict(state_ids) state_ids[(event.type, event.state_key)] = event.event_id return state_ids diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index bbff3c8d5b4c..c0d9d1240f16 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -27,6 +27,7 @@ data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ +from typing import TYPE_CHECKING from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -34,14 +35,18 @@ from synapse.storage.purge_events import PurgeEventsStorage from synapse.storage.state import StateGroupStorage -__all__ = ["DataStores", "DataStore"] +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + + +__all__ = ["Databases", "DataStore"] class Storage: """The high level interfaces for talking to various storage layers. """ - def __init__(self, hs, stores: Databases): + def __init__(self, hs: "HomeServer", stores: Databases): # We include the main data store here mainly so that we don't have to # rewrite all the existing code to split it into high vs low level # interfaces. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 08a69f2f9667..56faca33e1f9 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -12,16 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Awaitable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + TypeVar, +) import attr from synapse.api.constants import EventTypes from synapse.events import EventBase +from synapse.storage.databases import Databases from synapse.types import MutableStateMap, StateMap +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) # Used for generic functions below @@ -330,10 +343,12 @@ class StateGroupStorage: """High level interface to fetching state for event. """ - def __init__(self, hs, stores): + def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores - async def get_state_group_delta(self, state_group: int): + async def get_state_group_delta( + self, state_group: int + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: """Given a state group try to return a previous group and a delta between the old and the new. @@ -341,8 +356,8 @@ async def get_state_group_delta(self, state_group: int): state_group: The state group used to retrieve state deltas. Returns: - Tuple[Optional[int], Optional[StateMap[str]]]: - (prev_group, delta_ids) + A tuple of the previous group and a state map of the event IDs which + make up the delta between the old ad new state groups. """ return await self.stores.state.get_state_group_delta(state_group) @@ -436,7 +451,7 @@ def _get_state_groups_from_groups( async def get_state_for_events( self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[str, StateMap[EventBase]]: """Given a list of event_ids and type tuples, return a list of state dicts for each event. @@ -472,7 +487,7 @@ async def get_state_for_events( async def get_state_ids_for_events( self, event_ids: List[str], state_filter: StateFilter = StateFilter.all() - ): + ) -> Dict[str, StateMap[str]]: """ Get the state dicts corresponding to a list of events, containing the event_ids of the state events (as opposed to the events themselves) @@ -500,7 +515,7 @@ async def get_state_ids_for_events( async def get_state_for_event( self, event_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[EventBase]: """ Get the state dict corresponding to a particular event @@ -516,7 +531,7 @@ async def get_state_for_event( async def get_state_ids_for_event( self, event_id: str, state_filter: StateFilter = StateFilter.all() - ): + ) -> StateMap[str]: """ Get the state dict corresponding to a particular event From 2ebe3b413d128a103afd7db1d0f72f85d950dd9a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Dec 2020 14:06:12 -0500 Subject: [PATCH 02/12] Add type hints to base storage. --- mypy.ini | 1 + synapse/storage/_base.py | 32 +++++++++++++++++++++----------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/mypy.ini b/mypy.ini index 2fac850dd7b2..3d740e10b782 100644 --- a/mypy.ini +++ b/mypy.ini @@ -69,6 +69,7 @@ files = synapse/spam_checker_api, synapse/state, synapse/storage/__init__.py, + synapse/storage/_base.py, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, synapse/storage/databases/main/pusher.py, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 2b196ded1bd0..f1cc33d55b96 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,14 +17,18 @@ import logging import random from abc import ABCMeta -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, Union from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool -from synapse.types import Collection, get_domain_from_id +from synapse.storage.types import Connection +from synapse.types import Collection, StreamToken, get_domain_from_id from synapse.util import json_decoder +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -36,24 +40,27 @@ class SQLBaseStore(metaclass=ABCMeta): per data store (and not one per physical database). """ - def __init__(self, database: DatabasePool, db_conn, hs): + def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.hs = hs self._clock = hs.get_clock() self.database_engine = database.engine self.db_pool = database self.rand = random.SystemRandom() - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: StreamToken, rows + ) -> None: pass - def _invalidate_state_caches(self, room_id, members_changed): + def _invalidate_state_caches( + self, room_id: str, members_changed: Iterable[str] + ) -> None: """Invalidates caches that are based on the current state, but does not stream invalidations down replication. Args: - room_id (str): Room where state changed - members_changed (iterable[str]): The user_ids of members that have - changed + room_id: Room where state changed + members_changed: The user_ids of members that have changed """ for host in {get_domain_from_id(u) for u in members_changed}: self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) @@ -64,7 +71,7 @@ def _invalidate_state_caches(self, room_id, members_changed): def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] - ): + ) -> None: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. @@ -88,12 +95,15 @@ def _attempt_to_invalidate_cache( cache.invalidate(tuple(key)) -def db_to_json(db_content): +def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: """ Take some data from a database row and return a JSON-decoded object. Args: - db_content (memoryview|buffer|bytes|bytearray|unicode) + db_content: The JSON-encoded contents from the database. + + Returns: + The object decoded from JSON. """ # psycopg2 on Python 3 returns memoryview objects, which we need to # cast to bytes to decode From b15e0dfe09cff7a3045b203c6af32978883c6b3d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Dec 2020 14:07:23 -0500 Subject: [PATCH 03/12] Add type hints to background updates. --- mypy.ini | 1 + synapse/storage/background_updates.py | 111 +++++++++++++++----------- 2 files changed, 65 insertions(+), 47 deletions(-) diff --git a/mypy.ini b/mypy.ini index 3d740e10b782..4e82b5eae4c8 100644 --- a/mypy.ini +++ b/mypy.ini @@ -70,6 +70,7 @@ files = synapse/state, synapse/storage/__init__.py, synapse/storage/_base.py, + synapse/storage/background_updates.py, synapse/storage/databases/main/appservice.py, synapse/storage/databases/main/events.py, synapse/storage/databases/main/pusher.py, diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 810721ebe975..95a9af8079c2 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,29 +12,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Iterable, Optional from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.types import Connection +from synapse.types import JsonDict from synapse.util import json_encoder from . import engines +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + from synapse.storage.database import DatabasePool, LoggingTransaction + logger = logging.getLogger(__name__) class BackgroundUpdatePerformance: """Tracks the how long a background update is taking to update its items""" - def __init__(self, name): + def __init__(self, name: str): self.name = name self.total_item_count = 0 - self.total_duration_ms = 0 - self.avg_item_count = 0 - self.avg_duration_ms = 0 + self.total_duration_ms = 0.0 + self.avg_item_count = 0.0 + self.avg_duration_ms = 0.0 - def update(self, item_count, duration_ms): + def update(self, item_count: int, duration_ms: float) -> None: """Update the stats after doing an update""" self.total_item_count += item_count self.total_duration_ms += duration_ms @@ -44,7 +49,7 @@ def update(self, item_count, duration_ms): self.avg_item_count += 0.1 * (item_count - self.avg_item_count) self.avg_duration_ms += 0.1 * (duration_ms - self.avg_duration_ms) - def average_items_per_ms(self): + def average_items_per_ms(self) -> Optional[float]: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -58,7 +63,7 @@ def average_items_per_ms(self): # changes in how long the update process takes. return float(self.avg_item_count) / float(self.avg_duration_ms) - def total_items_per_ms(self): + def total_items_per_ms(self) -> Optional[float]: """An estimate of how long it takes to do a single update. Returns: A duration in ms as a float @@ -83,21 +88,25 @@ class BackgroundUpdater: BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, hs, database): + def __init__(self, hs: "HomeServer", database: "DatabasePool"): self._clock = hs.get_clock() self.db_pool = database # if a background update is currently running, its name. self._current_background_update = None # type: Optional[str] - self._background_update_performance = {} - self._background_update_handlers = {} + self._background_update_performance = ( + {} + ) # type: Dict[str, BackgroundUpdatePerformance] + self._background_update_handlers = ( + {} + ) # type: Dict[str, Callable[[JsonDict, int], Awaitable[int]]] self._all_done = False - def start_doing_background_updates(self): + def start_doing_background_updates(self) -> None: run_as_background_process("background_updates", self.run_background_updates) - async def run_background_updates(self, sleep=True): + async def run_background_updates(self, sleep: bool = True) -> None: logger.info("Starting background schema updates") while True: if sleep: @@ -148,7 +157,7 @@ async def has_completed_background_updates(self) -> bool: return False - async def has_completed_background_update(self, update_name) -> bool: + async def has_completed_background_update(self, update_name: str) -> bool: """Check if the given background update has finished running. """ if self._all_done: @@ -173,8 +182,7 @@ async def do_next_background_update(self, desired_duration_ms: float) -> bool: Returns once some amount of work is done. Args: - desired_duration_ms(float): How long we want to spend - updating. + desired_duration_ms: How long we want to spend updating. Returns: True if we have finished running all the background updates, otherwise False """ @@ -220,6 +228,7 @@ def get_background_updates_txn(txn): return False async def _do_background_update(self, desired_duration_ms: float) -> int: + assert self._current_background_update is not None update_name = self._current_background_update logger.info("Starting update batch on background update '%s'", update_name) @@ -273,7 +282,11 @@ async def _do_background_update(self, desired_duration_ms: float) -> int: return len(self._background_update_performance) - def register_background_update_handler(self, update_name, update_handler): + def register_background_update_handler( + self, + update_name: str, + update_handler: Callable[[JsonDict, int], Awaitable[int]], + ): """Register a handler for doing a background update. The handler should take two arguments: @@ -287,12 +300,12 @@ def register_background_update_handler(self, update_name, update_handler): The handler is responsible for updating the progress of the update. Args: - update_name(str): The name of the update that this code handles. - update_handler(function): The function that does the update. + update_name: The name of the update that this code handles. + update_handler: The function that does the update. """ self._background_update_handlers[update_name] = update_handler - def register_noop_background_update(self, update_name): + def register_noop_background_update(self, update_name: str) -> None: """Register a noop handler for a background update. This is useful when we previously did a background update, but no @@ -302,10 +315,10 @@ def register_noop_background_update(self, update_name): also be called to clear the update. Args: - update_name (str): Name of update + update_name: Name of update """ - async def noop_update(progress, batch_size): + async def noop_update(progress: JsonDict, batch_size: int) -> int: await self._end_background_update(update_name) return 1 @@ -313,14 +326,14 @@ async def noop_update(progress, batch_size): def register_background_index_update( self, - update_name, - index_name, - table, - columns, - where_clause=None, - unique=False, - psql_only=False, - ): + update_name: str, + index_name: str, + table: str, + columns: Iterable[str], + where_clause: Optional[str] = None, + unique: bool = False, + psql_only: bool = False, + ) -> None: """Helper for store classes to do a background index addition To use: @@ -332,19 +345,19 @@ def register_background_index_update( 2. In the Store constructor, call this method Args: - update_name (str): update_name to register for - index_name (str): name of index to add - table (str): table to add index to - columns (list[str]): columns/expressions to include in index - unique (bool): true to make a UNIQUE index + update_name: update_name to register for + index_name: name of index to add + table: table to add index to + columns: columns/expressions to include in index + unique: true to make a UNIQUE index psql_only: true to only create this index on psql databases (useful for virtual sqlite tables) """ - def create_index_psql(conn): + def create_index_psql(conn: Connection) -> None: conn.rollback() # postgres insists on autocommit for the index - conn.set_session(autocommit=True) + conn.set_session(autocommit=True) # type: ignore try: c = conn.cursor() @@ -371,9 +384,9 @@ def create_index_psql(conn): logger.debug("[SQL] %s", sql) c.execute(sql) finally: - conn.set_session(autocommit=False) + conn.set_session(autocommit=False) # type: ignore - def create_index_sqlite(conn): + def create_index_sqlite(conn: Connection) -> None: # Sqlite doesn't support concurrent creation of indexes. # # We don't use partial indices on SQLite as it wasn't introduced @@ -399,7 +412,7 @@ def create_index_sqlite(conn): c.execute(sql) if isinstance(self.db_pool.engine, engines.PostgresEngine): - runner = create_index_psql + runner = create_index_psql # type: Optional[Callable[[Connection], None]] elif psql_only: runner = None else: @@ -433,7 +446,9 @@ async def _end_background_update(self, update_name: str) -> None: "background_updates", keyvalues={"update_name": update_name} ) - async def _background_update_progress(self, update_name: str, progress: dict): + async def _background_update_progress( + self, update_name: str, progress: dict + ) -> None: """Update the progress of a background update Args: @@ -441,20 +456,22 @@ async def _background_update_progress(self, update_name: str, progress: dict): progress: The progress of the update. """ - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "background_update_progress", self._background_update_progress_txn, update_name, progress, ) - def _background_update_progress_txn(self, txn, update_name, progress): + def _background_update_progress_txn( + self, txn: LoggingTransaction, update_name: str, progress: JsonDict + ) -> None: """Update the progress of a background update Args: - txn(cursor): The transaction. - update_name(str): The name of the background update task - progress(dict): The progress of the update. + txn: The transaction. + update_name: The name of the background update task + progress: The progress of the update. """ progress_json = json_encoder.encode(progress) From cbc2f667c4fa0403669578c3db5562e15ff56bfa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Dec 2020 15:45:08 -0500 Subject: [PATCH 04/12] Add type hints to relations. --- mypy.ini | 4 ++++ synapse/storage/relations.py | 44 +++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/mypy.ini b/mypy.ini index 4e82b5eae4c8..85752c6928ef 100644 --- a/mypy.ini +++ b/mypy.ini @@ -80,7 +80,11 @@ files = synapse/storage/database.py, synapse/storage/engines, synapse/storage/persist_events.py, + synapse/storage/push_rule.py, + synapse/storage/relations.py, + synapse/storage/roommember.py, synapse/storage/state.py, + synapse/storage/types.py, synapse/storage/util, synapse/streams, synapse/types.py, diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index cec96ad6a72e..03d554734fa5 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +from typing import Any, Dict, List, Optional, Tuple import attr from synapse.api.errors import SynapseError +from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -27,18 +29,18 @@ class PaginationChunk: """Returned by relation pagination APIs. Attributes: - chunk (list): The rows returned by pagination - next_batch (Any|None): Token to fetch next set of results with, if + chunk: The rows returned by pagination + next_batch: Token to fetch next set of results with, if None then there are no more results. - prev_batch (Any|None): Token to fetch previous set of results with, if + prev_batch: Token to fetch previous set of results with, if None then there are no previous results. """ - chunk = attr.ib() - next_batch = attr.ib(default=None) - prev_batch = attr.ib(default=None) + chunk = attr.ib(type=List[JsonDict]) + next_batch = attr.ib(type=Optional[Any], default=None) + prev_batch = attr.ib(type=Optional[Any], default=None) - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: d = {"chunk": self.chunk} if self.next_batch: @@ -59,25 +61,25 @@ class RelationPaginationToken: boundaries of the chunk as pagination tokens. Attributes: - topological (int): The topological ordering of the boundary event - stream (int): The stream ordering of the boundary event. + topological: The topological ordering of the boundary event + stream: The stream ordering of the boundary event. """ - topological = attr.ib() - stream = attr.ib() + topological = attr.ib(type=int) + stream = attr.ib(type=int) @staticmethod - def from_string(string): + def from_string(string: str) -> "RelationPaginationToken": try: t, s = string.split("-") return RelationPaginationToken(int(t), int(s)) except ValueError: raise SynapseError(400, "Invalid token") - def to_string(self): + def to_string(self) -> str: return "%d-%d" % (self.topological, self.stream) - def as_tuple(self): + def as_tuple(self) -> Tuple[Any, ...]: return attr.astuple(self) @@ -89,23 +91,23 @@ class AggregationPaginationToken: aggregation groups, we can just use them as our pagination token. Attributes: - count (int): The count of relations in the boundar group. - stream (int): The MAX stream ordering in the boundary group. + count: The count of relations in the boundar group. + stream: The MAX stream ordering in the boundary group. """ - count = attr.ib() - stream = attr.ib() + count = attr.ib(type=int) + stream = attr.ib(type=int) @staticmethod - def from_string(string): + def from_string(string: str) -> "AggregationPaginationToken": try: c, s = string.split("-") return AggregationPaginationToken(int(c), int(s)) except ValueError: raise SynapseError(400, "Invalid token") - def to_string(self): + def to_string(self) -> str: return "%d-%d" % (self.count, self.stream) - def as_tuple(self): + def as_tuple(self) -> Tuple[Any, ...]: return attr.astuple(self) From 44ada012a82770ba6ba01018a047eecd07688b99 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Dec 2020 15:50:22 -0500 Subject: [PATCH 05/12] Add type hints to keys. --- mypy.ini | 2 ++ synapse/storage/keys.py | 5 +++-- synapse/storage/purge_events.py | 11 ++++++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/mypy.ini b/mypy.ini index 85752c6928ef..6a2a9a9379de 100644 --- a/mypy.ini +++ b/mypy.ini @@ -79,7 +79,9 @@ files = synapse/storage/databases/main/ui_auth.py, synapse/storage/database.py, synapse/storage/engines, + synapse/storage/keys.py, synapse/storage/persist_events.py, + synapse/storage/purge_events.py, synapse/storage/push_rule.py, synapse/storage/relations.py, synapse/storage/roommember.py, diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index afd10f7bae1c..c03871f3933e 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -17,11 +17,12 @@ import logging import attr +from signedjson.types import VerifyKey logger = logging.getLogger(__name__) @attr.s(slots=True, frozen=True) class FetchKeyResult: - verify_key = attr.ib() # VerifyKey: the key itself - valid_until_ts = attr.ib() # int: how long we can use this key for + verify_key = attr.ib(type=VerifyKey) # the key itself + valid_until_ts = attr.ib(type=int) # how long we can use this key for diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index bfa0a9fd065a..6c359c1aaedb 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -15,7 +15,12 @@ import itertools import logging -from typing import Set +from typing import TYPE_CHECKING, Set + +from synapse.storage.databases import Databases + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -24,10 +29,10 @@ class PurgeEventsStorage: """High level interface for purging rooms and event history. """ - def __init__(self, hs, stores): + def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores - async def purge_room(self, room_id: str): + async def purge_room(self, room_id: str) -> None: """Deletes all record of a room """ From d972c78e2b4459a5bf84d3f2ec6e6306791192d6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Dec 2020 16:17:41 -0500 Subject: [PATCH 06/12] Check prepare database. --- mypy.ini | 1 + synapse/storage/prepare_database.py | 111 ++++++++++++++++------------ 2 files changed, 64 insertions(+), 48 deletions(-) diff --git a/mypy.ini b/mypy.ini index 6a2a9a9379de..ead87f7585f4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -81,6 +81,7 @@ files = synapse/storage/engines, synapse/storage/keys.py, synapse/storage/persist_events.py, + synapse/storage/prepare_database.py, synapse/storage/purge_events.py, synapse/storage/push_rule.py, synapse/storage/relations.py, diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 459754feabf7..09ef7991f909 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -18,7 +18,15 @@ import os import re from collections import Counter -from typing import Optional, TextIO +from typing import ( + Counter as CounterType, + Generator, + Iterable, + List, + Optional, + TextIO, + Tuple, +) import attr @@ -70,7 +78,7 @@ def prepare_database( db_conn: LoggingDatabaseConnection, database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], - databases: Collection[str] = ["main", "state"], + databases: Collection[str] = ("main", "state"), ): """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -155,7 +163,9 @@ def prepare_database( raise -def _setup_new_database(cur, database_engine, databases): +def _setup_new_database( + cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str] +) -> None: """Sets up the physical database by finding a base set of "full schemas" and then applying any necessary deltas, including schemas from the given data stores. @@ -188,10 +198,9 @@ def _setup_new_database(cur, database_engine, databases): folder as well those in the data stores specified. Args: - cur (Cursor): a database cursor - database_engine (DatabaseEngine) - databases (list[str]): The names of the databases to instantiate - on the given physical database. + cur: a database cursor + database_engine + databases: The names of the databases to instantiate on the given physical database. """ # We're about to set up a brand new database so we check that its @@ -199,12 +208,11 @@ def _setup_new_database(cur, database_engine, databases): database_engine.check_new_database(cur) current_dir = os.path.join(dir_path, "schema", "full_schemas") - directory_entries = os.listdir(current_dir) # First we find the highest full schema version we have valid_versions = [] - for filename in directory_entries: + for filename in os.listdir(current_dir): try: ver = int(filename) except ValueError: @@ -237,7 +245,7 @@ def _setup_new_database(cur, database_engine, databases): for database in databases ) - directory_entries = [] + directory_entries = [] # type: List[_DirectoryListing] for directory in directories: directory_entries.extend( _DirectoryListing(file_name, os.path.join(directory, file_name)) @@ -275,15 +283,15 @@ def _setup_new_database(cur, database_engine, databases): def _upgrade_existing_database( - cur, - current_version, - applied_delta_files, - upgraded, - database_engine, - config, - databases, - is_empty=False, -): + cur: Cursor, + current_version: int, + applied_delta_files: List[str], + upgraded: bool, + database_engine: BaseDatabaseEngine, + config: Optional[HomeServerConfig], + databases: Collection[str], + is_empty: bool = False, +) -> None: """Upgrades an existing physical database. Delta files can either be SQL stored in *.sql files, or python modules @@ -323,21 +331,20 @@ def _upgrade_existing_database( for a version before applying those in the next version. Args: - cur (Cursor) - current_version (int): The current version of the schema. - applied_delta_files (list): A list of deltas that have already been - applied. - upgraded (bool): Whether the current version was generated by having + cur + current_version: The current version of the schema. + applied_delta_files: A list of deltas that have already been applied. + upgraded: Whether the current version was generated by having applied deltas or from full schema file. If `True` the function will never apply delta files for the given `current_version`, since the current_version wasn't generated by applying those delta files. - database_engine (DatabaseEngine) - config (synapse.config.homeserver.HomeServerConfig|None): + database_engine + config: None if we are initialising a blank database, otherwise the application config - databases (list[str]): The names of the databases to instantiate + databases: The names of the databases to instantiate on the given physical database. - is_empty (bool): Is this a blank database? I.e. do we need to run the + is_empty: Is this a blank database? I.e. do we need to run the upgrade portions of the delta scripts. """ if is_empty: @@ -358,6 +365,7 @@ def _upgrade_existing_database( if not is_empty and "main" in databases: from synapse.storage.databases.main import check_database_before_upgrade + assert config is not None check_database_before_upgrade(cur, database_engine, config) start_ver = current_version @@ -388,10 +396,10 @@ def _upgrade_existing_database( ) # Used to check if we have any duplicate file names - file_name_counter = Counter() + file_name_counter = Counter() # type: CounterType[str] # Now find which directories have anything of interest. - directory_entries = [] + directory_entries = [] # type: List[_DirectoryListing] for directory in directories: logger.debug("Looking for schema deltas in %s", directory) try: @@ -445,11 +453,11 @@ def _upgrade_existing_database( module_name = "synapse.storage.v%d_%s" % (v, root_name) with open(absolute_path) as python_file: - module = imp.load_source(module_name, absolute_path, python_file) + module = imp.load_source(module_name, absolute_path, python_file) # type: ignore logger.info("Running script %s", relative_path) - module.run_create(cur, database_engine) + module.run_create(cur, database_engine) # type: ignore if not is_empty: - module.run_upgrade(cur, database_engine, config=config) + module.run_upgrade(cur, database_engine, config=config) # type: ignore elif ext == ".pyc" or file_name == "__pycache__": # Sometimes .pyc files turn up anyway even though we've # disabled their generation; e.g. from distribution package @@ -497,14 +505,15 @@ def _upgrade_existing_database( logger.info("Schema now up to date") -def _apply_module_schemas(txn, database_engine, config): +def _apply_module_schemas( + txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig +) -> None: """Apply the module schemas for the dynamic modules, if any Args: cur: database cursor - database_engine: synapse database engine class - config (synapse.config.homeserver.HomeServerConfig): - application config + database_engine: + config: application config """ for (mod, _config) in config.password_providers: if not hasattr(mod, "get_db_schema_files"): @@ -515,15 +524,19 @@ def _apply_module_schemas(txn, database_engine, config): ) -def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): +def _apply_module_schema_files( + cur: Cursor, + database_engine: BaseDatabaseEngine, + modname: str, + names_and_streams: Iterable[Tuple[str, TextIO]], +) -> None: """Apply the module schemas for a single module Args: cur: database cursor database_engine: synapse database engine class - modname (str): fully qualified name of the module - names_and_streams (Iterable[(str, file)]): the names and streams of - schemas to be applied + modname: fully qualified name of the module + names_and_streams: the names and streams of schemas to be applied """ cur.execute( "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,), @@ -549,7 +562,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams) ) -def get_statements(f): +def get_statements(f: Iterable[str]) -> Generator[str, None, None]: statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment @@ -594,17 +607,19 @@ def get_statements(f): statement_buffer = statements[-1].strip() -def executescript(txn, schema_path): +def executescript(txn: Cursor, schema_path: str) -> None: with open(schema_path, "r") as f: execute_statements_from_stream(txn, f) -def execute_statements_from_stream(cur: Cursor, f: TextIO): +def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None: for statement in get_statements(f): cur.execute(statement) -def _get_or_create_schema_state(txn, database_engine): +def _get_or_create_schema_state( + txn: Cursor, database_engine: BaseDatabaseEngine +) -> Optional[Tuple[int, List[str], bool]]: # Bluntly try creating the schema_version tables. schema_path = os.path.join(dir_path, "schema", "schema_version.sql") executescript(txn, schema_path) @@ -612,7 +627,6 @@ def _get_or_create_schema_state(txn, database_engine): txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() current_version = int(row[0]) if row else None - upgraded = bool(row[1]) if row else None if current_version: txn.execute( @@ -620,6 +634,7 @@ def _get_or_create_schema_state(txn, database_engine): (current_version,), ) applied_deltas = [d for d, in txn] + upgraded = bool(row[1]) return current_version, applied_deltas, upgraded return None @@ -634,5 +649,5 @@ class _DirectoryListing: `file_name` attr is kept first. """ - file_name = attr.ib() - absolute_path = attr.ib() + file_name = attr.ib(type=str) + absolute_path = attr.ib(type=str) From e3ff527e2aac9183bbaca65179a60774078338e8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Dec 2020 09:27:27 -0500 Subject: [PATCH 07/12] Add newsfragment. --- changelog.d/8980.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/8980.misc diff --git a/changelog.d/8980.misc b/changelog.d/8980.misc new file mode 100644 index 000000000000..83ef3c5def4b --- /dev/null +++ b/changelog.d/8980.misc @@ -0,0 +1 @@ +Add type hints to the base storage code. From 2824fb2bc6810aaffb08c6e01811984b0920598f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Dec 2020 10:05:46 -0500 Subject: [PATCH 08/12] Fix imports. --- synapse/storage/background_updates.py | 2 +- synapse/storage/state.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 95a9af8079c2..29b8ca676a9b 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -464,7 +464,7 @@ async def _background_update_progress( ) def _background_update_progress_txn( - self, txn: LoggingTransaction, update_name: str, progress: JsonDict + self, txn: "LoggingTransaction", update_name: str, progress: JsonDict ) -> None: """Update the progress of a background update diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 56faca33e1f9..f9e76aba9f2a 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -29,11 +29,11 @@ from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.storage.databases import Databases from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: from synapse.app.homeserver import HomeServer + from synapse.storage.databases import Databases logger = logging.getLogger(__name__) @@ -343,7 +343,7 @@ class StateGroupStorage: """High level interface to fetching state for event. """ - def __init__(self, hs: "HomeServer", stores: Databases): + def __init__(self, hs: "HomeServer", stores: "Databases"): self.stores = stores async def get_state_group_delta( From 97dc83b968227345945851f7c19a6a3025754a73 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Dec 2020 13:44:50 -0500 Subject: [PATCH 09/12] Try to fix py3.5. --- synapse/storage/prepare_database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 09ef7991f909..68a7ba0107d6 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -19,7 +19,6 @@ import re from collections import Counter from typing import ( - Counter as CounterType, Generator, Iterable, List, @@ -29,6 +28,7 @@ ) import attr +from typing_extensions import Counter as CounterType from synapse.config.homeserver import HomeServerConfig from synapse.storage.database import LoggingDatabaseConnection From 244737beba3e16ff377451585f19dde9598f1a75 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Dec 2020 13:56:37 -0500 Subject: [PATCH 10/12] lint --- synapse/storage/prepare_database.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 68a7ba0107d6..f91a2eae7aa4 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -18,14 +18,7 @@ import os import re from collections import Counter -from typing import ( - Generator, - Iterable, - List, - Optional, - TextIO, - Tuple, -) +from typing import Generator, Iterable, List, Optional, TextIO, Tuple import attr from typing_extensions import Counter as CounterType From 3bda9277c9c47108b37e1e0c46080835eb1d793e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Dec 2020 07:12:04 -0500 Subject: [PATCH 11/12] Fix typos. Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- synapse/storage/relations.py | 2 +- synapse/storage/state.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index 03d554734fa5..2564f34b4713 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -91,7 +91,7 @@ class AggregationPaginationToken: aggregation groups, we can just use them as our pagination token. Attributes: - count: The count of relations in the boundar group. + count: The count of relations in the boundary group. stream: The MAX stream ordering in the boundary group. """ diff --git a/synapse/storage/state.py b/synapse/storage/state.py index f9e76aba9f2a..31ccbf23dc20 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -357,7 +357,7 @@ async def get_state_group_delta( Returns: A tuple of the previous group and a state map of the event IDs which - make up the delta between the old ad new state groups. + make up the delta between the old and new state groups. """ return await self.stores.state.get_state_group_delta(state_group) From b366044ace8590de38042270f7ae54bdcc2ba17c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Dec 2020 07:18:01 -0500 Subject: [PATCH 12/12] Add a missing type. --- synapse/storage/_base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f1cc33d55b96..a25c4093bce9 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -48,7 +48,11 @@ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer" self.rand = random.SystemRandom() def process_replication_rows( - self, stream_name: str, instance_name: str, token: StreamToken, rows + self, + stream_name: str, + instance_name: str, + token: StreamToken, + rows: Iterable[Any], ) -> None: pass