Compare commits

...

4 Commits

Author SHA1 Message Date
Erik Johnston
99cf1e7c1c Faster sliding sync sorting 2024-07-11 15:40:46 +01:00
Erik Johnston
4ca13ce0dd Handle to-device extensions to Sliding Sync (#17416)
Implements MSC3885

---------

Co-authored-by: Eric Eastwood <eric.eastwood@beta.gouv.fr>
2024-07-10 11:58:42 +01:00
Quentin Gliech
8e229535fa Merge branch 'release-v1.111' into develop 2024-07-10 11:36:07 +02:00
Eric Eastwood
1cf3ff6b40 Add rooms name and avatar to Sliding Sync /sync (#17418)
Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync
2024-07-09 12:26:45 -05:00
12 changed files with 931 additions and 89 deletions

View File

@@ -0,0 +1 @@
Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.

View File

@@ -0,0 +1 @@
Populate `name`/`avatar` fields in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.

View File

@@ -18,6 +18,7 @@
#
#
import logging
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple
import attr
@@ -33,6 +34,7 @@ from synapse.api.constants import (
from synapse.events import EventBase
from synapse.events.utils import strip_event
from synapse.handlers.relations import BundledAggregations
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
from synapse.types import (
JsonDict,
@@ -464,6 +466,7 @@ class SlidingSyncHandler:
membership_state_keys = room_sync_config.required_state_map.get(
EventTypes.Member
)
# Also see `StateFilter.must_await_full_state(...)` for comparison
lazy_loading = (
membership_state_keys is not None
and len(membership_state_keys) == 1
@@ -540,11 +543,15 @@ class SlidingSyncHandler:
rooms[room_id] = room_sync_result
extensions = await self.get_extensions_response(
sync_config=sync_config, to_token=to_token
)
return SlidingSyncResult(
next_pos=to_token,
lists=lists,
rooms=rooms,
extensions={},
extensions=extensions,
)
async def get_sync_room_ids_for_user(
@@ -982,6 +989,7 @@ class SlidingSyncHandler:
# Assemble a new sync room map but only with the `filtered_room_id_set`
return {room_id: sync_room_map[room_id] for room_id in filtered_room_id_set}
@trace
async def sort_rooms(
self,
sync_room_map: Dict[str, _RoomMembershipForUser],
@@ -1003,24 +1011,18 @@ class SlidingSyncHandler:
# Assemble a map of room ID to the `stream_ordering` of the last activity that the
# user should see in the room (<= `to_token`)
last_activity_in_room_map: Dict[str, int] = {}
to_fetch = []
for room_id, room_for_user in sync_room_map.items():
# If they are fully-joined to the room, let's find the latest activity
# at/before the `to_token`.
if room_for_user.membership == Membership.JOIN:
last_event_result = (
await self.store.get_last_event_pos_in_room_before_stream_ordering(
room_id, to_token.room_key
)
)
stream = self.store._events_stream_cache._entity_to_key.get(room_id)
if stream is not None:
if stream <= to_token.room_key.stream:
last_activity_in_room_map[room_id] = stream
continue
# If the room has no events at/before the `to_token`, this is probably a
# mistake in the code that generates the `sync_room_map` since that should
# only give us rooms that the user had membership in during the token range.
assert last_event_result is not None
_, event_pos = last_event_result
last_activity_in_room_map[room_id] = event_pos.stream
to_fetch.append(room_id)
else:
# Otherwise, if the user has left/been invited/knocked/been banned from
# a room, they shouldn't see anything past that point.
@@ -1031,6 +1033,20 @@ class SlidingSyncHandler:
# https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
ordering_map = await self.store.get_max_stream_ordering_in_rooms(to_fetch)
for room_id, stream_pos in ordering_map.items():
if stream_pos is None:
continue
if stream_pos.persisted_after(to_token.room_key):
continue
last_activity_in_room_map[room_id] = stream_pos.stream
for room_id in sync_room_map.keys() - last_activity_in_room_map.keys():
# TODO: Handle better
last_activity_in_room_map[room_id] = sync_room_map[room_id].event_pos.stream
return sorted(
sync_room_map.values(),
# Sort by the last activity (stream_ordering) in the room
@@ -1202,7 +1218,7 @@ class SlidingSyncHandler:
# Figure out any stripped state events for invite/knocks. This allows the
# potential joiner to identify the room.
stripped_state: List[JsonDict] = []
stripped_state: Optional[List[JsonDict]] = None
if room_membership_for_user_at_to_token.membership in (
Membership.INVITE,
Membership.KNOCK,
@@ -1239,7 +1255,7 @@ class SlidingSyncHandler:
# updates.
initial = True
# Fetch the required state for the room
# Fetch the `required_state` for the room
#
# No `required_state` for invite/knock rooms (just `stripped_state`)
#
@@ -1247,13 +1263,15 @@ class SlidingSyncHandler:
# of membership. Currently, we have to make this optional because
# `invite`/`knock` rooms only have `stripped_state`. See
# https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
#
# Calculate the `StateFilter` based on the `required_state` for the room
room_state: Optional[StateMap[EventBase]] = None
required_room_state: Optional[StateMap[EventBase]] = None
if room_membership_for_user_at_to_token.membership not in (
Membership.INVITE,
Membership.KNOCK,
):
# Calculate the `StateFilter` based on the `required_state` for the room
state_filter: Optional[StateFilter] = StateFilter.none()
required_state_filter = StateFilter.none()
# If we have a double wildcard ("*", "*") in the `required_state`, we need
# to fetch all state for the room
#
@@ -1276,7 +1294,7 @@ class SlidingSyncHandler:
if StateValues.WILDCARD in room_sync_config.required_state_map.get(
StateValues.WILDCARD, set()
):
state_filter = StateFilter.all()
required_state_filter = StateFilter.all()
# TODO: `StateFilter` currently doesn't support wildcard event types. We're
# currently working around this by returning all state to the client but it
# would be nice to fetch less from the database and return just what the
@@ -1285,7 +1303,7 @@ class SlidingSyncHandler:
room_sync_config.required_state_map.get(StateValues.WILDCARD)
is not None
):
state_filter = StateFilter.all()
required_state_filter = StateFilter.all()
else:
required_state_types: List[Tuple[str, Optional[str]]] = []
for (
@@ -1317,77 +1335,125 @@ class SlidingSyncHandler:
else:
required_state_types.append((state_type, state_key))
state_filter = StateFilter.from_types(required_state_types)
required_state_filter = StateFilter.from_types(required_state_types)
# We can skip fetching state if we don't need any
if state_filter != StateFilter.none():
# We can return all of the state that was requested if we're doing an
# initial sync
if initial:
# People shouldn't see past their leave/ban event
if room_membership_for_user_at_to_token.membership in (
Membership.LEAVE,
Membership.BAN,
):
room_state = await self.storage_controllers.state.get_state_at(
room_id,
stream_position=to_token.copy_and_replace(
StreamKeyType.ROOM,
room_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
),
state_filter=state_filter,
# Partially-stated rooms should have all state events except for
# the membership events and since we've already excluded
# partially-stated rooms unless `required_state` only has
# `["m.room.member", "$LAZY"]` for membership, we should be able
# to retrieve everything requested. Plus we don't want to block
# the whole sync waiting for this one room.
await_full_state=False,
)
# Otherwise, we can get the latest current state in the room
else:
room_state = await self.storage_controllers.state.get_current_state(
room_id,
state_filter,
# Partially-stated rooms should have all state events except for
# the membership events and since we've already excluded
# partially-stated rooms unless `required_state` only has
# `["m.room.member", "$LAZY"]` for membership, we should be able
# to retrieve everything requested. Plus we don't want to block
# the whole sync waiting for this one room.
await_full_state=False,
)
# TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token`
# We need this base set of info for the response so let's just fetch it along
# with the `required_state` for the room
META_ROOM_STATE = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")]
state_filter = StateFilter(
types=StateFilter.from_types(
chain(META_ROOM_STATE, required_state_filter.to_types())
).types,
include_others=required_state_filter.include_others,
)
# We can return all of the state that was requested if this was the first
# time we've sent the room down this connection.
if initial:
# People shouldn't see past their leave/ban event
if room_membership_for_user_at_to_token.membership in (
Membership.LEAVE,
Membership.BAN,
):
room_state = await self.storage_controllers.state.get_state_at(
room_id,
stream_position=to_token.copy_and_replace(
StreamKeyType.ROOM,
room_membership_for_user_at_to_token.event_pos.to_room_stream_token(),
),
state_filter=state_filter,
# Partially-stated rooms should have all state events except for
# remote membership events. Since we've already excluded
# partially-stated rooms unless `required_state` only has
# `["m.room.member", "$LAZY"]` for membership, we should be able to
# retrieve everything requested. When we're lazy-loading, if there
# are some remote senders in the timeline, we should also have their
# membership event because we had to auth that timeline event. Plus
# we don't want to block the whole sync waiting for this one room.
await_full_state=False,
)
# Otherwise, we can get the latest current state in the room
else:
# TODO: Once we can figure out if we've sent a room down this connection before,
# we can return updates instead of the full required state.
raise NotImplementedError()
room_state = await self.storage_controllers.state.get_current_state(
room_id,
state_filter,
# Partially-stated rooms should have all state events except for
# remote membership events. Since we've already excluded
# partially-stated rooms unless `required_state` only has
# `["m.room.member", "$LAZY"]` for membership, we should be able to
# retrieve everything requested. When we're lazy-loading, if there
# are some remote senders in the timeline, we should also have their
# membership event because we had to auth that timeline event. Plus
# we don't want to block the whole sync waiting for this one room.
await_full_state=False,
)
# TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token`
else:
# TODO: Once we can figure out if we've sent a room down this connection before,
# we can return updates instead of the full required state.
raise NotImplementedError()
if required_state_filter != StateFilter.none():
required_room_state = required_state_filter.filter_state(room_state)
# Find the room name and avatar from the state
room_name: Optional[str] = None
room_avatar: Optional[str] = None
if room_state is not None:
name_event = room_state.get((EventTypes.Name, ""))
if name_event is not None:
room_name = name_event.content.get("name")
avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if avatar_event is not None:
room_avatar = avatar_event.content.get("url")
elif stripped_state is not None:
for event in stripped_state:
if event["type"] == EventTypes.Name:
room_name = event.get("content", {}).get("name")
elif event["type"] == EventTypes.RoomAvatar:
room_avatar = event.get("content", {}).get("url")
# Found everything so we can stop looking
if room_name is not None and room_avatar is not None:
break
# Figure out the last bump event in the room
last_bump_event_result = (
await self.store.get_last_event_pos_in_room_before_stream_ordering(
room_id, to_token.room_key, event_types=DEFAULT_BUMP_EVENT_TYPES
last_bump_event_stream_ordering = None
if timeline_events:
for e in reversed(timeline_events):
if e.type in DEFAULT_BUMP_EVENT_TYPES:
last_bump_event_stream_ordering = (
e.internal_metadata.stream_ordering
)
break
if last_bump_event_stream_ordering is None:
last_bump_event_result = (
await self.store.get_last_event_pos_in_room_before_stream_ordering(
room_id, to_token.room_key, event_types=DEFAULT_BUMP_EVENT_TYPES
)
)
)
if last_bump_event_result is not None:
last_bump_event_stream_ordering = last_bump_event_result[1].stream
# By default, just choose the membership event position
bump_stamp = room_membership_for_user_at_to_token.event_pos.stream
# But if we found a bump event, use that instead
if last_bump_event_result is not None:
_, bump_event_pos = last_bump_event_result
bump_stamp = bump_event_pos.stream
if last_bump_event_stream_ordering is not None:
bump_stamp = last_bump_event_stream_ordering
return SlidingSyncResult.RoomResult(
# TODO: Dummy value
name=None,
# TODO: Dummy value
avatar=None,
name=room_name,
avatar=room_avatar,
# TODO: Dummy value
heroes=None,
# TODO: Dummy value
is_dm=False,
initial=initial,
required_state=list(room_state.values()) if room_state else None,
required_state=(
list(required_room_state.values()) if required_room_state else None
),
timeline_events=timeline_events,
bundled_aggregations=bundled_aggregations,
stripped_state=stripped_state,
@@ -1404,3 +1470,100 @@ class SlidingSyncHandler:
notification_count=0,
highlight_count=0,
)
async def get_extensions_response(
self,
sync_config: SlidingSyncConfig,
to_token: StreamToken,
) -> SlidingSyncResult.Extensions:
"""Handle extension requests.
Args:
sync_config: Sync configuration
to_token: The point in the stream to sync up to.
"""
if sync_config.extensions is None:
return SlidingSyncResult.Extensions()
to_device_response = None
if sync_config.extensions.to_device:
to_device_response = await self.get_to_device_extensions_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)
return SlidingSyncResult.Extensions(to_device=to_device_response)
async def get_to_device_extensions_response(
self,
sync_config: SlidingSyncConfig,
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
to_token: StreamToken,
) -> SlidingSyncResult.Extensions.ToDeviceExtension:
"""Handle to-device extension (MSC3885)
Args:
sync_config: Sync configuration
to_device_request: The to-device extension from the request
to_token: The point in the stream to sync up to.
"""
user_id = sync_config.user.to_string()
device_id = sync_config.device_id
# Check that this request has a valid device ID (not all requests have
# to belong to a device, and so device_id is None), and that the
# extension is enabled.
if device_id is None or not to_device_request.enabled:
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{to_token.to_device_key}",
events=[],
)
since_stream_id = 0
if to_device_request.since is not None:
# We've already validated this is an int.
since_stream_id = int(to_device_request.since)
if to_token.to_device_key < since_stream_id:
# The since token is ahead of our current token, so we return an
# empty response.
logger.warning(
"Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r",
since_stream_id,
to_token.to_device_key,
)
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=to_device_request.since,
events=[],
)
# Delete everything before the given since token, as we know the
# device must have received them.
deleted = await self.store.delete_messages_for_device(
user_id=user_id,
device_id=device_id,
up_to_stream_id=since_stream_id,
)
logger.debug(
"Deleted %d to-device messages up to %d for %s",
deleted,
since_stream_id,
user_id,
)
messages, stream_id = await self.store.get_messages_for_device(
user_id=user_id,
device_id=device_id,
from_stream_id=since_stream_id,
to_stream_id=to_token.to_device_key,
limit=min(to_device_request.limit, 100), # Limit to at most 100 events
)
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{stream_id}",
events=messages,
)

View File

@@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
)
response["extensions"] = {} # TODO: sliding_sync_result.extensions
response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
)
return response
@@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet):
return serialized_rooms
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
) -> JsonDict:
result = {}
if extensions.to_device is not None:
result["to_device"] = {
"next_batch": extensions.to_device.next_batch,
"events": extensions.to_device.events,
}
return result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@@ -309,6 +309,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering) # type: ignore[attr-defined]
self._attempt_to_invalidate_cache(
"get_max_stream_ordering_in_room", (room_id,)
)
if redacts:
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]

View File

@@ -551,7 +551,7 @@ class PersistEventsStore:
# From this point onwards the events are only events that we haven't
# seen before.
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
self._store_event_txn(txn, room_id, events_and_contexts=events_and_contexts)
if new_forward_extremities:
self._update_forward_extremities_txn(
@@ -1555,6 +1555,7 @@ class PersistEventsStore:
def _store_event_txn(
self,
txn: LoggingTransaction,
room_id: str,
events_and_contexts: Collection[Tuple[EventBase, EventContext]],
) -> None:
"""Insert new events into the event, event_json, redaction and
@@ -1629,6 +1630,27 @@ class PersistEventsStore:
],
)
# Update the `sliding_sync_room_metadata` with the latest
# (non-backfilled, ie positive) stream ordering.
#
# We know this list is sorted and non-empty, so we just take the last
# one event.
max_stream_ordering: int
for e, _ in events_and_contexts:
assert e.internal_metadata.stream_ordering is not None
max_stream_ordering = e.internal_metadata.stream_ordering
if max_stream_ordering > 0:
self.db_pool.simple_upsert_txn(
txn,
table="sliding_sync_room_metadata",
keyvalues={"room_id": room_id},
values={
"instance_name": self._instance_name,
"last_stream_ordering": max_stream_ordering,
},
)
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.

View File

@@ -50,6 +50,7 @@ from typing import (
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
@@ -78,8 +79,13 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken
from synapse.util.caches.descriptors import cached
from synapse.types import (
JsonDict,
PersistedEventPosition,
RoomStreamToken,
StrCollection,
)
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
@@ -610,6 +616,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
database.updates.register_background_update_handler(
"sliding_sync_room_metadata", self._sliding_sync_room_metadata_bg_update
)
def get_room_max_stream_ordering(self) -> int:
"""Get the stream_ordering of regular events that we have committed up to
@@ -1185,6 +1195,52 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return None
@cachedList(
cached_method_name="get_max_stream_ordering_in_room",
list_name="room_ids",
)
async def get_max_stream_ordering_in_rooms(
self, room_ids: StrCollection
) -> Mapping[str, Optional[PersistedEventPosition]]:
"""Get the positions for the latest event in a room.
A batched version of `get_max_stream_ordering_in_room`.
"""
rows = await self.db_pool.simple_select_many_batch(
table="sliding_sync_room_metadata",
column="room_id",
iterable=room_ids,
retcols=("room_id", "instance_name", "last_stream_ordering"),
desc="get_max_stream_ordering_in_rooms",
)
return {
room_id: PersistedEventPosition(instance_name, stream)
for room_id, instance_name, stream in rows
}
@cached(max_entries=10000)
async def get_max_stream_ordering_in_room(
self,
room_id: str,
) -> Optional[PersistedEventPosition]:
"""Get the position for the latest event in a room.
Note: this may be after the current token for the room stream on this
process (e.g. due to replication lag)
"""
row = await self.db_pool.simple_select_one(
table="sliding_sync_room_metadata",
retcols=("instance_name", "last_stream_ordering"),
keyvalues={"room_id": room_id},
allow_none=True,
desc="get_max_stream_ordering_in_room",
)
if not row:
return None
return PersistedEventPosition(instance_name=row[0], stream=row[1])
async def get_last_event_pos_in_room_before_stream_ordering(
self,
room_id: str,
@@ -1983,3 +2039,88 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return RoomStreamToken(stream=last_position.stream - 1)
return None
async def _sliding_sync_room_metadata_bg_update(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to fill out 'sliding_sync_room_metadata' table"""
previous_room = progress.get("previous_room", "")
def _sliding_sync_room_metadata_bg_update_txn(txn: LoggingTransaction) -> int:
# Both these queries are just getting the most recent
# instance_name/stream ordering for the next N rooms.
if isinstance(self.database_engine, PostgresEngine):
sql = """
SELECT room_id, instance_name, stream_ordering FROM rooms AS r,
LATERAL (
SELECT instance_name, stream_ordering
FROM events WHERE events.room_id = r.room_id
ORDER BY stream_ordering DESC
LIMIT 1
) e
WHERE r.room_id > ?
ORDER BY r.room_id ASC
LIMIT ?
"""
else:
sql = """
SELECT
room_id,
(
SELECT instance_name
FROM events WHERE events.room_id = r.room_id
ORDER BY stream_ordering DESC
LIMIT 1
),
(
SELECT stream_ordering
FROM events WHERE events.room_id = r.room_id
ORDER BY stream_ordering DESC
LIMIT 1
)
FROM rooms AS r
WHERE r.room_id > ?
ORDER BY r.room_id ASC
LIMIT ?
"""
txn.execute(sql, (previous_room, batch_size))
rows = txn.fetchall()
if not rows:
return 0
self.db_pool.simple_upsert_many_txn(
txn,
table="sliding_sync_room_metadata",
key_names=("room_id",),
key_values=[(room_id,) for room_id, _, _ in rows],
value_names=(
"instance_name",
"last_stream_ordering",
),
value_values=[
(
instance_name or "master",
stream,
)
for _, instance_name, stream in rows
],
)
self.db_pool.updates._background_update_progress_txn(
txn, "sliding_sync_room_metadata", {"previous_room": rows[-1][0]}
)
return len(rows)
rows = await self.db_pool.runInteraction(
"_sliding_sync_room_metadata_bg_update",
_sliding_sync_room_metadata_bg_update_txn,
)
if rows == 0:
await self.db_pool.updates._end_background_update(
"sliding_sync_room_metadata"
)
return rows

View File

@@ -0,0 +1,24 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- A table that maps from room ID to metadata useful for sliding sync.
CREATE TABLE sliding_sync_room_metadata (
room_id TEXT NOT NULL PRIMARY KEY,
-- The instance_name / stream ordering of the last event in the room.
instance_name TEXT NOT NULL,
last_stream_ordering BIGINT NOT NULL
);
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(8507, 'sliding_sync_room_metadata', '{}');

View File

@@ -18,7 +18,7 @@
#
#
from enum import Enum
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
import attr
from typing_extensions import TypedDict
@@ -252,10 +252,39 @@ class SlidingSyncResult:
count: int
ops: List[Operation]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Extensions:
"""Responses for extensions
Attributes:
to_device: The to-device extension (MSC3885)
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ToDeviceExtension:
"""The to-device extension (MSC3885)
Attributes:
next_batch: The to-device stream token the client should use
to get more results
events: A list of to-device messages for the client
"""
next_batch: str
events: Sequence[JsonMapping]
def __bool__(self) -> bool:
return bool(self.events)
to_device: Optional[ToDeviceExtension] = None
def __bool__(self) -> bool:
return bool(self.to_device)
next_pos: StreamToken
lists: Dict[str, SlidingWindowList]
rooms: Dict[str, RoomResult]
extensions: JsonMapping
extensions: Extensions
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -271,5 +300,5 @@ class SlidingSyncResult:
next_pos=next_pos,
lists={},
rooms={},
extensions={},
extensions=SlidingSyncResult.Extensions(),
)

View File

@@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel):
class RoomSubscription(CommonRoomParameters):
pass
class Extension(RequestBodyModel):
enabled: Optional[StrictBool] = False
lists: Optional[List[StrictStr]] = None
rooms: Optional[List[StrictStr]] = None
class Extensions(RequestBodyModel):
"""The extensions section of the request.
Extensions MUST have an `enabled` flag which defaults to `false`. If a client
sends an unknown extension name, the server MUST ignore it (or else backwards
compatibility between clients and servers is broken when a newer client tries to
communicate with an older server).
"""
class ToDeviceExtension(RequestBodyModel):
"""The to-device extension (MSC3885)
Attributes:
enabled
limit: Maximum number of to-device messages to return
since: The `next_batch` from the previous sync response
"""
enabled: Optional[StrictBool] = False
limit: StrictInt = 100
since: Optional[StrictStr] = None
@validator("since")
def since_token_check(
cls, value: Optional[StrictStr]
) -> Optional[StrictStr]:
# `since` comes in as an opaque string token but we know that it's just
# an integer representing the position in the device inbox stream. We
# want to pre-validate it to make sure it works fine in downstream code.
if value is None:
return value
try:
int(value)
except ValueError:
raise ValueError(
"'extensions.to_device.since' is invalid (should look like an int)"
)
return value
to_device: Optional[ToDeviceExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
@@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel):
else:
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
extensions: Optional[Dict[StrictStr, Extension]] = None
extensions: Optional[Extensions] = None
@validator("lists")
def lists_length_check(

View File

@@ -38,7 +38,16 @@ from synapse.api.constants import (
)
from synapse.events import EventBase
from synapse.handlers.sliding_sync import StateValues
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
from synapse.rest.client import (
devices,
knock,
login,
read_marker,
receipts,
room,
sendtodevice,
sync,
)
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID
from synapse.util import Clock
@@ -47,7 +56,7 @@ from tests import unittest
from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.server import TimedOutException
from tests.server import FakeChannel, TimedOutException
from tests.test_utils.event_injection import mark_event_as_partial_state
logger = logging.getLogger(__name__)
@@ -1802,6 +1811,206 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
channel.json_body["lists"]["foo-list"],
)
def test_rooms_meta_when_joined(self) -> None:
"""
Test that the `rooms` `name` and `avatar` (soon to test `heroes`) are included
in the response when the user is joined to the room.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
room_id1 = self.helper.create_room_as(
user2_id,
tok=user2_tok,
extra_content={
"name": "my super room",
},
)
# Set the room avatar URL
self.helper.send_state(
room_id1,
EventTypes.RoomAvatar,
{"url": "mxc://DUMMY_MEDIA_ID"},
tok=user2_tok,
)
self.helper.join(room_id1, user1_id, tok=user1_tok)
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {
"foo-list": {
"ranges": [[0, 1]],
"required_state": [],
"timeline_limit": 0,
}
}
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Reflect the current state of the room
self.assertEqual(
channel.json_body["rooms"][room_id1]["name"],
"my super room",
channel.json_body["rooms"][room_id1],
)
self.assertEqual(
channel.json_body["rooms"][room_id1]["avatar"],
"mxc://DUMMY_MEDIA_ID",
channel.json_body["rooms"][room_id1],
)
def test_rooms_meta_when_invited(self) -> None:
"""
Test that the `rooms` `name` and `avatar` (soon to test `heroes`) are included
in the response when the user is invited to the room.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
room_id1 = self.helper.create_room_as(
user2_id,
tok=user2_tok,
extra_content={
"name": "my super room",
},
)
# Set the room avatar URL
self.helper.send_state(
room_id1,
EventTypes.RoomAvatar,
{"url": "mxc://DUMMY_MEDIA_ID"},
tok=user2_tok,
)
self.helper.join(room_id1, user1_id, tok=user1_tok)
# Update the room name after user1 has left
self.helper.send_state(
room_id1,
EventTypes.Name,
{"name": "my super duper room"},
tok=user2_tok,
)
# Update the room avatar URL after user1 has left
self.helper.send_state(
room_id1,
EventTypes.RoomAvatar,
{"url": "mxc://UPDATED_DUMMY_MEDIA_ID"},
tok=user2_tok,
)
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {
"foo-list": {
"ranges": [[0, 1]],
"required_state": [],
"timeline_limit": 0,
}
}
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# This should still reflect the current state of the room even when the user is
# invited.
self.assertEqual(
channel.json_body["rooms"][room_id1]["name"],
"my super duper room",
channel.json_body["rooms"][room_id1],
)
self.assertEqual(
channel.json_body["rooms"][room_id1]["avatar"],
"mxc://UPDATED_DUMMY_MEDIA_ID",
channel.json_body["rooms"][room_id1],
)
def test_rooms_meta_when_banned(self) -> None:
"""
Test that the `rooms` `name` and `avatar` (soon to test `heroes`) reflect the
state of the room when the user was banned (do not leak current state).
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
room_id1 = self.helper.create_room_as(
user2_id,
tok=user2_tok,
extra_content={
"name": "my super room",
},
)
# Set the room avatar URL
self.helper.send_state(
room_id1,
EventTypes.RoomAvatar,
{"url": "mxc://DUMMY_MEDIA_ID"},
tok=user2_tok,
)
self.helper.join(room_id1, user1_id, tok=user1_tok)
self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok)
# Update the room name after user1 has left
self.helper.send_state(
room_id1,
EventTypes.Name,
{"name": "my super duper room"},
tok=user2_tok,
)
# Update the room avatar URL after user1 has left
self.helper.send_state(
room_id1,
EventTypes.RoomAvatar,
{"url": "mxc://UPDATED_DUMMY_MEDIA_ID"},
tok=user2_tok,
)
# Make the Sliding Sync request
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {
"foo-list": {
"ranges": [[0, 1]],
"required_state": [],
"timeline_limit": 0,
}
}
},
access_token=user1_tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Reflect the state of the room at the time of leaving
self.assertEqual(
channel.json_body["rooms"][room_id1]["name"],
"my super room",
channel.json_body["rooms"][room_id1],
)
self.assertEqual(
channel.json_body["rooms"][room_id1]["avatar"],
"mxc://DUMMY_MEDIA_ID",
channel.json_body["rooms"][room_id1],
)
def test_rooms_limited_initial_sync(self) -> None:
"""
Test that we mark `rooms` as `limited=True` when we saturate the `timeline_limit`
@@ -2973,6 +3182,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
},
exact=True,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
def test_rooms_required_state_incremental_sync(self) -> None:
"""
@@ -3027,6 +3237,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
},
exact=True,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
def test_rooms_required_state_wildcard(self) -> None:
"""
@@ -3084,6 +3295,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
state_map.values(),
exact=True,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
def test_rooms_required_state_wildcard_event_type(self) -> None:
"""
@@ -3147,6 +3359,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
# events when the `event_type` is a wildcard.
exact=False,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
def test_rooms_required_state_wildcard_state_key(self) -> None:
"""
@@ -3192,6 +3405,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
},
exact=True,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
def test_rooms_required_state_lazy_loading_room_members(self) -> None:
"""
@@ -3247,6 +3461,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
},
exact=True,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
@parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)])
def test_rooms_required_state_leave_ban(self, stop_membership: str) -> None:
@@ -3329,6 +3544,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
},
exact=True,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
def test_rooms_required_state_combine_superset(self) -> None:
"""
@@ -3401,6 +3617,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
},
exact=True,
)
self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state"))
def test_rooms_required_state_partial_state(self) -> None:
"""
@@ -3488,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
],
channel.json_body["lists"]["foo-list"],
)
class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase):
"""Tests for the to-device sliding sync extension"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
sendtodevice.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.sync_endpoint = (
"/_matrix/client/unstable/org.matrix.simplified_msc3575/sync"
)
def _assert_to_device_response(
self, channel: FakeChannel, expected_messages: List[JsonDict]
) -> str:
"""Assert the sliding sync response was successful and has the expected
to-device messages.
Returns the next_batch token from the to-device section.
"""
self.assertEqual(channel.code, 200, channel.json_body)
extensions = channel.json_body["extensions"]
to_device = extensions["to_device"]
self.assertIsInstance(to_device["next_batch"], str)
self.assertEqual(to_device["events"], expected_messages)
return to_device["next_batch"]
def test_no_data(self) -> None:
"""Test that enabling to-device extension works, even if there is
no-data
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
# We expect no to-device messages
self._assert_to_device_response(channel, [])
def test_data_initial_sync(self) -> None:
"""Test that we get to-device messages when we don't specify a since
token"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass", "d1")
user2_id = self.register_user("u2", "pass")
user2_tok = self.login(user2_id, "pass", "d2")
# Send the to-device message
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user1_id: {"d1": test_msg}}},
access_token=user2_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self._assert_to_device_response(
channel,
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
)
def test_data_incremental_sync(self) -> None:
"""Test that we get to-device messages over incremental syncs"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass", "d1")
user2_id = self.register_user("u2", "pass")
user2_tok = self.login(user2_id, "pass", "d2")
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
# No to-device messages yet.
next_batch = self._assert_to_device_response(channel, [])
test_msg = {"foo": "bar"}
chan = self.make_request(
"PUT",
"/_matrix/client/r0/sendToDevice/m.test/1234",
content={"messages": {user1_id: {"d1": test_msg}}},
access_token=user2_tok,
)
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
"since": next_batch,
}
},
},
access_token=user1_tok,
)
next_batch = self._assert_to_device_response(
channel,
[{"content": test_msg, "sender": user2_id, "type": "m.test"}],
)
# The next sliding sync request should not include the to-device
# message.
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
"since": next_batch,
}
},
},
access_token=user1_tok,
)
self._assert_to_device_response(channel, [])
# An initial sliding sync request should not include the to-device
# message, as it should have been deleted
channel = self.make_request(
"POST",
self.sync_endpoint,
{
"lists": {},
"extensions": {
"to_device": {
"enabled": True,
}
},
},
access_token=user1_tok,
)
self._assert_to_device_response(channel, [])

View File

@@ -440,6 +440,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
events[0].room_id,
[
(e, EventContext(self.hs.get_storage_controllers(), {}))
for e in events