Compare commits

...

2 Commits

Author SHA1 Message Date
Erik Johnston
3bfc7b2e99 Newsfile 2024-07-08 15:05:40 +01:00
Erik Johnston
eab09e2262 Handle to-device extensions to Sliding Sync 2024-07-08 15:01:33 +01:00
5 changed files with 157 additions and 10 deletions

View File

@@ -0,0 +1 @@
Add to-device support to experimental sliding sync implementation.

View File

@@ -526,11 +526,15 @@ class SlidingSyncHandler:
rooms[room_id] = room_sync_result
extensions = await self.get_extensions_response(
sync_config=sync_config, to_token=to_token
)
return SlidingSyncResult(
next_pos=to_token,
lists=lists,
rooms=rooms,
extensions={},
extensions=extensions,
)
async def get_sync_room_ids_for_user(
@@ -1367,3 +1371,86 @@ class SlidingSyncHandler:
notification_count=0,
highlight_count=0,
)
async def get_extensions_response(
self,
sync_config: SlidingSyncConfig,
to_token: StreamToken,
) -> SlidingSyncResult.Extensions:
"""Handle extension requests."""
if sync_config.extensions is None:
return SlidingSyncResult.Extensions()
to_device_response = None
if sync_config.extensions.to_device:
to_device_response = await self.get_to_device_extensions_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)
return SlidingSyncResult.Extensions(to_device=to_device_response)
async def get_to_device_extensions_response(
self,
sync_config: SlidingSyncConfig,
to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension,
to_token: StreamToken,
) -> SlidingSyncResult.Extensions.ToDeviceExtension:
"""Handle to-device extension (MSC3885)"""
user_id = sync_config.user.to_string()
device_id = sync_config.device_id
# Check that this request has a valid device ID (not all requests have
# to belong to a device, and so device_id is None), and that the
# extension is enabled.
if device_id is None or not to_device_request.enabled:
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{to_token.to_device_key}",
events=[],
)
since_stream_id = 0
if to_device_request.since:
# We've already validated this is an int.
since_stream_id = int(to_device_request.since)
if to_token.to_device_key < since_stream_id:
# The since token is ahead of our current token, so we return an
# empty response.
logger.warning(
"Got to-device.since from the future. Next_batch: %r, current: %r",
since_stream_id,
to_token.to_device_key,
)
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=to_device_request.since,
events=[],
)
# Delete everything before the given since token, as we know the
# device must have received them.
deleted = await self.store.delete_messages_for_device(
user_id=user_id,
device_id=device_id,
up_to_stream_id=since_stream_id,
)
logger.debug(
"Deleted %d to-device messages up to %d", deleted, since_stream_id
)
messages, stream_id = await self.store.get_messages_for_device(
user_id=user_id,
device_id=device_id,
from_stream_id=since_stream_id,
to_stream_id=to_token.to_device_key,
limit=min(to_device_request.limit, 100), # Limit to at most 100 events
)
return SlidingSyncResult.Extensions.ToDeviceExtension(
next_batch=f"{stream_id}",
events=messages,
)

View File

@@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet):
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
)
response["extensions"] = {} # TODO: sliding_sync_result.extensions
response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
)
return response
@@ -1053,6 +1055,19 @@ class SlidingSyncRestServlet(RestServlet):
return serialized_rooms
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
) -> JsonDict:
result = {}
if extensions.to_device is not None:
result["to_device"] = {
"next_batch": extensions.to_device.next_batch,
"events": extensions.to_device.events,
}
return result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@@ -18,7 +18,7 @@
#
#
from enum import Enum
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple
import attr
from typing_extensions import TypedDict
@@ -244,10 +244,27 @@ class SlidingSyncResult:
count: int
ops: List[Operation]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class Extensions:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ToDeviceExtension:
"""The to-device extension (MSC3885)"""
next_batch: str
events: Sequence[JsonMapping]
def __bool__(self) -> bool:
return bool(self.events)
to_device: Optional[ToDeviceExtension] = None
def __bool__(self) -> bool:
return bool(self.to_device)
next_pos: StreamToken
lists: Dict[str, SlidingWindowList]
rooms: Dict[str, RoomResult]
extensions: JsonMapping
extensions: Extensions
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -263,5 +280,5 @@ class SlidingSyncResult:
next_pos=next_pos,
lists={},
rooms={},
extensions={},
extensions=SlidingSyncResult.Extensions(),
)

View File

@@ -276,10 +276,37 @@ class SlidingSyncBody(RequestBodyModel):
class RoomSubscription(CommonRoomParameters):
pass
class Extension(RequestBodyModel):
enabled: Optional[StrictBool] = False
lists: Optional[List[StrictStr]] = None
rooms: Optional[List[StrictStr]] = None
class Extensions(RequestBodyModel):
"""The extensions section of the request."""
class ToDeviceExtension(RequestBodyModel):
"""The to-device extension (MSC3885)
Args:
enabled
limit: Maximum number of to-device messages to return
since: The `next_batch` from the previous sync response
"""
enabled: Optional[StrictBool] = False
limit: StrictInt = 100
since: Optional[StrictStr] = None
@validator("since")
def lists_length_check(
cls, value: Optional[StrictStr]
) -> Optional[StrictStr]:
if value is None:
return value
try:
int(value)
except ValueError:
raise ValueError("'extensions.to_device.since' is invalid")
return value
to_device: Optional[ToDeviceExtension] = None
# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
@@ -287,7 +314,7 @@ class SlidingSyncBody(RequestBodyModel):
else:
lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type]
room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None
extensions: Optional[Dict[StrictStr, Extension]] = None
extensions: Optional[Extensions] = None
@validator("lists")
def lists_length_check(