diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9062fac91a4..d5a8bc01266 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -23,7 +23,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Set, Tuple from synapse.api import errors -from synapse.api.constants import EduTypes, EventTypes +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( Codes, FederationDeniedError, @@ -33,11 +33,13 @@ SynapseError, ) from synapse.logging.opentracing import log_kv, set_tag, trace +import synapse.metrics from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo +from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import ( JsonDict, JsonMapping, @@ -54,7 +56,7 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.cancellation import cancellable -from synapse.util.metrics import measure_func +from synapse.util.metrics import measure_func, Measure from synapse.util.retryutils import ( NotRetryingDestination, filter_destinations_by_retry_limiter, @@ -428,6 +430,7 @@ def __init__(self, hs: "HomeServer"): self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() self.db_pool = hs.get_datastores().main.db_pool + self._is_processing = False self.device_list_updater = DeviceListUpdater(hs, self) @@ -461,6 +464,145 @@ def __init__(self, hs: "HomeServer"): self._delete_stale_devices, ) + # Listen for state delta updates. We do this so we can send device list updates on room join + # to remote servers. We do not remember where we got up to before, as we only need to send + # these updates on a best-effort basis, as they quickly heal due to /keys/query requests. + # We want to send device list updates eagerly to improve our robustness on unreliable + # networks. + # See https://github.com/element-hq/synapse/issues/11374#issuecomment-1908396300 + self._event_pos = self.store.get_room_max_stream_ordering() + self._event_processing = False + self.notifier.add_replication_callback(self.notify_new_event) + + def notify_new_event(self) -> None: + """Called when there may be more deltas to process""" + if self._event_processing: + return + + self._event_processing = True + + async def process() -> None: + try: + await self._unsafe_process() + finally: + self._event_processing = False + + run_as_background_process("device.notify_new_event", process) + + async def _unsafe_process(self) -> None: + # Loop round handling deltas until we're up to date + while True: + with Measure(self.clock, "device_list_delta"): + room_max_stream_ordering = self.store.get_room_max_stream_ordering() + if self._event_pos == room_max_stream_ordering: + return + + logger.debug( + "Processing device list stats %s->%s", + self._event_pos, + room_max_stream_ordering, + ) + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( + self._event_pos, room_max_stream_ordering + ) + + # We may get multiple deltas for different rooms, but we want to + # handle them on a room by room basis, so we batch them up by + # room. + deltas_by_room: Dict[str, List[StateDelta]] = {} + for delta in deltas: + deltas_by_room.setdefault(delta.room_id, []).append(delta) + + for room_id, deltas_for_room in deltas_by_room.items(): + newly_joined_local_users = await self._get_newly_joined_local_users(room_id, deltas_for_room) + if not newly_joined_local_users: + continue + # if a local user newly joins a room, we want to broadcast their device lists to + # federated servers in that room, if we haven't already. + hosts = await self.store.get_current_hosts_in_room(room_id) + # filter out ourselves + hosts = [h for h in hosts if not self.hs.is_mine_server_name(h)] + if len(hosts) == 0: + continue + # broadcast device lists for these users in the room + num_pokes = 0 + for user_id in newly_joined_local_users: + # the join is for the user, we need to send device list updates for all + # their devices. + device_ids = await self.store.get_devices_by_user(user_id) + for device_id in device_ids.keys(): + num_pokes += 1 + await self.store.add_device_list_outbound_pokes( + user_id=user_id, + device_id=device_id, + room_id=room_id, + hosts=hosts, + context=None, + ) + logger.info( + "Found %d hosts to send device list updates to for a new room join, " + + "added %s device_list_outbound_pokes", + len(hosts), num_pokes, + ) + + # Notify things that device lists need to be sent out. + self.notifier.notify_replication() + await self.federation_sender.send_device_messages( + hosts, immediate=False + ) + + self._event_pos = max_pos + + # Expose current event processing position to prometheus + synapse.metrics.event_processing_positions.labels("device").set( + max_pos + ) + + async def _get_newly_joined_local_users(self, room_id: str, deltas: List[StateDelta]) -> Optional[Set[str]]: + """Process current state deltas for the room to find new joins that need + to be handled. + """ + newly_joined_local_users = set() + + for delta in deltas: + assert room_id == delta.room_id + logger.debug( + "device.handling: %r %r, %s", delta.event_type, delta.state_key, delta.event_id + ) + # Drop any event that isn't a membership join + if delta.event_type != EventTypes.Member: + continue + if delta.event_id is None: + # state has been deleted, so this is not a join. We only care about joins. + continue + # Drop any event that is for a non-local user + membership_change_user = UserID.from_string(delta.state_key) + if not self.hs.is_mine(membership_change_user): + continue + event = await self.store.get_event(delta.event_id, allow_none=True) + if not event or event.content.get("membership") != Membership.JOIN: + # We only care about joins + continue + if delta.prev_event_id: + prev_event = await self.store.get_event( + delta.prev_event_id, allow_none=True + ) + if ( + prev_event + and prev_event.content.get("membership") == Membership.JOIN + ): + # Ignore changes to join events. + continue + newly_joined_local_users.add(delta.state_key) + + if not newly_joined_local_users: + # If nobody has joined then there's nothing to do. + return + return newly_joined_local_users + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length.