14
14
15
15
import logging
16
16
import re
17
- from typing import Any , Dict , Iterable , Optional , Set , Tuple
17
+ from typing import (
18
+ TYPE_CHECKING ,
19
+ Dict ,
20
+ Iterable ,
21
+ List ,
22
+ Optional ,
23
+ Sequence ,
24
+ Set ,
25
+ Tuple ,
26
+ cast ,
27
+ )
28
+
29
+ if TYPE_CHECKING :
30
+ from synapse .server import HomeServer
18
31
19
32
from synapse .api .constants import EventTypes , HistoryVisibility , JoinRules
20
- from synapse .storage .database import DatabasePool
33
+ from synapse .storage .database import DatabasePool , LoggingTransaction
21
34
from synapse .storage .databases .main .state import StateFilter
22
35
from synapse .storage .databases .main .state_deltas import StateDeltasStore
23
36
from synapse .storage .engines import PostgresEngine , Sqlite3Engine
24
- from synapse .types import get_domain_from_id , get_localpart_from_id
37
+ from synapse .storage .types import Connection
38
+ from synapse .types import JsonDict , get_domain_from_id , get_localpart_from_id
25
39
from synapse .util .caches .descriptors import cached
26
40
27
41
logger = logging .getLogger (__name__ )
@@ -36,7 +50,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
36
50
# add_users_who_share_private_rooms?
37
51
SHARE_PRIVATE_WORKING_SET = 500
38
52
39
- def __init__ (self , database : DatabasePool , db_conn , hs ):
53
+ def __init__ (
54
+ self ,
55
+ database : DatabasePool ,
56
+ db_conn : Connection ,
57
+ hs : "HomeServer" ,
58
+ ):
40
59
super ().__init__ (database , db_conn , hs )
41
60
42
61
self .server_name = hs .hostname
@@ -57,10 +76,12 @@ def __init__(self, database: DatabasePool, db_conn, hs):
57
76
"populate_user_directory_cleanup" , self ._populate_user_directory_cleanup
58
77
)
59
78
60
- async def _populate_user_directory_createtables (self , progress , batch_size ):
79
+ async def _populate_user_directory_createtables (
80
+ self , progress : JsonDict , batch_size : int
81
+ ) -> int :
61
82
62
83
# Get all the rooms that we want to process.
63
- def _make_staging_area (txn ) :
84
+ def _make_staging_area (txn : LoggingTransaction ) -> None :
64
85
sql = (
65
86
"CREATE TABLE IF NOT EXISTS "
66
87
+ TEMP_TABLE
@@ -110,16 +131,20 @@ def _make_staging_area(txn):
110
131
)
111
132
return 1
112
133
113
- async def _populate_user_directory_cleanup (self , progress , batch_size ):
134
+ async def _populate_user_directory_cleanup (
135
+ self ,
136
+ progress : JsonDict ,
137
+ batch_size : int ,
138
+ ) -> int :
114
139
"""
115
140
Update the user directory stream position, then clean up the old tables.
116
141
"""
117
142
position = await self .db_pool .simple_select_one_onecol (
118
- TEMP_TABLE + "_position" , None , "position"
143
+ TEMP_TABLE + "_position" , {} , "position"
119
144
)
120
145
await self .update_user_directory_stream_pos (position )
121
146
122
- def _delete_staging_area (txn ) :
147
+ def _delete_staging_area (txn : LoggingTransaction ) -> None :
123
148
txn .execute ("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms" )
124
149
txn .execute ("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users" )
125
150
txn .execute ("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position" )
@@ -133,18 +158,32 @@ def _delete_staging_area(txn):
133
158
)
134
159
return 1
135
160
136
- async def _populate_user_directory_process_rooms (self , progress , batch_size ):
161
+ async def _populate_user_directory_process_rooms (
162
+ self , progress : JsonDict , batch_size : int
163
+ ) -> int :
137
164
"""
165
+ Rescan the state of all rooms so we can track
166
+
167
+ - who's in a public room;
168
+ - which local users share a private room with other users (local
169
+ and remote); and
170
+ - who should be in the user_directory.
171
+
138
172
Args:
139
173
progress (dict)
140
174
batch_size (int): Maximum number of state events to process
141
175
per cycle.
176
+
177
+ Returns:
178
+ number of events processed.
142
179
"""
143
180
# If we don't have progress filed, delete everything.
144
181
if not progress :
145
182
await self .delete_all_from_user_dir ()
146
183
147
- def _get_next_batch (txn ):
184
+ def _get_next_batch (
185
+ txn : LoggingTransaction ,
186
+ ) -> Optional [Sequence [Tuple [str , int ]]]:
148
187
# Only fetch 250 rooms, so we don't fetch too many at once, even
149
188
# if those 250 rooms have less than batch_size state events.
150
189
sql = """
@@ -155,15 +194,17 @@ def _get_next_batch(txn):
155
194
TEMP_TABLE + "_rooms" ,
156
195
)
157
196
txn .execute (sql )
158
- rooms_to_work_on = txn .fetchall ()
197
+ rooms_to_work_on = cast ( List [ Tuple [ str , int ]], txn .fetchall () )
159
198
160
199
if not rooms_to_work_on :
161
200
return None
162
201
163
202
# Get how many are left to process, so we can give status on how
164
203
# far we are in processing
165
204
txn .execute ("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms" )
166
- progress ["remaining" ] = txn .fetchone ()[0 ]
205
+ result = txn .fetchone ()
206
+ assert result is not None
207
+ progress ["remaining" ] = result [0 ]
167
208
168
209
return rooms_to_work_on
169
210
@@ -261,29 +302,33 @@ def _get_next_batch(txn):
261
302
262
303
return processed_event_count
263
304
264
- async def _populate_user_directory_process_users (self , progress , batch_size ):
305
+ async def _populate_user_directory_process_users (
306
+ self , progress : JsonDict , batch_size : int
307
+ ) -> int :
265
308
"""
266
309
Add all local users to the user directory.
267
310
"""
268
311
269
- def _get_next_batch (txn ) :
312
+ def _get_next_batch (txn : LoggingTransaction ) -> Optional [ List [ str ]] :
270
313
sql = "SELECT user_id FROM %s LIMIT %s" % (
271
314
TEMP_TABLE + "_users" ,
272
315
str (batch_size ),
273
316
)
274
317
txn .execute (sql )
275
- users_to_work_on = txn .fetchall ()
318
+ user_result = cast ( List [ Tuple [ str ]], txn .fetchall () )
276
319
277
- if not users_to_work_on :
320
+ if not user_result :
278
321
return None
279
322
280
- users_to_work_on = [x [0 ] for x in users_to_work_on ]
323
+ users_to_work_on = [x [0 ] for x in user_result ]
281
324
282
325
# Get how many are left to process, so we can give status on how
283
326
# far we are in processing
284
327
sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
285
328
txn .execute (sql )
286
- progress ["remaining" ] = txn .fetchone ()[0 ]
329
+ count_result = txn .fetchone ()
330
+ assert count_result is not None
331
+ progress ["remaining" ] = count_result [0 ]
287
332
288
333
return users_to_work_on
289
334
@@ -324,7 +369,7 @@ def _get_next_batch(txn):
324
369
325
370
return len (users_to_work_on )
326
371
327
- async def is_room_world_readable_or_publicly_joinable (self , room_id ) :
372
+ async def is_room_world_readable_or_publicly_joinable (self , room_id : str ) -> bool :
328
373
"""Check if the room is either world_readable or publically joinable"""
329
374
330
375
# Create a state filter that only queries join and history state event
@@ -368,7 +413,7 @@ async def update_profile_in_user_dir(
368
413
if not isinstance (avatar_url , str ):
369
414
avatar_url = None
370
415
371
- def _update_profile_in_user_dir_txn (txn ) :
416
+ def _update_profile_in_user_dir_txn (txn : LoggingTransaction ) -> None :
372
417
self .db_pool .simple_upsert_txn (
373
418
txn ,
374
419
table = "user_directory" ,
@@ -435,7 +480,7 @@ async def add_users_who_share_private_room(
435
480
for user_id , other_user_id in user_id_tuples
436
481
],
437
482
value_names = (),
438
- value_values = None ,
483
+ value_values = () ,
439
484
desc = "add_users_who_share_room" ,
440
485
)
441
486
@@ -454,14 +499,14 @@ async def add_users_in_public_rooms(
454
499
key_names = ["user_id" , "room_id" ],
455
500
key_values = [(user_id , room_id ) for user_id in user_ids ],
456
501
value_names = (),
457
- value_values = None ,
502
+ value_values = () ,
458
503
desc = "add_users_in_public_rooms" ,
459
504
)
460
505
461
506
async def delete_all_from_user_dir (self ) -> None :
462
507
"""Delete the entire user directory"""
463
508
464
- def _delete_all_from_user_dir_txn (txn ) :
509
+ def _delete_all_from_user_dir_txn (txn : LoggingTransaction ) -> None :
465
510
txn .execute ("DELETE FROM user_directory" )
466
511
txn .execute ("DELETE FROM user_directory_search" )
467
512
txn .execute ("DELETE FROM users_in_public_rooms" )
@@ -473,7 +518,7 @@ def _delete_all_from_user_dir_txn(txn):
473
518
)
474
519
475
520
@cached ()
476
- async def get_user_in_directory (self , user_id : str ) -> Optional [Dict [str , Any ]]:
521
+ async def get_user_in_directory (self , user_id : str ) -> Optional [Dict [str , str ]]:
477
522
return await self .db_pool .simple_select_one (
478
523
table = "user_directory" ,
479
524
keyvalues = {"user_id" : user_id },
@@ -497,7 +542,12 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
497
542
# add_users_who_share_private_rooms?
498
543
SHARE_PRIVATE_WORKING_SET = 500
499
544
500
- def __init__ (self , database : DatabasePool , db_conn , hs ):
545
+ def __init__ (
546
+ self ,
547
+ database : DatabasePool ,
548
+ db_conn : Connection ,
549
+ hs : "HomeServer" ,
550
+ ) -> None :
501
551
super ().__init__ (database , db_conn , hs )
502
552
503
553
self ._prefer_local_users_in_search = (
@@ -506,7 +556,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
506
556
self ._server_name = hs .config .server .server_name
507
557
508
558
async def remove_from_user_dir (self , user_id : str ) -> None :
509
- def _remove_from_user_dir_txn (txn ) :
559
+ def _remove_from_user_dir_txn (txn : LoggingTransaction ) -> None :
510
560
self .db_pool .simple_delete_txn (
511
561
txn , table = "user_directory" , keyvalues = {"user_id" : user_id }
512
562
)
@@ -532,7 +582,7 @@ def _remove_from_user_dir_txn(txn):
532
582
"remove_from_user_dir" , _remove_from_user_dir_txn
533
583
)
534
584
535
- async def get_users_in_dir_due_to_room (self , room_id ) :
585
+ async def get_users_in_dir_due_to_room (self , room_id : str ) -> Set [ str ] :
536
586
"""Get all user_ids that are in the room directory because they're
537
587
in the given room_id
538
588
"""
@@ -565,7 +615,7 @@ async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
565
615
room_id
566
616
"""
567
617
568
- def _remove_user_who_share_room_txn (txn ) :
618
+ def _remove_user_who_share_room_txn (txn : LoggingTransaction ) -> None :
569
619
self .db_pool .simple_delete_txn (
570
620
txn ,
571
621
table = "users_who_share_private_rooms" ,
@@ -586,7 +636,7 @@ def _remove_user_who_share_room_txn(txn):
586
636
"remove_user_who_share_room" , _remove_user_who_share_room_txn
587
637
)
588
638
589
- async def get_user_dir_rooms_user_is_in (self , user_id ) :
639
+ async def get_user_dir_rooms_user_is_in (self , user_id : str ) -> List [ str ] :
590
640
"""
591
641
Returns the rooms that a user is in.
592
642
@@ -628,7 +678,9 @@ async def get_shared_rooms_for_users(
628
678
A set of room ID's that the users share.
629
679
"""
630
680
631
- def _get_shared_rooms_for_users_txn (txn ):
681
+ def _get_shared_rooms_for_users_txn (
682
+ txn : LoggingTransaction ,
683
+ ) -> List [Dict [str , str ]]:
632
684
txn .execute (
633
685
"""
634
686
SELECT p1.room_id
@@ -669,7 +721,9 @@ async def get_user_directory_stream_pos(self) -> Optional[int]:
669
721
desc = "get_user_directory_stream_pos" ,
670
722
)
671
723
672
- async def search_user_dir (self , user_id , search_term , limit ):
724
+ async def search_user_dir (
725
+ self , user_id : str , search_term : str , limit : int
726
+ ) -> JsonDict :
673
727
"""Searches for users in directory
674
728
675
729
Returns:
@@ -705,7 +759,7 @@ async def search_user_dir(self, user_id, search_term, limit):
705
759
# We allow manipulating the ranking algorithm by injecting statements
706
760
# based on config options.
707
761
additional_ordering_statements = []
708
- ordering_arguments = ()
762
+ ordering_arguments : Tuple [ str , ...] = ()
709
763
710
764
if isinstance (self .database_engine , PostgresEngine ):
711
765
full_query , exact_query , prefix_query = _parse_query_postgres (search_term )
@@ -811,7 +865,7 @@ async def search_user_dir(self, user_id, search_term, limit):
811
865
return {"limited" : limited , "results" : results }
812
866
813
867
814
- def _parse_query_sqlite (search_term ) :
868
+ def _parse_query_sqlite (search_term : str ) -> str :
815
869
"""Takes a plain unicode string from the user and converts it into a form
816
870
that can be passed to database.
817
871
We use this so that we can add prefix matching, which isn't something
@@ -826,7 +880,7 @@ def _parse_query_sqlite(search_term):
826
880
return " & " .join ("(%s* OR %s)" % (result , result ) for result in results )
827
881
828
882
829
- def _parse_query_postgres (search_term ) :
883
+ def _parse_query_postgres (search_term : str ) -> Tuple [ str , str , str ] :
830
884
"""Takes a plain unicode string from the user and converts it into a form
831
885
that can be passed to database.
832
886
We use this so that we can add prefix matching, which isn't something
0 commit comments