Compare commits

...

5 Commits

Author SHA1 Message Date
Erik Johnston
6a0d2dc6fc Only check for all rooms if not outbound poke 2024-05-18 16:12:46 +01:00
Erik Johnston
e6d3d808aa Newsfile 2024-05-18 12:33:46 +01:00
Erik Johnston
cf474a094f Add stream change cache for device lists in room 2024-05-18 12:33:26 +01:00
Erik Johnston
5b2b3120c2 Cap the top stream ID when fetching changed devices 2024-05-18 12:27:27 +01:00
Erik Johnston
bec0313e1b Improve perf of sync device lists (#17191)
It's almost always more efficient to query the rooms that have device
list changes, rather than looking at the list of all users whose devices
have changed and then look for shared rooms.
2024-05-18 12:27:12 +01:00
5 changed files with 103 additions and 62 deletions

1
changelog.d/17216.misc Normal file
View File

@@ -0,0 +1 @@
Improve performance of calculating device lists changes in `/sync`.

View File

@@ -159,20 +159,32 @@ class DeviceWorkerHandler:
@cancellable
async def get_device_changes_in_shared_rooms(
self, user_id: str, room_ids: StrCollection, from_token: StreamToken
self,
user_id: str,
room_ids: StrCollection,
from_token: StreamToken,
now_token: Optional[StreamToken] = None,
) -> Set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
now_device_lists_key = self.store.get_device_stream_token()
if now_token:
now_device_lists_key = now_token.device_list_key
changed_users = await self.store.get_device_list_changes_in_rooms(
room_ids, from_token.device_list_key
room_ids,
from_token.device_list_key,
now_device_lists_key,
)
if changed_users is not None:
# We also check if the given user has changed their device. If
# they're in no rooms then the above query won't include them.
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, [user_id]
from_token.device_list_key,
[user_id],
to_key=now_device_lists_key,
)
changed_users.update(changed)
return changed_users
@@ -190,7 +202,9 @@ class DeviceWorkerHandler:
tracked_users.add(user_id)
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
from_token.device_list_key,
tracked_users,
to_key=now_device_lists_key,
)
return changed

View File

@@ -1886,38 +1886,14 @@ class SyncHandler:
# Step 1a, check for changes in devices of users we share a room
# with
#
# We do this in two different ways depending on what we have cached.
# If we already have a list of all the user that have changed since
# the last sync then it's likely more efficient to compare the rooms
# they're in with the rooms the syncing user is in.
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
if cache_result.hit:
changed_users = cache_result.entities
result = await self.store.get_rooms_for_users(changed_users)
for changed_user_id, entries in result.items():
# Check if the changed user shares any rooms with the user,
# or if the changed user is the syncing user (as we always
# want to include device list updates of their own devices).
if user_id == changed_user_id or any(
rid in joined_room_ids for rid in entries
):
users_that_have_changed.add(changed_user_id)
else:
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
from_token=since_token,
)
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
from_token=since_token,
now_token=sync_result_builder.now_token,
)
)
# Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:

View File

@@ -112,6 +112,15 @@ class ReplicationDataHandler:
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME:
if any(row.entity.startswith("@") and not row.is_signature for row in rows):
prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.
@@ -146,12 +155,6 @@ class ReplicationDataHandler:
StreamKeyType.TO_DEVICE, token, users=entities
)
elif stream_name == DeviceListsStream.NAME:
all_room_ids: Set[str] = set()
for row in rows:
if row.entity.startswith("@") and not row.is_signature:
room_ids = await self.store.get_rooms_for_user(row.entity)
all_room_ids.update(room_ids)
# `all_room_ids` can be large, so let's wake up those streams in batches
for batched_room_ids in batch_iter(all_room_ids, 100):
self.notifier.on_new_event(

View File

@@ -70,10 +70,7 @@ from synapse.types import (
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import (
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
@@ -132,6 +129,20 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=device_list_prefill,
)
device_list_room_prefill, min_device_list_room_id = self.db_pool.get_cache_dict(
db_conn,
"device_lists_changes_in_room",
entity_column="room_id",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_room_stream_cache = StreamChangeCache(
"DeviceListRoomStreamChangeCache",
min_device_list_room_id,
prefilled_cache=device_list_room_prefill,
)
(
user_signature_stream_prefill,
user_signature_stream_list_id,
@@ -209,6 +220,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
row.entity, token
)
def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
) -> None:
"Record that device lists have changed in rooms"
for room_id in room_ids:
self._device_list_room_stream_cache.entity_has_changed(room_id, token)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
@@ -832,16 +850,6 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
return {device[0]: db_to_json(device[1]) for device in devices}
def get_cached_device_list_changes(
self,
from_key: int,
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""
return self._device_list_stream_cache.get_all_entities_changed(from_key)
@cancellable
async def get_all_devices_changed(
self,
@@ -1457,7 +1465,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_device_list_changes_in_rooms(
self, room_ids: Collection[str], from_id: int
self, room_ids: Collection[str], from_id: int, to_id: int
) -> Optional[Set[str]]:
"""Return the set of users whose devices have changed in the given rooms
since the given stream ID.
@@ -1473,9 +1481,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if min_stream_id > from_id:
return None
changed_room_ids = self._device_list_room_stream_cache.get_entities_changed(
room_ids, from_id
)
if not changed_room_ids:
return set()
sql = """
SELECT DISTINCT user_id FROM device_lists_changes_in_room
WHERE {clause} AND stream_id >= ?
WHERE {clause} AND stream_id > ? AND stream_id <= ?
"""
def _get_device_list_changes_in_rooms_txn(
@@ -1487,11 +1501,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return {user_id for user_id, in txn}
changes = set()
for chunk in batch_iter(room_ids, 1000):
for chunk in batch_iter(changed_room_ids, 1000):
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", chunk
)
args.append(from_id)
args.append(to_id)
changes |= await self.db_pool.runInteraction(
"get_device_list_changes_in_rooms",
@@ -1502,6 +1517,34 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return changes
async def get_all_device_list_changes(self, from_id: int, to_id: int) -> Set[str]:
"""Return the set of rooms where devices have changed since the given
stream ID.
Will raise an exception if the given stream ID is too old.
"""
min_stream_id = await self._get_min_device_lists_changes_in_room()
if min_stream_id > from_id:
raise Exception("stream ID is too old")
sql = """
SELECT DISTINCT room_id FROM device_lists_changes_in_room
WHERE stream_id > ? AND stream_id <= ?
"""
def _get_all_device_list_changes_txn(
txn: LoggingTransaction,
) -> Set[str]:
txn.execute(sql, (from_id, to_id))
return {room_id for room_id, in txn}
return await self.db_pool.runInteraction(
"get_all_device_list_changes",
_get_all_device_list_changes_txn,
)
async def get_device_list_changes_in_room(
self, room_id: str, min_stream_id: int
) -> Collection[Tuple[str, str]]:
@@ -1962,8 +2005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
async def add_device_change_to_streams(
self,
user_id: str,
device_ids: Collection[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -2122,8 +2165,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self,
txn: LoggingTransaction,
user_id: str,
device_ids: Iterable[str],
room_ids: Collection[str],
device_ids: StrCollection,
room_ids: StrCollection,
stream_ids: List[int],
context: Dict[str, str],
) -> None:
@@ -2161,6 +2204,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
txn.call_after(
self.device_lists_in_rooms_have_changed, room_ids, max(stream_ids)
)
async def get_uncoverted_outbound_room_pokes(
self, start_stream_id: int, start_room_id: str, limit: int = 10
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]: