Compare commits

...

5 Commits

Author SHA1 Message Date
dependabot[bot]
85088d647b Bump pyopenssl from 24.0.0 to 24.1.0
Bumps [pyopenssl](https://github.com/pyca/pyopenssl) from 24.0.0 to 24.1.0.
- [Changelog](https://github.com/pyca/pyopenssl/blob/main/CHANGELOG.rst)
- [Commits](https://github.com/pyca/pyopenssl/compare/24.0.0...24.1.0)

---
updated-dependencies:
- dependency-name: pyopenssl
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-05-27 10:43:31 +00:00
Shay
9edb725ebc Support MSC3916 by adding unstable media endpoints to _matrix/client (#17213)
[MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md)
adds new media endpoints under `_matrix/client`. This PR adds the
`/preview_url`, `/config`, and `/thumbnail` endpoints. `/download` will
be added in a follow-up PR once the work for the federation `/download`
endpoint is complete (see
https://github.com/element-hq/synapse/pull/17172).

Should be reviewable commit-by-commit.
2024-05-24 09:47:37 +01:00
Eric Eastwood
c97251d5ba Add Sliding Sync /sync/e2ee endpoint for To-Device messages (#17167)
This is being introduced as part of Sliding Sync but doesn't have any sliding window component. It's just a way to get E2EE events without having to sit through a big initial sync  (`/sync` v2). And we can avoid encryption events being backed up by the main sync response or vice-versa.

Part of some Sliding Sync simplification/experimentation. See [this discussion](https://github.com/element-hq/synapse/pull/17167#discussion_r1610495866) for why it may not be as useful as we thought.

Based on:

 - https://github.com/matrix-org/matrix-spec-proposals/pull/3575
 - https://github.com/matrix-org/matrix-spec-proposals/pull/3885
 - https://github.com/matrix-org/matrix-spec-proposals/pull/3884
2024-05-23 12:06:16 -05:00
reivilibre
7e2412265d Log exceptions when failing to auto-join new user according to the auto_join_rooms option. (#17176)
Would have been useful for tracking down #16878.

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2024-05-22 14:22:33 +01:00
reivilibre
7ef00b7628 Add logging to tasks managed by the task scheduler, showing CPU and database usage. (#17219)
The log format is the same as the request log format, except:

- fields that are specific to HTTP requests have been removed
- the task's params are included at the end of the log line.

These log lines are emitted:
- when the task function finishes — both completion and failure (and I
suppose it is possible for a task to become schedulable again?)
- every 5 minutes whilst it is running

Closes #17217.

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2024-05-22 14:12:58 +01:00
18 changed files with 3328 additions and 731 deletions

View File

@@ -0,0 +1 @@
Add experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync/e2ee` endpoint for To-Device messages and device encryption info.

1
changelog.d/17176.misc Normal file
View File

@@ -0,0 +1 @@
Log exceptions when failing to auto-join new user according to the `auto_join_rooms` option.

View File

@@ -0,0 +1 @@
Support MSC3916 by adding unstable media endpoints to `_matrix/client` (#17213).

View File

@@ -0,0 +1 @@
Add logging to tasks managed by the task scheduler, showing CPU and database usage.

8
poetry.lock generated
View File

@@ -1997,13 +1997,13 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]]
name = "pyopenssl"
version = "24.0.0"
version = "24.1.0"
description = "Python wrapper module around the OpenSSL library"
optional = false
python-versions = ">=3.7"
files = [
{file = "pyOpenSSL-24.0.0-py3-none-any.whl", hash = "sha256:ba07553fb6fd6a7a2259adb9b84e12302a9a8a75c44046e8bb5d3e5ee887e3c3"},
{file = "pyOpenSSL-24.0.0.tar.gz", hash = "sha256:6aa33039a93fffa4563e655b61d11364d01264be8ccb49906101e02a334530bf"},
{file = "pyOpenSSL-24.1.0-py3-none-any.whl", hash = "sha256:17ed5be5936449c5418d1cd269a1a9e9081bc54c17aed272b45856a3d3dc86ad"},
{file = "pyOpenSSL-24.1.0.tar.gz", hash = "sha256:cabed4bfaa5df9f1a16c0ef64a0cb65318b5cd077a7eda7d6970131ca2f41a6f"},
]
[package.dependencies]
@@ -2011,7 +2011,7 @@ cryptography = ">=41.0.5,<43"
[package.extras]
docs = ["sphinx (!=5.2.0,!=5.2.0.post0,!=7.2.5)", "sphinx-rtd-theme"]
test = ["flaky", "pretend", "pytest (>=3.0.1)"]
test = ["pretend", "pytest (>=3.0.1)", "pytest-rerunfailures"]
[[package]]
name = "pysaml2"

View File

@@ -332,6 +332,9 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
# MSC3575 (Sliding Sync API endpoints)
self.msc3575_enabled: bool = experimental.get("msc3575_enabled", False)
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
@@ -436,3 +439,7 @@ class ExperimentalConfig(Config):
self.msc4115_membership_on_events = experimental.get(
"msc4115_membership_on_events", False
)
self.msc3916_authenticated_media_enabled = experimental.get(
"msc3916_authenticated_media_enabled", False
)

View File

@@ -590,7 +590,7 @@ class RegistrationHandler:
# moving away from bare excepts is a good thing to do.
logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
logger.error("Failed to join new user to %r: %r", r, e, exc_info=True)
async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place

View File

@@ -28,11 +28,14 @@ from typing import (
Dict,
FrozenSet,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
overload,
)
import attr
@@ -128,6 +131,8 @@ class SyncVersion(Enum):
# Traditional `/sync` endpoint
SYNC_V2 = "sync_v2"
# Part of MSC3575 Sliding Sync
E2EE_SYNC = "e2ee_sync"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -280,6 +285,26 @@ class SyncResult:
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeSyncResult:
"""
Attributes:
next_batch: Token for the next sync
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
"""
next_batch: StreamToken
to_device: List[JsonDict]
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonMapping
device_unused_fallback_key_types: List[str]
class SyncHandler:
def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
@@ -322,6 +347,31 @@ class SyncHandler:
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> SyncResult: ...
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> E2eeSyncResult: ...
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
@@ -331,7 +381,18 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> SyncResult:
) -> Union[SyncResult, E2eeSyncResult]: ...
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: SyncVersion,
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
@@ -344,8 +405,10 @@ class SyncHandler:
since_token: The point in the stream to sync from.
timeout: How long to wait for new data to arrive before giving up.
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
@@ -366,6 +429,29 @@ class SyncHandler:
logger.debug("Returning sync response for %s", user_id)
return res
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult: ...
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> E2eeSyncResult: ...
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
@@ -374,7 +460,17 @@ class SyncHandler:
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult:
) -> Union[SyncResult, E2eeSyncResult]: ...
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> Union[SyncResult, E2eeSyncResult]:
"""The start of the machinery that produces a /sync response.
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -417,14 +513,16 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result: SyncResult = await self.current_sync_for_user(
sync_config, sync_version, since_token, full_state=full_state
result: Union[SyncResult, E2eeSyncResult] = (
await self.current_sync_for_user(
sync_config, sync_version, since_token, full_state=full_state
)
)
else:
# Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> SyncResult:
) -> Union[SyncResult, E2eeSyncResult]:
return await self.current_sync_for_user(
sync_config, sync_version, since_token
)
@@ -456,14 +554,43 @@ class SyncHandler:
return result
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult: ...
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> E2eeSyncResult: ...
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Generates the response body of a sync result, represented as a SyncResult.
) -> Union[SyncResult, E2eeSyncResult]: ...
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]:
"""
Generates the response body of a sync result, represented as a
`SyncResult`/`E2eeSyncResult`.
This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your
@@ -474,15 +601,25 @@ class SyncHandler:
sync_version: Determines what kind of sync response to generate.
since_token: The point in the stream to sync from.p.
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
"""
with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
# Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2:
sync_result: SyncResult = await self.generate_sync_result(
sync_config, since_token, full_state
sync_result: Union[SyncResult, E2eeSyncResult] = (
await self.generate_sync_result(
sync_config, since_token, full_state
)
)
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
elif sync_version == SyncVersion.E2EE_SYNC:
sync_result = await self.generate_e2ee_sync_result(
sync_config, since_token
)
else:
raise Exception(
@@ -1691,6 +1828,96 @@ class SyncHandler:
next_batch=sync_result_builder.now_token,
)
async def generate_e2ee_sync_result(
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
) -> E2eeSyncResult:
"""
Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
This is represented by a `E2eeSyncResult` struct, which is built from small
pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
mutable ("inout") parameter to various helper functions. These retrieve and
process the data which forms the sync body, often writing to the
`sync_result_builder` to store their output.
At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
instance to signify that the sync calculation is complete.
"""
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
sync_result_builder = await self.get_sync_result_builder(
sync_config,
since_token,
full_state=False,
)
# 1. Calculate `to_device` events
await self._generate_sync_entry_for_to_device(sync_result_builder)
# 2. Calculate `device_lists`
# Device list updates are sent if a since token is provided.
device_lists = DeviceListUpdates()
include_device_list_updates = bool(since_token and since_token.device_list_key)
if include_device_list_updates:
# Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
# is used in calculate_user_changes below.
#
# TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
# figure out the membership changes/derived info needed for
# `_generate_sync_entry_for_device_list()`. In the future, we should try to
# refactor this away.
(
newly_joined_rooms,
newly_left_rooms,
) = await self._generate_sync_entry_for_rooms(sync_result_builder)
# This uses the sync_result_builder.joined which is set in
# `_generate_sync_entry_for_rooms`, if that didn't find any joined
# rooms for some reason it is a no-op.
(
newly_joined_or_invited_or_knocked_users,
newly_left_users,
) = sync_result_builder.calculate_user_changes()
device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
# 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
device_id = sync_config.device_id
one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
return E2eeSyncResult(
to_device=sync_result_builder.to_device,
device_lists=device_lists,
device_one_time_keys_count=one_time_keys_count,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
async def get_sync_result_builder(
self,
sync_config: SyncConfig,
@@ -1889,7 +2116,7 @@ class SyncHandler:
users_that_have_changed = (
await self._device_handler.get_device_changes_in_shared_rooms(
user_id,
sync_result_builder.joined_room_ids,
joined_room_ids,
from_token=since_token,
now_token=sync_result_builder.now_token,
)

View File

@@ -22,11 +22,27 @@
import logging
from io import BytesIO
from types import TracebackType
from typing import Optional, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from PIL import Image
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.media._base import (
FileInfo,
ThumbnailInfo,
respond_404,
respond_with_file,
respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -231,3 +247,471 @@ class Thumbnailer:
def __del__(self) -> None:
# Make sure we actually do close the image, rather than leak data.
self.close()
class ThumbnailProvider:
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
self.hs = hs
self.media_repo = media_repo
self.media_storage = media_storage
self.store = hs.get_datastores().main
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
async def respond_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info.url_cache),
server_name=None,
)
async def select_or_generate_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=bool(media_info.url_cache),
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info.url_cache),
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def select_or_generate_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def respond_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: SynapseRequest,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
media_id: str,
file_id: str,
url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
logger.debug(
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
media_id,
desired_width,
desired_height,
desired_method,
thumbnail_infos,
)
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
# different code path to handle it.
assert not self.dynamic_thumbnails
if thumbnail_infos:
file_info = self._select_thumbnail(
desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
)
if not file_info:
logger.info("Couldn't find a thumbnail matching the desired inputs")
respond_404(request)
return
# The thumbnail property must exist.
assert file_info.thumbnail is not None
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
return
# If we can't find the thumbnail we regenerate it. This can happen
# if e.g. we've deleted the thumbnails but still have the original
# image somewhere.
#
# Since we have an entry for the thumbnail in the DB we a) know we
# have have successfully generated the thumbnail in the past (so we
# don't need to worry about repeatedly failing to generate
# thumbnails), and b) have already calculated that appropriate
# width/height/method so we can just call the "generate exact"
# methods.
# First let's check that we do actually have the original image
# still. This will throw a 404 if we don't.
# TODO: We should refetch the thumbnails for remote media.
await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
if server_name:
await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id=file_id,
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
else:
# This might be because:
# 1. We can't create thumbnails for the given media (corrupted or
# unsupported file type), or
# 2. The thumbnailing process never ran or errored out initially
# when the media was first uploaded (these bugs should be
# reported and fixed).
# Note that we don't attempt to generate a thumbnail now because
# `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails")
assert request.path is not None
respond_with_json(
request,
400,
cs_error(
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
% (
request.path.decode(),
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
),
code=Codes.UNKNOWN,
),
send_cors=True,
)
def _select_thumbnail(
self,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
file_id: str,
url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
d_w = desired_width
d_h = desired_height
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
# Other thumbnails.
crop_info_list2: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "crop":
continue
t_w = info.width
t_h = info.height
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if crop_info_list:
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
# Other thumbnails.
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "scale":
continue
t_w = info.width
t_h = info.height
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if info_list:
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=thumbnail_info,
)
# No matching thumbnail was found.
return None

View File

@@ -0,0 +1,205 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright (C) 2024 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import logging
import re
from synapse.http.server import (
HttpServer,
respond_with_json,
respond_with_json_bytes,
set_corp_headers,
set_cors_headers,
)
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
respond_404,
)
from synapse.media.media_repository import MediaRepository
from synapse.media.media_storage import MediaStorage
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.server import HomeServer
from synapse.util.stringutils import parse_and_validate_server_name
logger = logging.getLogger(__name__)
class UnstablePreviewURLServlet(RestServlet):
"""
Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
specific additions).
This does have trade-offs compared to other designs:
* Pros:
* Simple and flexible; can be used by any clients at any point
* Cons:
* If each homeserver provides one of these independently, all the homeservers in a
room may needlessly DoS the target URI
* The URL metadata must be stored somewhere, rather than just using Matrix
itself to store the media.
* Matrix cannot be used to distribute the metadata between homeservers.
"""
PATTERNS = [
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
]
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.media_repo = media_repo
self.media_storage = media_storage
assert self.media_repo.url_previewer is not None
self.url_previewer = self.media_repo.url_previewer
async def on_GET(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url", required=True)
ts = parse_integer(request, "ts")
if ts is None:
ts = self.clock.time_msec()
og = await self.url_previewer.preview(url, requester.user, ts)
respond_with_json_bytes(request, 200, og, send_cors=True)
class UnstableMediaConfigResource(RestServlet):
PATTERNS = [
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
]
def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.config
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.media.max_upload_size}
async def on_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
class UnstableThumbnailResource(RestServlet):
PATTERNS = [
re.compile(
"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
]
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
self.thumbnailer = ThumbnailProvider(hs, media_repo, media_storage)
self.auth = hs.get_auth()
async def on_GET(
self, request: SynapseRequest, server_name: str, media_id: str
) -> None:
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
await self.auth.get_user_by_req(request)
set_cors_headers(request)
set_corp_headers(request)
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self.thumbnailer.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
else:
await self.thumbnailer.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
self.media_repo.mark_recently_accessed(None, media_id)
else:
# Don't let users download media from configured domains, even if it
# is already downloaded. This is Trust & Safety tooling to make some
# media inaccessible to local users.
# See `prevent_media_downloads_from` config docs for more info.
if server_name in self.prevent_media_downloads_from:
respond_404(request)
return
remote_resp_function = (
self.thumbnailer.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
else self.thumbnailer.respond_remote_thumbnail
)
await remote_resp_function(
request,
server_name,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc3916_authenticated_media_enabled:
media_repo = hs.get_media_repository()
if hs.config.media.url_preview_enabled:
UnstablePreviewURLServlet(
hs, media_repo, media_repo.media_storage
).register(http_server)
UnstableMediaConfigResource(hs).register(http_server)
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
http_server
)

View File

@@ -567,5 +567,176 @@ class SyncRestServlet(RestServlet):
return result
class SlidingSyncE2eeRestServlet(RestServlet):
"""
API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
of Sliding Sync but doesn't have any sliding window component. It's just a way to
get E2EE events without having to sit through a big initial sync (`/sync` v2). And
we can avoid encryption events being backed up by the main sync response.
Having To-Device messages split out to this sync endpoint also helps when clients
need to have 2 or more sync streams open at a time, e.g a push notification process
and a main process. This can cause the two processes to race to fetch the To-Device
events, resulting in the need for complex synchronisation rules to ensure the token
is correctly and atomically exchanged between processes.
GET parameters::
timeout(int): How long to wait for new events in milliseconds.
since(batch_token): Batch token when asking for incremental deltas.
Response JSON::
{
"next_batch": // batch token for the next /sync
"to_device": {
// list of to-device events
"events": [
{
"content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
"type": "m.room.encrypted",
"sender": "@alice:example.com",
}
// ...
]
},
"device_lists": {
"changed": ["@alice:example.com"],
"left": ["@bob:example.com"]
},
"device_one_time_keys_count": {
"signed_curve25519": 50
},
"device_unused_fallback_key_types": [
"signed_curve25519"
]
}
"""
PATTERNS = client_patterns(
"/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.sync_handler = hs.get_sync_handler()
# Filtering only matters for the `device_lists` because it requires a bunch of
# derived information from rooms (see how `_generate_sync_entry_for_rooms()`
# prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
self.only_member_events_filter_collection = FilterCollection(
self.hs,
{
"room": {
# We only care about membership events for the `device_lists`.
# Membership will tell us whether a user has joined/left a room and
# if there are new devices to encrypt for.
"timeline": {
"types": ["m.room.member"],
},
"state": {
"types": ["m.room.member"],
},
# We don't want any extra account_data generated because it's not
# returned by this endpoint. This helps us avoid work in
# `_generate_sync_entry_for_rooms()`
"account_data": {
"not_types": ["*"],
},
# We don't want any extra ephemeral data generated because it's not
# returned by this endpoint. This helps us avoid work in
# `_generate_sync_entry_for_rooms()`
"ephemeral": {
"not_types": ["*"],
},
},
# We don't want any extra account_data generated because it's not
# returned by this endpoint. (This is just here for good measure)
"account_data": {
"not_types": ["*"],
},
# We don't want any extra presence data generated because it's not
# returned by this endpoint. (This is just here for good measure)
"presence": {
"not_types": ["*"],
},
},
)
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
sync_config = SyncConfig(
user=user,
filter_collection=self.only_member_events_filter_collection,
is_guest=requester.is_guest,
device_id=device_id,
)
since_token = None
if since is not None:
since_token = await StreamToken.from_string(self.store, since)
# Request cache key
request_key = (
SyncVersion.E2EE_SYNC,
user,
timeout,
since,
)
# Gather data for the response
sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
SyncVersion.E2EE_SYNC,
request_key,
since_token=since_token,
timeout=timeout,
full_state=False,
)
# The client may have disconnected by now; don't bother to serialize the
# response if so.
if request._disconnected:
logger.info("Client has disconnected; not serializing response.")
return 200, {}
response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.to_device:
response["to_device"] = {"events": sync_result.to_device}
if sync_result.device_lists.changed:
response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
if sync_result.device_lists.left:
response["device_lists"]["left"] = list(sync_result.device_lists.left)
# We always include this because https://github.com/vector-im/element-android/issues/3725
# The spec isn't terribly clear on when this can be omitted and how a client would tell
# the difference between "no keys present" and "nothing changed" in terms of whole field
# absent / individual key type entry absent
# Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
# states that this field should always be included, as long as the server supports the feature.
response["device_unused_fallback_key_types"] = (
sync_result.device_unused_fallback_key_types
)
return 200, response
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
if hs.config.experimental.msc3575_enabled:
SlidingSyncE2eeRestServlet(hs).register(http_server)

View File

@@ -22,23 +22,18 @@
import logging
import re
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
from synapse.http.server import set_corp_headers, set_cors_headers
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.media._base import (
DEFAULT_MAX_TIMEOUT_MS,
MAXIMUM_ALLOWED_MAX_TIMEOUT_MS,
FileInfo,
ThumbnailInfo,
respond_404,
respond_with_file,
respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
@@ -66,10 +61,11 @@ class ThumbnailResource(RestServlet):
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.thumbnail_provider = ThumbnailProvider(hs, media_repo, media_storage)
async def on_GET(
self, request: SynapseRequest, server_name: str, media_id: str
@@ -91,11 +87,11 @@ class ThumbnailResource(RestServlet):
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail(
await self.thumbnail_provider.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
else:
await self._respond_local_thumbnail(
await self.thumbnail_provider.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms
)
self.media_repo.mark_recently_accessed(None, media_id)
@@ -109,9 +105,9 @@ class ThumbnailResource(RestServlet):
return
remote_resp_function = (
self._select_or_generate_remote_thumbnail
self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
else self._respond_remote_thumbnail
else self.thumbnail_provider.respond_remote_thumbnail
)
await remote_resp_function(
request,
@@ -124,457 +120,3 @@ class ThumbnailResource(RestServlet):
max_timeout_ms,
)
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_id,
url_cache=bool(media_info.url_cache),
server_name=None,
)
async def _select_or_generate_local_thumbnail(
self,
request: SynapseRequest,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=bool(media_info.url_cache),
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_local_exact_thumbnail(
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
url_cache=bool(media_info.url_cache),
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def _select_or_generate_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
max_timeout_ms: int,
) -> None:
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
respond_404(request)
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
file_id = media_info.filesystem_id
for info in thumbnail_infos:
t_w = info.width == desired_width
t_h = info.height == desired_height
t_method = info.method == desired_method
t_type = info.type == desired_type
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=info,
)
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
)
if file_path:
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
max_timeout_ms: int,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms
)
if not media_info:
return
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
media_info.filesystem_id,
url_cache=False,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: SynapseRequest,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
media_id: str,
file_id: str,
url_cache: bool,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail info of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
logger.debug(
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
media_id,
desired_width,
desired_height,
desired_method,
thumbnail_infos,
)
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
# different code path to handle it.
assert not self.dynamic_thumbnails
if thumbnail_infos:
file_info = self._select_thumbnail(
desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
)
if not file_info:
logger.info("Couldn't find a thumbnail matching the desired inputs")
respond_404(request)
return
# The thumbnail property must exist.
assert file_info.thumbnail is not None
responder = await self.media_storage.fetch_media(file_info)
if responder:
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
return
# If we can't find the thumbnail we regenerate it. This can happen
# if e.g. we've deleted the thumbnails but still have the original
# image somewhere.
#
# Since we have an entry for the thumbnail in the DB we a) know we
# have have successfully generated the thumbnail in the past (so we
# don't need to worry about repeatedly failing to generate
# thumbnails), and b) have already calculated that appropriate
# width/height/method so we can just call the "generate exact"
# methods.
# First let's check that we do actually have the original image
# still. This will throw a 404 if we don't.
# TODO: We should refetch the thumbnails for remote media.
await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
)
if server_name:
await self.media_repo.generate_remote_exact_thumbnail(
server_name,
file_id=file_id,
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
)
else:
await self.media_repo.generate_local_exact_thumbnail(
media_id=media_id,
t_width=file_info.thumbnail.width,
t_height=file_info.thumbnail.height,
t_method=file_info.thumbnail.method,
t_type=file_info.thumbnail.type,
url_cache=url_cache,
)
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
else:
# This might be because:
# 1. We can't create thumbnails for the given media (corrupted or
# unsupported file type), or
# 2. The thumbnailing process never ran or errored out initially
# when the media was first uploaded (these bugs should be
# reported and fixed).
# Note that we don't attempt to generate a thumbnail now because
# `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails")
assert request.path is not None
respond_with_json(
request,
400,
cs_error(
"Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
% (
request.path.decode(),
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
),
code=Codes.UNKNOWN,
),
send_cors=True,
)
def _select_thumbnail(
self,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[ThumbnailInfo],
file_id: str,
url_cache: bool,
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of thumbnail infos of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
d_w = desired_width
d_h = desired_height
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
# Other thumbnails.
crop_info_list2: List[
Tuple[int, int, int, bool, Optional[int], ThumbnailInfo]
] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "crop":
continue
t_w = info.width
t_h = info.height
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if crop_info_list:
thumbnail_info = min(crop_info_list, key=lambda t: t[:-1])[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2, key=lambda t: t[:-1])[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
info_list: List[Tuple[int, bool, int, ThumbnailInfo]] = []
# Other thumbnails.
info_list2: List[Tuple[int, bool, int, ThumbnailInfo]] = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info.method != "scale":
continue
t_w = info.width
t_h = info.height
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info.type
length_quality = info.length
if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
# Pick the most appropriate thumbnail. Some values of `desired_width` and
# `desired_height` may result in a tie, in which case we avoid comparing on
# the thumbnail info and pick the thumbnail that appears earlier
# in the list of candidates.
if info_list:
thumbnail_info = min(info_list, key=lambda t: t[:-1])[-1]
elif info_list2:
thumbnail_info = min(info_list2, key=lambda t: t[:-1])[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=thumbnail_info,
)
# No matching thumbnail was found.
return None

View File

@@ -24,7 +24,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Set
from twisted.python.failure import Failure
from synapse.logging.context import nested_logging_context
from synapse.logging.context import (
ContextResourceUsage,
LoggingContext,
nested_logging_context,
set_current_context,
)
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import (
run_as_background_process,
@@ -81,6 +86,8 @@ class TaskScheduler:
MAX_CONCURRENT_RUNNING_TASKS = 5
# Time from the last task update after which we will log a warning
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs
# Report a running task's status and usage every so often.
OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes
def __init__(self, hs: "HomeServer"):
self._hs = hs
@@ -346,6 +353,33 @@ class TaskScheduler:
assert task.id not in self._running_tasks
await self._store.delete_scheduled_task(task.id)
@staticmethod
def _log_task_usage(
state: str, task: ScheduledTask, usage: ContextResourceUsage, active_time: float
) -> None:
"""
Log a line describing the state and usage of a task.
The log line is inspired by / a copy of the request log line format,
but with irrelevant fields removed.
active_time: Time that the task has been running for, in seconds.
"""
logger.info(
"Task %s: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" [%d dbevts] %r, %r",
state,
active_time,
usage.ru_utime,
usage.ru_stime,
usage.db_sched_duration_sec,
usage.db_txn_duration_sec,
int(usage.db_txn_count),
usage.evt_db_fetch_count,
task.resource_id,
task.params,
)
async def _launch_task(self, task: ScheduledTask) -> None:
"""Launch a scheduled task now.
@@ -360,8 +394,32 @@ class TaskScheduler:
)
function = self._actions[task.action]
def _occasional_report(
task_log_context: LoggingContext, start_time: float
) -> None:
"""
Helper to log a 'Task continuing' line every so often.
"""
current_time = self._clock.time()
calling_context = set_current_context(task_log_context)
try:
usage = task_log_context.get_resource_usage()
TaskScheduler._log_task_usage(
"continuing", task, usage, current_time - start_time
)
finally:
set_current_context(calling_context)
async def wrapper() -> None:
with nested_logging_context(task.id):
with nested_logging_context(task.id) as log_context:
start_time = self._clock.time()
occasional_status_call = self._clock.looping_call(
_occasional_report,
TaskScheduler.OCCASIONAL_REPORT_INTERVAL_MS,
log_context,
start_time,
)
try:
(status, result, error) = await function(task)
except Exception:
@@ -383,6 +441,13 @@ class TaskScheduler:
)
self._running_tasks.remove(task.id)
current_time = self._clock.time()
usage = log_context.get_resource_usage()
TaskScheduler._log_task_usage(
status.value, task, usage, current_time - start_time
)
occasional_status_call.stop()
# Try launch a new task since we've finished with this one.
self._clock.call_later(0.1, self._launch_scheduled_tasks)

View File

@@ -18,6 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
import itertools
import os
import shutil
import tempfile
@@ -46,11 +47,11 @@ from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
from synapse.media.storage_provider import FileStorageProviderBackend
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
from synapse.rest import admin
from synapse.rest.client import login
from synapse.rest.media.thumbnail_resource import ThumbnailResource
from synapse.rest.client import login, media
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock
@@ -153,68 +154,54 @@ class _TestImage:
is_inline: bool = True
@parameterized_class(
("test_image",),
[
# small png
(
_TestImage(
SMALL_PNG,
b"image/png",
b".png",
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
b"000000737a7af40000001a49444154789cedc101010000008220"
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
b"44ae426082"
),
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000d49444154789c636060606000000005"
b"0001a5f645400000000049454e44ae426082"
),
),
),
# small png with transparency.
(
_TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
b"4154789c636800000082008177cd72b60000000049454e44ae426"
b"082"
),
b"image/png",
b".png",
# Note that we don't check the output since it varies across
# different versions of Pillow.
),
),
# small lossless webp
(
_TestImage(
unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007"
b"1011118888fe0700"
),
b"image/webp",
b".webp",
),
),
# an empty file
(
_TestImage(
b"",
b"image/gif",
b".gif",
expected_found=False,
unable_to_thumbnail=True,
),
),
# An SVG.
(
_TestImage(
b"""<?xml version="1.0"?>
small_png = _TestImage(
SMALL_PNG,
b"image/png",
b".png",
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000020000000200806"
b"000000737a7af40000001a49444154789cedc101010000008220"
b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
b"44ae426082"
),
unhexlify(
b"89504e470d0a1a0a0000000d4948445200000001000000010806"
b"0000001f15c4890000000d49444154789c636060606000000005"
b"0001a5f645400000000049454e44ae426082"
),
)
small_png_with_transparency = _TestImage(
unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944"
b"4154789c636800000082008177cd72b60000000049454e44ae426"
b"082"
),
b"image/png",
b".png",
# Note that we don't check the output since it varies across
# different versions of Pillow.
)
small_lossless_webp = _TestImage(
unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
),
b"image/webp",
b".webp",
)
empty_file = _TestImage(
b"",
b"image/gif",
b".gif",
expected_found=False,
unable_to_thumbnail=True,
)
SVG = _TestImage(
b"""<?xml version="1.0"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
@@ -223,19 +210,32 @@ class _TestImage:
<circle cx="100" cy="100" r="50" stroke="black"
stroke-width="5" fill="red" />
</svg>""",
b"image/svg",
b".svg",
expected_found=False,
unable_to_thumbnail=True,
is_inline=False,
),
),
],
b"image/svg",
b".svg",
expected_found=False,
unable_to_thumbnail=True,
is_inline=False,
)
test_images = [
small_png,
small_png_with_transparency,
small_lossless_webp,
empty_file,
SVG,
]
urls = [
"_matrix/media/r0/thumbnail",
"_matrix/client/unstable/org.matrix.msc3916/media/thumbnail",
]
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
class MediaRepoTests(unittest.HomeserverTestCase):
servlets = [media.register_servlets]
test_image: ClassVar[_TestImage]
hijack_auth = True
user_id = "@test:user"
url: ClassVar[str]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches: List[
@@ -298,6 +298,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
@@ -502,7 +503,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale"
channel = self.make_request(
"GET",
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -530,7 +531,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
@@ -566,12 +567,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method
channel = self.make_request(
"GET",
f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
f"/{self.url}/{self.media_id}{params}",
shorthand=False,
await_result=False,
)
self.pump()
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
b"Content-Type": [self.test_image.content_type],
@@ -580,7 +580,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
if expected_found:
self.assertEqual(channel.code, 200)
@@ -603,7 +602,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_UNKNOWN",
"error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
"error": f"Cannot find any thumbnails for the requested media ('/{self.url}/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
},
)
else:
@@ -613,7 +612,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body,
{
"errcode": "M_NOT_FOUND",
"error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
"error": f"Not found '/{self.url}/example.com/12345'",
},
)
@@ -625,12 +624,12 @@ class MediaRepoTests(unittest.HomeserverTestCase):
content_type = self.test_image.content_type.decode()
media_repo = self.hs.get_media_repository()
thumbnail_resouce = ThumbnailResource(
thumbnail_provider = ThumbnailProvider(
self.hs, media_repo, media_repo.media_storage
)
self.assertIsNotNone(
thumbnail_resouce._select_thumbnail(
thumbnail_provider._select_thumbnail(
desired_width=desired_size,
desired_height=desired_size,
desired_method=method,

View File

@@ -24,8 +24,8 @@ from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
from synapse.rest.client import account, keys, login, register
from synapse.rest import admin, devices, sync
from synapse.rest.client import keys, login, register
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@@ -33,146 +33,6 @@ from synapse.util import Clock
from tests import unittest
class DeviceListsTestCase(unittest.HomeserverTestCase):
"""Tests regarding device list changes."""
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
register.register_servlets,
account.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# Create a room for them to coexist peacefully in
new_room_id = self.helper.create_room_as(
alice_user_id, is_public=True, tok=alice_access_token
)
self.assertIsNotNone(new_room_id)
# Have Bob join the room
self.helper.invite(
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
)
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
# Now have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
"/sync",
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"/sync?since={next_batch_token}&timeout=30000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the
# *next* sync.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# These users do not share a room. They are lonely.
# Have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
"/sync",
access_token=bob_access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"/sync?since={next_batch_token}&timeout=1000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result()
self.assertEqual(
bob_sync_channel.code, HTTPStatus.OK, bob_sync_channel.json_body
)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
class DevicesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,

File diff suppressed because it is too large Load Diff

View File

@@ -18,15 +18,39 @@
# [This file includes modifications made by New Vector Limited]
#
#
from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class SendToDeviceTestCase(HomeserverTestCase):
"""
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -34,6 +58,11 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -54,7 +83,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -67,15 +96,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
}
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should re-appear if we do another sync
channel = self.make_request("GET", "/sync", access_token=user2_tok)
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
# it should *not* appear if we do an incremental sync
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body.get("to_device", {}).get("events", []), [])
@@ -99,15 +132,19 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three
channel = self.make_request("GET", "/sync", access_token=user2_tok)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
for i in range(2):
self.assertEqual(
msgs[i],
{"sender": user1, "type": "m.room_key_request", "content": {"idx": i}},
{
"sender": user1,
"type": "m.room_key_request",
"content": {"idx": i},
},
)
sync_token = channel.json_body["next_batch"]
@@ -125,7 +162,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -159,7 +198,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -193,7 +232,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET", f"/sync?since={sync_token}", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
@@ -217,7 +258,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request("GET", "/sync", access_token=user2_tok)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -233,7 +274,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])
@@ -241,7 +284,9 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET", f"/sync?since={sync_token}&timeout=300000", access_token=user2_tok
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
messages = channel.json_body.get("to_device", {}).get("events", [])

View File

@@ -21,7 +21,7 @@
import json
from typing import List
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor
@@ -688,24 +688,180 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device list (`device_lists`) changes.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# Create a room for them to coexist peacefully in
new_room_id = self.helper.create_room_as(
alice_user_id, is_public=True, tok=alice_access_token
)
self.assertIsNotNone(new_room_id)
# Have Bob join the room
self.helper.invite(
new_room_id, alice_user_id, bob_user_id, tok=alice_access_token
)
self.helper.join(new_room_id, bob_user_id, tok=bob_access_token)
# Now have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that bob's incremental sync contains the updated device list.
# If not, the client would only receive the device list update on the
# *next* sync.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertIn(alice_user_id, changed_device_lists, bob_sync_channel.json_body)
def test_not_receiving_local_device_list_changes(self) -> None:
"""Tests a local users DO NOT receive device updates from each other if they do not
share a room.
"""
# Register two users
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
bob_user_id = self.register_user("bob", "ponyponypony")
bob_access_token = self.login(bob_user_id, "ponyponypony")
# These users do not share a room. They are lonely.
# Have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch_token = channel.json_body["next_batch"]
# ...and then an incremental sync. This should block until the sync stream is woken up,
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
)
# Have alice update their device list
channel = self.make_request(
"PUT",
f"/devices/{test_device_id}",
{
"display_name": "New Device Name",
},
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that bob's incremental sync does not contain the updated device list.
bob_sync_channel.await_result()
self.assertEqual(bob_sync_channel.code, 200, bob_sync_channel.json_body)
changed_device_lists = bob_sync_channel.json_body.get("device_lists", {}).get(
"changed", []
)
self.assertNotIn(
alice_user_id, changed_device_lists, bob_sync_channel.json_body
)
def test_user_with_no_rooms_receives_self_device_list_updates(self) -> None:
"""Tests that a user with no rooms still receives their own device list updates"""
device_id = "TESTDEVICE"
test_device_id = "TESTDEVICE"
# Register a user and login, creating a device
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey", device_id=device_id)
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request("GET", "/sync", access_token=self.tok)
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -713,19 +869,19 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
f"/sync?since={next_batch}&timeout=30000",
access_token=self.tok,
f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
access_token=alice_access_token,
await_result=False,
)
# Change our device's display name
channel = self.make_request(
"PUT",
f"devices/{device_id}",
f"devices/{test_device_id}",
{
"display_name": "freeze ray",
},
access_token=self.tok,
access_token=alice_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -739,7 +895,230 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
).get("changed", [])
self.assertIn(
self.user_id, device_list_changes, incremental_sync_channel.json_body
alice_user_id, device_list_changes, incremental_sync_channel.json_body
)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device one time keys (`device_one_time_keys_count`) changes.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_one_time_keys(self) -> None:
"""
Tests when no one time keys set, it still has the default `signed_curve25519` in
`device_one_time_keys_count`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
{"signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that one time keys for the device/user are counted correctly in the `/sync`
response
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Upload one time keys for the user/device
keys: JsonDict = {
"alg1:k1": "key1",
"alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
"alg2:k3": {"key": "key3"},
}
res = self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id, test_device_id, {"one_time_keys": keys}
)
)
# Note that "signed_curve25519" is always returned in key count responses
# regardless of whether we uploaded any keys for it. This is necessary until
# https://github.com/matrix-org/matrix-doc/issues/3298 is fixed.
self.assertDictEqual(
res,
{"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}},
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertDictEqual(
channel.json_body["device_one_time_keys_count"],
{"alg1": 1, "alg2": 2, "signed_curve25519": 0},
channel.json_body["device_one_time_keys_count"],
)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
def test_no_device_unused_fallback_key(self) -> None:
"""
Test when no unused fallback key is set, it just returns an empty list. The MSC
says "The device_unused_fallback_key_types parameter must be present if the
server supports fallback keys.",
https://github.com/matrix-org/matrix-spec-proposals/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
[],
channel.json_body["device_unused_fallback_key_types"],
)
def test_returns_device_one_time_keys(self) -> None:
"""
Tests that device unused fallback key type is returned correctly in the `/sync`
"""
test_device_id = "TESTDEVICE"
alice_user_id = self.register_user("alice", "correcthorse")
alice_access_token = self.login(
alice_user_id, "correcthorse", device_id=test_device_id
)
# We shouldn't have any unused fallback keys yet
res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(res, [])
# Upload a fallback key for the user/device
fallback_key = {"alg1:k1": "fallback_key1"}
self.get_success(
self.e2e_keys_handler.upload_keys_for_user(
alice_user_id,
test_device_id,
{"fallback_keys": fallback_key},
)
)
# We should now have an unused alg1 key
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(alice_user_id, test_device_id)
)
self.assertEqual(fallback_res, ["alg1"], fallback_res)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for the unused fallback key types
self.assertListEqual(
channel.json_body["device_unused_fallback_key_types"],
["alg1"],
channel.json_body["device_unused_fallback_key_types"],
)