Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 7f33527

Browse files
David Robertsonreivilibre
David Robertson
andauthored
Improve typing in user_directory files (#10891)
* Improve typing in user_directory files This makes the user_directory.py in storage pass most of mypy's checks (including `no-untyped-defs`). Unfortunately that file is in the tangled web of Store class inheritance so doesn't pass mypy at the moment. The handlers directory has already been mypyed. Co-authored-by: reivilibre <[email protected]>
1 parent e704cc2 commit 7f33527

File tree

4 files changed

+95
-37
lines changed

4 files changed

+95
-37
lines changed

changelog.d/10891.misc

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve type hinting in the user directory code.

mypy.ini

+2
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ files =
8585
tests/handlers/test_room_summary.py,
8686
tests/handlers/test_send_email.py,
8787
tests/handlers/test_sync.py,
88+
tests/handlers/test_user_directory.py,
8889
tests/rest/client/test_login.py,
8990
tests/rest/client/test_auth.py,
9091
tests/storage/test_state.py,
92+
tests/storage/test_user_directory.py,
9193
tests/util/test_itertools.py,
9294
tests/util/test_stream_change_cache.py
9395

synapse/storage/databases/main/user_directory.py

+89-35
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,28 @@
1414

1515
import logging
1616
import re
17-
from typing import Any, Dict, Iterable, Optional, Set, Tuple
17+
from typing import (
18+
TYPE_CHECKING,
19+
Dict,
20+
Iterable,
21+
List,
22+
Optional,
23+
Sequence,
24+
Set,
25+
Tuple,
26+
cast,
27+
)
28+
29+
if TYPE_CHECKING:
30+
from synapse.server import HomeServer
1831

1932
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
20-
from synapse.storage.database import DatabasePool
33+
from synapse.storage.database import DatabasePool, LoggingTransaction
2134
from synapse.storage.databases.main.state import StateFilter
2235
from synapse.storage.databases.main.state_deltas import StateDeltasStore
2336
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
24-
from synapse.types import get_domain_from_id, get_localpart_from_id
37+
from synapse.storage.types import Connection
38+
from synapse.types import JsonDict, get_domain_from_id, get_localpart_from_id
2539
from synapse.util.caches.descriptors import cached
2640

2741
logger = logging.getLogger(__name__)
@@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
3650
# add_users_who_share_private_rooms?
3751
SHARE_PRIVATE_WORKING_SET = 500
3852

39-
def __init__(self, database: DatabasePool, db_conn, hs):
53+
def __init__(
54+
self,
55+
database: DatabasePool,
56+
db_conn: Connection,
57+
hs: "HomeServer",
58+
):
4059
super().__init__(database, db_conn, hs)
4160

4261
self.server_name = hs.hostname
@@ -57,10 +76,12 @@ def __init__(self, database: DatabasePool, db_conn, hs):
5776
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
5877
)
5978

60-
async def _populate_user_directory_createtables(self, progress, batch_size):
79+
async def _populate_user_directory_createtables(
80+
self, progress: JsonDict, batch_size: int
81+
) -> int:
6182

6283
# Get all the rooms that we want to process.
63-
def _make_staging_area(txn):
84+
def _make_staging_area(txn: LoggingTransaction) -> None:
6485
sql = (
6586
"CREATE TABLE IF NOT EXISTS "
6687
+ TEMP_TABLE
@@ -110,16 +131,20 @@ def _make_staging_area(txn):
110131
)
111132
return 1
112133

113-
async def _populate_user_directory_cleanup(self, progress, batch_size):
134+
async def _populate_user_directory_cleanup(
135+
self,
136+
progress: JsonDict,
137+
batch_size: int,
138+
) -> int:
114139
"""
115140
Update the user directory stream position, then clean up the old tables.
116141
"""
117142
position = await self.db_pool.simple_select_one_onecol(
118-
TEMP_TABLE + "_position", None, "position"
143+
TEMP_TABLE + "_position", {}, "position"
119144
)
120145
await self.update_user_directory_stream_pos(position)
121146

122-
def _delete_staging_area(txn):
147+
def _delete_staging_area(txn: LoggingTransaction) -> None:
123148
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
124149
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
125150
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
@@ -133,18 +158,32 @@ def _delete_staging_area(txn):
133158
)
134159
return 1
135160

136-
async def _populate_user_directory_process_rooms(self, progress, batch_size):
161+
async def _populate_user_directory_process_rooms(
162+
self, progress: JsonDict, batch_size: int
163+
) -> int:
137164
"""
165+
Rescan the state of all rooms so we can track
166+
167+
- who's in a public room;
168+
- which local users share a private room with other users (local
169+
and remote); and
170+
- who should be in the user_directory.
171+
138172
Args:
139173
progress (dict)
140174
batch_size (int): Maximum number of state events to process
141175
per cycle.
176+
177+
Returns:
178+
number of events processed.
142179
"""
143180
# If we don't have progress filed, delete everything.
144181
if not progress:
145182
await self.delete_all_from_user_dir()
146183

147-
def _get_next_batch(txn):
184+
def _get_next_batch(
185+
txn: LoggingTransaction,
186+
) -> Optional[Sequence[Tuple[str, int]]]:
148187
# Only fetch 250 rooms, so we don't fetch too many at once, even
149188
# if those 250 rooms have less than batch_size state events.
150189
sql = """
@@ -155,15 +194,17 @@ def _get_next_batch(txn):
155194
TEMP_TABLE + "_rooms",
156195
)
157196
txn.execute(sql)
158-
rooms_to_work_on = txn.fetchall()
197+
rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
159198

160199
if not rooms_to_work_on:
161200
return None
162201

163202
# Get how many are left to process, so we can give status on how
164203
# far we are in processing
165204
txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
166-
progress["remaining"] = txn.fetchone()[0]
205+
result = txn.fetchone()
206+
assert result is not None
207+
progress["remaining"] = result[0]
167208

168209
return rooms_to_work_on
169210

@@ -261,29 +302,33 @@ def _get_next_batch(txn):
261302

262303
return processed_event_count
263304

264-
async def _populate_user_directory_process_users(self, progress, batch_size):
305+
async def _populate_user_directory_process_users(
306+
self, progress: JsonDict, batch_size: int
307+
) -> int:
265308
"""
266309
Add all local users to the user directory.
267310
"""
268311

269-
def _get_next_batch(txn):
312+
def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
270313
sql = "SELECT user_id FROM %s LIMIT %s" % (
271314
TEMP_TABLE + "_users",
272315
str(batch_size),
273316
)
274317
txn.execute(sql)
275-
users_to_work_on = txn.fetchall()
318+
user_result = cast(List[Tuple[str]], txn.fetchall())
276319

277-
if not users_to_work_on:
320+
if not user_result:
278321
return None
279322

280-
users_to_work_on = [x[0] for x in users_to_work_on]
323+
users_to_work_on = [x[0] for x in user_result]
281324

282325
# Get how many are left to process, so we can give status on how
283326
# far we are in processing
284327
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
285328
txn.execute(sql)
286-
progress["remaining"] = txn.fetchone()[0]
329+
count_result = txn.fetchone()
330+
assert count_result is not None
331+
progress["remaining"] = count_result[0]
287332

288333
return users_to_work_on
289334

@@ -324,7 +369,7 @@ def _get_next_batch(txn):
324369

325370
return len(users_to_work_on)
326371

327-
async def is_room_world_readable_or_publicly_joinable(self, room_id):
372+
async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
328373
"""Check if the room is either world_readable or publically joinable"""
329374

330375
# Create a state filter that only queries join and history state event
@@ -368,7 +413,7 @@ async def update_profile_in_user_dir(
368413
if not isinstance(avatar_url, str):
369414
avatar_url = None
370415

371-
def _update_profile_in_user_dir_txn(txn):
416+
def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
372417
self.db_pool.simple_upsert_txn(
373418
txn,
374419
table="user_directory",
@@ -435,7 +480,7 @@ async def add_users_who_share_private_room(
435480
for user_id, other_user_id in user_id_tuples
436481
],
437482
value_names=(),
438-
value_values=None,
483+
value_values=(),
439484
desc="add_users_who_share_room",
440485
)
441486

@@ -454,14 +499,14 @@ async def add_users_in_public_rooms(
454499
key_names=["user_id", "room_id"],
455500
key_values=[(user_id, room_id) for user_id in user_ids],
456501
value_names=(),
457-
value_values=None,
502+
value_values=(),
458503
desc="add_users_in_public_rooms",
459504
)
460505

461506
async def delete_all_from_user_dir(self) -> None:
462507
"""Delete the entire user directory"""
463508

464-
def _delete_all_from_user_dir_txn(txn):
509+
def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
465510
txn.execute("DELETE FROM user_directory")
466511
txn.execute("DELETE FROM user_directory_search")
467512
txn.execute("DELETE FROM users_in_public_rooms")
@@ -473,7 +518,7 @@ def _delete_all_from_user_dir_txn(txn):
473518
)
474519

475520
@cached()
476-
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
521+
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
477522
return await self.db_pool.simple_select_one(
478523
table="user_directory",
479524
keyvalues={"user_id": user_id},
@@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
497542
# add_users_who_share_private_rooms?
498543
SHARE_PRIVATE_WORKING_SET = 500
499544

500-
def __init__(self, database: DatabasePool, db_conn, hs):
545+
def __init__(
546+
self,
547+
database: DatabasePool,
548+
db_conn: Connection,
549+
hs: "HomeServer",
550+
) -> None:
501551
super().__init__(database, db_conn, hs)
502552

503553
self._prefer_local_users_in_search = (
@@ -506,7 +556,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
506556
self._server_name = hs.config.server.server_name
507557

508558
async def remove_from_user_dir(self, user_id: str) -> None:
509-
def _remove_from_user_dir_txn(txn):
559+
def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
510560
self.db_pool.simple_delete_txn(
511561
txn, table="user_directory", keyvalues={"user_id": user_id}
512562
)
@@ -532,7 +582,7 @@ def _remove_from_user_dir_txn(txn):
532582
"remove_from_user_dir", _remove_from_user_dir_txn
533583
)
534584

535-
async def get_users_in_dir_due_to_room(self, room_id):
585+
async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
536586
"""Get all user_ids that are in the room directory because they're
537587
in the given room_id
538588
"""
@@ -565,7 +615,7 @@ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
565615
room_id
566616
"""
567617

568-
def _remove_user_who_share_room_txn(txn):
618+
def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
569619
self.db_pool.simple_delete_txn(
570620
txn,
571621
table="users_who_share_private_rooms",
@@ -586,7 +636,7 @@ def _remove_user_who_share_room_txn(txn):
586636
"remove_user_who_share_room", _remove_user_who_share_room_txn
587637
)
588638

589-
async def get_user_dir_rooms_user_is_in(self, user_id):
639+
async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
590640
"""
591641
Returns the rooms that a user is in.
592642
@@ -628,7 +678,9 @@ async def get_shared_rooms_for_users(
628678
A set of room ID's that the users share.
629679
"""
630680

631-
def _get_shared_rooms_for_users_txn(txn):
681+
def _get_shared_rooms_for_users_txn(
682+
txn: LoggingTransaction,
683+
) -> List[Dict[str, str]]:
632684
txn.execute(
633685
"""
634686
SELECT p1.room_id
@@ -669,7 +721,9 @@ async def get_user_directory_stream_pos(self) -> Optional[int]:
669721
desc="get_user_directory_stream_pos",
670722
)
671723

672-
async def search_user_dir(self, user_id, search_term, limit):
724+
async def search_user_dir(
725+
self, user_id: str, search_term: str, limit: int
726+
) -> JsonDict:
673727
"""Searches for users in directory
674728
675729
Returns:
@@ -705,7 +759,7 @@ async def search_user_dir(self, user_id, search_term, limit):
705759
# We allow manipulating the ranking algorithm by injecting statements
706760
# based on config options.
707761
additional_ordering_statements = []
708-
ordering_arguments = ()
762+
ordering_arguments: Tuple[str, ...] = ()
709763

710764
if isinstance(self.database_engine, PostgresEngine):
711765
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -811,7 +865,7 @@ async def search_user_dir(self, user_id, search_term, limit):
811865
return {"limited": limited, "results": results}
812866

813867

814-
def _parse_query_sqlite(search_term):
868+
def _parse_query_sqlite(search_term: str) -> str:
815869
"""Takes a plain unicode string from the user and converts it into a form
816870
that can be passed to database.
817871
We use this so that we can add prefix matching, which isn't something
@@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
826880
return " & ".join("(%s* OR %s)" % (result, result) for result in results)
827881

828882

829-
def _parse_query_postgres(search_term):
883+
def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
830884
"""Takes a plain unicode string from the user and converts it into a form
831885
that can be passed to database.
832886
We use this so that we can add prefix matching, which isn't something

tests/handlers/test_user_directory.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import List, Tuple
1415
from unittest.mock import Mock
1516
from urllib.parse import quote
1617

@@ -325,7 +326,7 @@ def _compress_shared(self, shared):
325326
r.add((i["user_id"], i["other_user_id"], i["room_id"]))
326327
return r
327328

328-
def get_users_in_public_rooms(self):
329+
def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
329330
r = self.get_success(
330331
self.store.db_pool.simple_select_list(
331332
"users_in_public_rooms", None, ("user_id", "room_id")
@@ -336,7 +337,7 @@ def get_users_in_public_rooms(self):
336337
retval.append((i["user_id"], i["room_id"]))
337338
return retval
338339

339-
def get_users_who_share_private_rooms(self):
340+
def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
340341
return self.get_success(
341342
self.store.db_pool.simple_select_list(
342343
"users_who_share_private_rooms",

0 commit comments

Comments
 (0)