@@ -201,95 +201,19 @@ async def query_devices(
201
201
r [user_id ] = remote_queries [user_id ]
202
202
203
203
# Now fetch any devices that we don't have in our cache
204
- @trace
205
- async def do_remote_query (destination : str ) -> None :
206
- """This is called when we are querying the device list of a user on
207
- a remote homeserver and their device list is not in the device list
208
- cache. If we share a room with this user and we're not querying for
209
- specific user we will update the cache with their device list.
210
- """
211
-
212
- destination_query = remote_queries_not_in_cache [destination ]
213
-
214
- # We first consider whether we wish to update the device list cache with
215
- # the users device list. We want to track a user's devices when the
216
- # authenticated user shares a room with the queried user and the query
217
- # has not specified a particular device.
218
- # If we update the cache for the queried user we remove them from further
219
- # queries. We use the more efficient batched query_client_keys for all
220
- # remaining users
221
- user_ids_updated = []
222
- for (user_id , device_list ) in destination_query .items ():
223
- if user_id in user_ids_updated :
224
- continue
225
-
226
- if device_list :
227
- continue
228
-
229
- room_ids = await self .store .get_rooms_for_user (user_id )
230
- if not room_ids :
231
- continue
232
-
233
- # We've decided we're sharing a room with this user and should
234
- # probably be tracking their device lists. However, we haven't
235
- # done an initial sync on the device list so we do it now.
236
- try :
237
- if self ._is_master :
238
- user_devices = await self .device_handler .device_list_updater .user_device_resync (
239
- user_id
240
- )
241
- else :
242
- user_devices = await self ._user_device_resync_client (
243
- user_id = user_id
244
- )
245
-
246
- user_devices = user_devices ["devices" ]
247
- user_results = results .setdefault (user_id , {})
248
- for device in user_devices :
249
- user_results [device ["device_id" ]] = device ["keys" ]
250
- user_ids_updated .append (user_id )
251
- except Exception as e :
252
- failures [destination ] = _exception_to_failure (e )
253
-
254
- if len (destination_query ) == len (user_ids_updated ):
255
- # We've updated all the users in the query and we do not need to
256
- # make any further remote calls.
257
- return
258
-
259
- # Remove all the users from the query which we have updated
260
- for user_id in user_ids_updated :
261
- destination_query .pop (user_id )
262
-
263
- try :
264
- remote_result = await self .federation .query_client_keys (
265
- destination , {"device_keys" : destination_query }, timeout = timeout
266
- )
267
-
268
- for user_id , keys in remote_result ["device_keys" ].items ():
269
- if user_id in destination_query :
270
- results [user_id ] = keys
271
-
272
- if "master_keys" in remote_result :
273
- for user_id , key in remote_result ["master_keys" ].items ():
274
- if user_id in destination_query :
275
- cross_signing_keys ["master_keys" ][user_id ] = key
276
-
277
- if "self_signing_keys" in remote_result :
278
- for user_id , key in remote_result ["self_signing_keys" ].items ():
279
- if user_id in destination_query :
280
- cross_signing_keys ["self_signing_keys" ][user_id ] = key
281
-
282
- except Exception as e :
283
- failure = _exception_to_failure (e )
284
- failures [destination ] = failure
285
- set_tag ("error" , True )
286
- set_tag ("reason" , failure )
287
-
288
204
await make_deferred_yieldable (
289
205
defer .gatherResults (
290
206
[
291
- run_in_background (do_remote_query , destination )
292
- for destination in remote_queries_not_in_cache
207
+ run_in_background (
208
+ self ._query_devices_for_destination ,
209
+ results ,
210
+ cross_signing_keys ,
211
+ failures ,
212
+ destination ,
213
+ queries ,
214
+ timeout ,
215
+ )
216
+ for destination , queries in remote_queries_not_in_cache .items ()
293
217
],
294
218
consumeErrors = True ,
295
219
).addErrback (unwrapFirstError )
@@ -301,6 +225,121 @@ async def do_remote_query(destination: str) -> None:
301
225
302
226
return ret
303
227
228
+ @trace
229
+ async def _query_devices_for_destination (
230
+ self ,
231
+ results : JsonDict ,
232
+ cross_signing_keys : JsonDict ,
233
+ failures : Dict [str , JsonDict ],
234
+ destination : str ,
235
+ destination_query : Dict [str , Iterable [str ]],
236
+ timeout : int ,
237
+ ) -> None :
238
+ """This is called when we are querying the device list of a user on
239
+ a remote homeserver and their device list is not in the device list
240
+ cache. If we share a room with this user and we're not querying for
241
+ specific user we will update the cache with their device list.
242
+
243
+ Args:
244
+ results: A map from user ID to their device keys, which gets
245
+ updated with the newly fetched keys.
246
+ cross_signing_keys: Map from user ID to their cross signing keys,
247
+ which gets updated with the newly fetched keys.
248
+ failures: Map of destinations to failures that have occurred while
249
+ attempting to fetch keys.
250
+ destination: The remote server to query
251
+ destination_query: The query dict of devices to query the remote
252
+ server for.
253
+ timeout: The timeout for remote HTTP requests.
254
+ """
255
+
256
+ # We first consider whether we wish to update the device list cache with
257
+ # the users device list. We want to track a user's devices when the
258
+ # authenticated user shares a room with the queried user and the query
259
+ # has not specified a particular device.
260
+ # If we update the cache for the queried user we remove them from further
261
+ # queries. We use the more efficient batched query_client_keys for all
262
+ # remaining users
263
+ user_ids_updated = []
264
+ for (user_id , device_list ) in destination_query .items ():
265
+ if user_id in user_ids_updated :
266
+ continue
267
+
268
+ if device_list :
269
+ continue
270
+
271
+ room_ids = await self .store .get_rooms_for_user (user_id )
272
+ if not room_ids :
273
+ continue
274
+
275
+ # We've decided we're sharing a room with this user and should
276
+ # probably be tracking their device lists. However, we haven't
277
+ # done an initial sync on the device list so we do it now.
278
+ try :
279
+ if self ._is_master :
280
+ resync_results = await self .device_handler .device_list_updater .user_device_resync (
281
+ user_id
282
+ )
283
+ else :
284
+ resync_results = await self ._user_device_resync_client (
285
+ user_id = user_id
286
+ )
287
+
288
+ # Add the device keys to the results.
289
+ user_devices = resync_results ["devices" ]
290
+ user_results = results .setdefault (user_id , {})
291
+ for device in user_devices :
292
+ user_results [device ["device_id" ]] = device ["keys" ]
293
+ user_ids_updated .append (user_id )
294
+
295
+ # Add any cross signing keys to the results.
296
+ master_key = resync_results .get ("master_key" )
297
+ self_signing_key = resync_results .get ("self_signing_key" )
298
+
299
+ if master_key :
300
+ cross_signing_keys ["master_keys" ][user_id ] = master_key
301
+
302
+ if self_signing_key :
303
+ cross_signing_keys ["self_signing_keys" ][user_id ] = self_signing_key
304
+ except Exception as e :
305
+ failures [destination ] = _exception_to_failure (e )
306
+
307
+ if len (destination_query ) == len (user_ids_updated ):
308
+ # We've updated all the users in the query and we do not need to
309
+ # make any further remote calls.
310
+ return
311
+
312
+ # Remove all the users from the query which we have updated
313
+ for user_id in user_ids_updated :
314
+ destination_query .pop (user_id )
315
+
316
+ try :
317
+ remote_result = await self .federation .query_client_keys (
318
+ destination , {"device_keys" : destination_query }, timeout = timeout
319
+ )
320
+
321
+ for user_id , keys in remote_result ["device_keys" ].items ():
322
+ if user_id in destination_query :
323
+ results [user_id ] = keys
324
+
325
+ if "master_keys" in remote_result :
326
+ for user_id , key in remote_result ["master_keys" ].items ():
327
+ if user_id in destination_query :
328
+ cross_signing_keys ["master_keys" ][user_id ] = key
329
+
330
+ if "self_signing_keys" in remote_result :
331
+ for user_id , key in remote_result ["self_signing_keys" ].items ():
332
+ if user_id in destination_query :
333
+ cross_signing_keys ["self_signing_keys" ][user_id ] = key
334
+
335
+ except Exception as e :
336
+ failure = _exception_to_failure (e )
337
+ failures [destination ] = failure
338
+ set_tag ("error" , True )
339
+ set_tag ("reason" , failure )
340
+
341
+ return
342
+
304
343
async def get_cross_signing_keys_from_cache (
305
344
self , query : Iterable [str ], from_user_id : Optional [str ]
306
345
) -> Dict [str , Dict [str , dict ]]:
0 commit comments