Compare commits

...

5 Commits

Author SHA1 Message Date
Hugh Nimmo-Smith
678fa567c9 Make user_type extensible and allow default user_type to be set
- Allows additional user_type values to be set in the config
- Allows a default user_type to be specified that is used when registering users (where a specific user_type hasn't been specified)
2025-06-02 16:27:40 +01:00
Hugh Nimmo-Smith
6a5f1f31e9 Media repository callbacks for the module API to control file upload size
Adds new callbacks for media related functionality:

- get_media_config_for_user
- is_user_allowed_to_upload_media_of_size
2025-06-02 16:27:22 +01:00
Hugh Nimmo-Smith
b7c3fc4ada Add user_may_send_state_event callback to spam checker module API 2025-06-02 16:21:50 +01:00
Hugh Nimmo-Smith
73f57490e7 Pass room_config argument to user_may_create_room spam checker module callback 2025-06-02 16:21:50 +01:00
Hugh Nimmo-Smith
479ef78873 Add ratelimit callbacks to module API to allow dynamic ratelimiting
Adds new callback `get_ratelimit_override_for_user` which is invoked for a small subset of limiter types.
2025-06-02 16:21:38 +01:00
37 changed files with 1200 additions and 86 deletions

View File

@@ -0,0 +1 @@
Add user_may_send_state_event callback to spam checker module API.

View File

@@ -0,0 +1 @@
Support configuration of default and extra user types.

View File

@@ -0,0 +1 @@
Add new module API callbacks that allows overriding of media repository maximum upload size.

View File

@@ -0,0 +1 @@
Add a new module API callback that allows overriding of per user ratelimits.

View File

@@ -0,0 +1 @@
Pass room_config argument to user_may_create_room spam checker module callback.

View File

@@ -49,6 +49,8 @@
- [Background update controller callbacks](modules/background_update_controller_callbacks.md)
- [Account data callbacks](modules/account_data_callbacks.md)
- [Add extra fields to client events unsigned section callbacks](modules/add_extra_fields_to_client_events_unsigned.md)
- [Ratelimit callbacks](modules/ratelimit_callbacks.md)
- [Media repository](modules/media_repository_callbacks.md)
- [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
- [Workers](workers.md)
- [Using `synctl` with Workers](synctl_workers.md)

View File

@@ -163,7 +163,8 @@ Body parameters:
- `locked` - **bool**, optional. If unspecified, locked state will be left unchanged.
- `user_type` - **string** or null, optional. If not provided, the user type will be
not be changed. If `null` is given, the user type will be cleared.
Other allowed options are: `bot` and `support`.
Other allowed options are: `bot` and `support` and any extra values defined in the homserver
[configuration](../usage/configuration/config_documentation.md#user_types).
## List Accounts
### List Accounts (V2)

View File

@@ -0,0 +1,47 @@
# Media repository callbacks
Media repository callbacks allow module developers to customise the behaviour of the
media repository on a per user basis. Media repository callbacks can be registered
using the module API's `register_media_repository_callbacks` method.
The available media repository callbacks are:
### `get_media_config_for_user`
_First introduced in Synapse v1.X.X_
```python
async def get_media_config_for_user(user: str) -> Optional[JsonDict]
```
Called when processing a request from a client for the configuration of the content
repository. The module can return a JSON dictionary that should be returned for the use
or `None` if the module is happy for the default dictionary to be used. The user is
represented by their Matrix user ID (e.g. `@alice:example.com`).
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
If no module returns a non-`None` value then the default configuration will be returned.
### `is_user_allowed_to_upload_media_of_size`
_First introduced in Synapse v1.X.X_
```python
async def is_user_allowed_to_upload_media_of_size(user: str, size: int) -> bool
```
Called before media is accepted for upload from a user, in case the module needs to
enforce a different limit for the particular user. The user is represented by their Matrix
user ID. The size is in bytes.
If the module returns `False`, the current request will be denied with the error code
`M_TOO_LARGE` and the HTTP status code 413.
If multiple modules implement this callback, they will be considered in order. If a callback
returns `True`, Synapse falls through to the next one. The value of the first callback that
returns `False` will be used. If this happens, Synapse will not call any of the subsequent
implementations of this callback.

View File

@@ -0,0 +1,33 @@
# Ratelimit callbacks
Ratelimit callbacks allow module developers to override ratelimit settings dynamically whilst
Synapse is running. Ratelimit callbacks can be registered using the module API's
`register_ratelimit_callbacks` method.
The available ratelimit callbacks are:
### `get_ratelimit_override_for_user`
_First introduced in Synapse v1.X.X_
```python
async def get_ratelimit_override_for_user(user: str, limiter_name: str) -> Optional[RatelimitOverride]
```
Called when constructing a ratelimiter of a particular type for a user. The module can
return a `messages_per_second` and `burst_count` to be used, or `None` if no
the default settings are adequate. The user is represented by their Matrix user ID
(e.g. `@alice:example.com`). The limiter name is usually taken from the `RatelimitSettings` key
value.
The limiters that are currently supported are:
- `rc_invites.per_room`
- `rc_invites.per_user`
- `rc_invites.per_issuer`
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. If no module returns a non-`None` value
then the default settings will be used.

View File

@@ -159,12 +159,19 @@ _First introduced in Synapse v1.37.0_
_Changed in Synapse v1.62.0: `synapse.module_api.NOT_SPAM` and `synapse.module_api.errors.Codes` can be returned by this callback. Returning a boolean is now deprecated._
_Changed in Synapse v1.x.x: Added the `room_config` argument. Callbacks that only expect a single `user_id` argument are still supported._
```python
async def user_may_create_room(user_id: str) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes", bool]
async def user_may_create_room(user_id: str, room_config: synapse.module_api.JsonDict) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes", bool]
```
Called when processing a room creation request.
The arguments passed to this callback are:
* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`).
* `room_config`: The contents of the body of a [/createRoom request](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom) as a dictionary.
The callback must return one of:
- `synapse.module_api.NOT_SPAM`, to allow the operation. Other callbacks may still
decide to reject it.
@@ -239,6 +246,36 @@ be used. If this happens, Synapse will not call any of the subsequent implementa
this callback.
### `user_may_send_state_event`
_First introduced in Synapse vX.X.X_
```python
async def user_may_send_state_event(user_id: str, room_id: str, event_type: str, state_key: str, content: JsonDict) -> Union["synapse.module_api.NOT_SPAM", "synapse.module_api.errors.Codes"]
```
Called when processing a request to [send state events](https://spec.matrix.org/latest/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey) to a room.
The arguments passed to this callback are:
* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) sending the state event.
* `room_id`: The ID of the room that the requested state event is being sent to.
* `event_type`: The requested type of event.
* `state_key`: The requested state key.
* `content`: The requested event contents.
The callback must return one of:
- `synapse.module_api.NOT_SPAM`, to allow the operation. Other callbacks may still
decide to reject it.
- `synapse.module_api.errors.Codes` to reject the operation with an error code. In case
of doubt, `synapse.module_api.errors.Codes.FORBIDDEN` is a good error code.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `synapse.module_api.NOT_SPAM`, Synapse falls through to the next one.
The value of the first callback that does not return `synapse.module_api.NOT_SPAM` will
be used. If this happens, Synapse will not call any of the subsequent implementations of
this callback.
### `check_username_for_spam`

View File

@@ -63,7 +63,7 @@ class ExampleSpamChecker:
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites
async def user_may_create_room(self, userid):
async def user_may_create_room(self, userid, room_config):
return True # allow all room creations
async def user_may_create_room_alias(self, userid, room_alias):

View File

@@ -834,6 +834,24 @@ Example configuration:
```yaml
max_event_delay_duration: 24h
```
---
### `user_types`
Configuration settings related to the user types feature.
This setting has the following sub-options:
* `default_user_type`: The default user type to use for registering new users when no value has been specified.
Defaults to none.
* `extra_user_types`: Array of additional user types to allow. These are treated as real users. Defaults to [].
Example configuration:
```yaml
user_types:
default_user_type: "custom"
extra_user_types:
- "custom"
- "custom2"
```
## Homeserver blocking
Useful options for Synapse admins.

View File

@@ -20,7 +20,7 @@
#
#
from typing import Dict, Hashable, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import RatelimitSettings
@@ -28,6 +28,12 @@ from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
if TYPE_CHECKING:
# To avoid circular imports:
from synapse.module_api.callbacks.ratelimit_callbacks import (
RatelimitModuleApiCallbacks,
)
class Ratelimiter:
"""
@@ -72,12 +78,14 @@ class Ratelimiter:
store: DataStore,
clock: Clock,
cfg: RatelimitSettings,
ratelimit_callbacks: Optional["RatelimitModuleApiCallbacks"] = None,
):
self.clock = clock
self.rate_hz = cfg.per_second
self.burst_count = cfg.burst_count
self.store = store
self._limiter_name = cfg.key
self._ratelimit_callbacks = ratelimit_callbacks
# A dictionary representing the token buckets tracked by this rate
# limiter. Each entry maps a key of arbitrary type to a tuple representing:
@@ -165,6 +173,20 @@ class Ratelimiter:
if override and not override.messages_per_second:
return True, -1.0
if requester and self._ratelimit_callbacks:
# Check if the user has a custom rate limit for this specific limiter
# as returned by the module API.
module_override = (
await self._ratelimit_callbacks.get_ratelimit_override_for_user(
requester.user.to_string(),
self._limiter_name,
)
)
if module_override:
rate_hz = module_override.messages_per_second
burst_count = module_override.burst_count
# Override default values if set
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz

View File

@@ -59,6 +59,7 @@ from synapse.config import ( # noqa: F401
tls,
tracer,
user_directory,
user_types,
voip,
workers,
)
@@ -122,6 +123,7 @@ class RootConfig:
retention: retention.RetentionConfig
background_updates: background_updates.BackgroundUpdateConfig
auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
user_types: user_types.UserTypesConfig
config_classes: List[Type["Config"]] = ...
config_files: List[str]

View File

@@ -59,6 +59,7 @@ from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig
from .tracer import TracerConfig
from .user_directory import UserDirectoryConfig
from .user_types import UserTypesConfig
from .voip import VoipConfig
from .workers import WorkerConfig
@@ -107,4 +108,5 @@ class HomeServerConfig(RootConfig):
ExperimentalConfig,
BackgroundUpdateConfig,
AutoAcceptInvitesConfig,
UserTypesConfig,
]

View File

@@ -0,0 +1,44 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Any, List, Optional
from synapse.api.constants import UserTypes
from synapse.types import JsonDict
from ._base import Config, ConfigError
class UserTypesConfig(Config):
section = "user_types"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
user_types: JsonDict = config.get("user_types", {})
self.default_user_type: Optional[str] = user_types.get(
"default_user_type", None
)
self.extra_user_types: List[str] = user_types.get("extra_user_types", [])
all_user_types: List[str] = []
all_user_types.extend(UserTypes.ALL_USER_TYPES)
all_user_types.extend(self.extra_user_types)
self.all_user_types = all_user_types
if self.default_user_type is not None:
if self.default_user_type not in all_user_types:
raise ConfigError(
f"Default user type {self.default_user_type} is not in the list of all user types: {all_user_types}"
)

View File

@@ -115,6 +115,7 @@ class RegistrationHandler:
self._user_consent_version = self.hs.config.consent.user_consent_version
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self._server_name = hs.hostname
self._user_types_config = hs.config.user_types
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
@@ -306,6 +307,9 @@ class RegistrationHandler:
elif default_display_name is None:
default_display_name = localpart
if user_type is None:
user_type = self._user_types_config.default_user_type
await self.register_with_store(
user_id=user_id,
password_hash=password_hash,

View File

@@ -468,17 +468,6 @@ class RoomCreationHandler:
"""
user_id = requester.user.to_string()
spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
user_id
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise SynapseError(
403,
"You are not permitted to create rooms",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
creation_content: JsonDict = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
@@ -585,6 +574,24 @@ class RoomCreationHandler:
if current_power_level_int < needed_power_level:
user_power_levels[user_id] = needed_power_level
# We construct what the body of a call to /createRoom would look like for passing
# to the spam checker. We don't include a preset here, as we expect the
# initial state to contain everything we need.
spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
user_id,
{
"creation_content": creation_content,
"initial_state": list(initial_state.items()),
},
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise SynapseError(
403,
"You are not permitted to create rooms",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
await self._send_events_for_new_room(
requester,
new_room_id,
@@ -786,7 +793,7 @@ class RoomCreationHandler:
if not is_requester_admin:
spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
user_id
user_id, config
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise SynapseError(

View File

@@ -158,6 +158,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
store=self.store,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_invites_per_room,
ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit,
)
# Ratelimiter for invites, keyed by recipient (across all rooms, all
@@ -166,6 +167,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
store=self.store,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_invites_per_user,
ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit,
)
# Ratelimiter for invites, keyed by issuer (across all rooms, all
@@ -174,6 +176,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
store=self.store,
clock=self.clock,
cfg=hs.config.ratelimiting.rc_invites_per_issuer,
ratelimit_callbacks=hs.get_module_api_callbacks().ratelimit,
)
self._third_party_invite_limiter = Ratelimiter(

View File

@@ -90,6 +90,13 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
ON_USER_LOGIN_CALLBACK,
ON_USER_REGISTRATION_CALLBACK,
)
from synapse.module_api.callbacks.media_repository_callbacks import (
GET_MEDIA_CONFIG_FOR_USER_CALLBACK,
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK,
)
from synapse.module_api.callbacks.ratelimit_callbacks import (
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK,
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
@@ -103,6 +110,7 @@ from synapse.module_api.callbacks.spamchecker_callbacks import (
USER_MAY_JOIN_ROOM_CALLBACK,
USER_MAY_PUBLISH_ROOM_CALLBACK,
USER_MAY_SEND_3PID_INVITE_CALLBACK,
USER_MAY_SEND_STATE_EVENT_CALLBACK,
SpamCheckerModuleApiCallbacks,
)
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
@@ -311,6 +319,7 @@ class ModuleApi:
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = None,
user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None,
user_may_send_state_event: Optional[USER_MAY_SEND_STATE_EVENT_CALLBACK] = None,
check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None,
check_registration_for_spam: Optional[
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
@@ -335,6 +344,7 @@ class ModuleApi:
check_registration_for_spam=check_registration_for_spam,
check_media_file_for_spam=check_media_file_for_spam,
check_login_for_spam=check_login_for_spam,
user_may_send_state_event=user_may_send_state_event,
)
def register_account_validity_callbacks(
@@ -360,6 +370,36 @@ class ModuleApi:
on_legacy_admin_request=on_legacy_admin_request,
)
def register_ratelimit_callbacks(
self,
*,
get_ratelimit_override_for_user: Optional[
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK
] = None,
) -> None:
"""Registers callbacks for ratelimit capabilities.
Added in Synapse v1.x.x.
"""
return self._callbacks.ratelimit.register_callbacks(
get_ratelimit_override_for_user=get_ratelimit_override_for_user,
)
def register_media_repository_callbacks(
self,
*,
get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None,
is_user_allowed_to_upload_media_of_size: Optional[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = None,
) -> None:
"""Registers callbacks for media repository capabilities.
Added in Synapse v1.x.x.
"""
return self._callbacks.media_repository.register_callbacks(
get_media_config_for_user=get_media_config_for_user,
is_user_allowed_to_upload_media_of_size=is_user_allowed_to_upload_media_of_size,
)
def register_third_party_rules_callbacks(
self,
*,

View File

@@ -27,6 +27,12 @@ if TYPE_CHECKING:
from synapse.module_api.callbacks.account_validity_callbacks import (
AccountValidityModuleApiCallbacks,
)
from synapse.module_api.callbacks.media_repository_callbacks import (
MediaRepositoryModuleApiCallbacks,
)
from synapse.module_api.callbacks.ratelimit_callbacks import (
RatelimitModuleApiCallbacks,
)
from synapse.module_api.callbacks.spamchecker_callbacks import (
SpamCheckerModuleApiCallbacks,
)
@@ -38,5 +44,7 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
class ModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
self.account_validity = AccountValidityModuleApiCallbacks()
self.ratelimit = RatelimitModuleApiCallbacks(hs)
self.media_repository = MediaRepositoryModuleApiCallbacks(hs)
self.spam_checker = SpamCheckerModuleApiCallbacks(hs)
self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks(hs)

View File

@@ -0,0 +1,76 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
from synapse.types import JsonDict
from synapse.util.async_helpers import delay_cancellation
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
GET_MEDIA_CONFIG_FOR_USER_CALLBACK = Callable[[str], Awaitable[Optional[JsonDict]]]
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK = Callable[[str, int], Awaitable[bool]]
class MediaRepositoryModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
self.clock = hs.get_clock()
self._get_media_config_for_user_callbacks: List[
GET_MEDIA_CONFIG_FOR_USER_CALLBACK
] = []
self._is_user_allowed_to_upload_media_of_size_callbacks: List[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = []
def register_callbacks(
self,
get_media_config_for_user: Optional[GET_MEDIA_CONFIG_FOR_USER_CALLBACK] = None,
is_user_allowed_to_upload_media_of_size: Optional[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = None,
) -> None:
"""Register callbacks from module for each hook."""
if get_media_config_for_user is not None:
self._get_media_config_for_user_callbacks.append(get_media_config_for_user)
if is_user_allowed_to_upload_media_of_size is not None:
self._is_user_allowed_to_upload_media_of_size_callbacks.append(
is_user_allowed_to_upload_media_of_size
)
async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]:
for callback in self._get_media_config_for_user_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res: Optional[JsonDict] = await delay_cancellation(callback(user_id))
if res:
return res
return None
async def is_user_allowed_to_upload_media_of_size(
self, user_id: str, size: int
) -> bool:
for callback in self._is_user_allowed_to_upload_media_of_size_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res: bool = await delay_cancellation(callback(user_id, size))
if not res:
return res
return True

View File

@@ -0,0 +1,62 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
from synapse.storage.databases.main.room import RatelimitOverride
from synapse.util.async_helpers import delay_cancellation
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK = Callable[
[str, str], Awaitable[Optional[RatelimitOverride]]
]
class RatelimitModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
self.clock = hs.get_clock()
self._get_ratelimit_override_for_user_callbacks: List[
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK
] = []
def register_callbacks(
self,
get_ratelimit_override_for_user: Optional[
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK
] = None,
) -> None:
"""Register callbacks from module for each hook."""
if get_ratelimit_override_for_user is not None:
self._get_ratelimit_override_for_user_callbacks.append(
get_ratelimit_override_for_user
)
async def get_ratelimit_override_for_user(
self, user_id: str, limiter_name: str
) -> Optional[RatelimitOverride]:
for callback in self._get_ratelimit_override_for_user_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res: Optional[RatelimitOverride] = await delay_cancellation(
callback(user_id, limiter_name)
)
if res:
return res
return None

View File

@@ -120,20 +120,24 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[
]
],
]
USER_MAY_CREATE_ROOM_CALLBACK = Callable[
[str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE = Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
USER_MAY_CREATE_ROOM_CALLBACK = Union[
Callable[
[str, JsonDict],
Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE],
],
Callable[ # Single argument variant for backwards compatibility
[str], Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE]
],
]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[
@@ -168,6 +172,20 @@ USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
]
],
]
USER_MAY_SEND_STATE_EVENT_CALLBACK = Callable[
[str, str, str, str, JsonDict],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
]
],
]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Union[
Callable[[UserProfile], Awaitable[bool]],
Callable[[UserProfile, str], Awaitable[bool]],
@@ -332,6 +350,9 @@ class SpamCheckerModuleApiCallbacks:
USER_MAY_SEND_3PID_INVITE_CALLBACK
] = []
self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
self._user_may_send_state_event_callbacks: List[
USER_MAY_SEND_STATE_EVENT_CALLBACK
] = []
self._user_may_create_room_alias_callbacks: List[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = []
@@ -367,6 +388,7 @@ class SpamCheckerModuleApiCallbacks:
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
check_login_for_spam: Optional[CHECK_LOGIN_FOR_SPAM_CALLBACK] = None,
user_may_send_state_event: Optional[USER_MAY_SEND_STATE_EVENT_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
@@ -391,6 +413,11 @@ class SpamCheckerModuleApiCallbacks:
if user_may_create_room is not None:
self._user_may_create_room_callbacks.append(user_may_create_room)
if user_may_send_state_event is not None:
self._user_may_send_state_event_callbacks.append(
user_may_send_state_event,
)
if user_may_create_room_alias is not None:
self._user_may_create_room_alias_callbacks.append(
user_may_create_room_alias,
@@ -622,16 +649,41 @@ class SpamCheckerModuleApiCallbacks:
return self.NOT_SPAM
async def user_may_create_room(
self, userid: str
self, userid: str, room_config: JsonDict
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a given user may create a room
Args:
userid: The ID of the user attempting to create a room
room_config: The room creation configuration which is the body of the /createRoom request
"""
for callback in self._user_may_create_room_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
res = await delay_cancellation(callback(userid))
checker_args = inspect.signature(callback)
# Also ensure backwards compatibility with spam checker callbacks
# that don't expect the room_config argument.
if len(checker_args.parameters) == 2:
callback_with_requester_id = cast(
Callable[
[str, JsonDict],
Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE],
],
callback,
)
# We make a copy of the config to ensure the spam checker cannot modify it.
res = await delay_cancellation(
callback_with_requester_id(userid, room_config.copy())
)
else:
callback_without_requester_id = cast(
Callable[
[str], Awaitable[USER_MAY_CREATE_ROOM_CALLBACK_RETURN_VALUE]
],
callback,
)
res = await delay_cancellation(
callback_without_requester_id(userid)
)
if res is True or res is self.NOT_SPAM:
continue
elif res is False:
@@ -653,6 +705,37 @@ class SpamCheckerModuleApiCallbacks:
return self.NOT_SPAM
async def user_may_send_state_event(
self,
userid: str,
room_id: str,
event_type: str,
state_key: str,
content: JsonDict,
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:
"""Checks if a given user may create a room with a given visibility
Args:
userid: The ID of the user attempting to create a room
visibility: The visibility of the room to be created
"""
for callback in self._user_may_send_state_event_callbacks:
with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
# We make a copy of the content to ensure that the spam checker cannot modify it.
res = await delay_cancellation(
callback(userid, room_id, event_type, state_key, content.copy())
)
if res is self.NOT_SPAM:
continue
elif isinstance(res, synapse.api.errors.Codes):
return res, {}
else:
logger.warning(
"Module returned invalid value, rejecting room creation as spam"
)
return synapse.api.errors.Codes.FORBIDDEN, {}
return self.NOT_SPAM
async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias
) -> Union[Tuple[Codes, dict], Literal["NOT_SPAM"]]:

View File

@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
from synapse._pydantic_compat import StrictBool, StrictInt, StrictStr
from synapse.api.constants import Direction, UserTypes
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -230,6 +230,7 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
self._all_user_types = hs.config.user_types.all_user_types
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -277,7 +278,7 @@ class UserRestServletV2(RestServlet):
assert_params_in_dict(external_id, ["auth_provider", "external_id"])
user_type = body.get("user_type", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
if user_type is not None and user_type not in self._all_user_types:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
set_admin_to = body.get("admin", False)
@@ -524,6 +525,7 @@ class UserRegisterServlet(RestServlet):
self.reactor = hs.get_reactor()
self.nonces: Dict[str, int] = {}
self.hs = hs
self._all_user_types = hs.config.user_types.all_user_types
def _clear_old_nonces(self) -> None:
"""
@@ -605,7 +607,7 @@ class UserRegisterServlet(RestServlet):
user_type = body.get("user_type", None)
displayname = body.get("displayname", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
if user_type is not None and user_type not in self._all_user_types:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
if "mac" not in body:

View File

@@ -102,10 +102,17 @@ class MediaConfigResource(RestServlet):
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.media.max_upload_size}
self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository
async def on_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
requester = await self.auth.get_user_by_req(request)
user_specific_config = (
await self.media_repository_callbacks.get_media_config_for_user(
requester.user.to_string(),
)
)
response = user_specific_config if user_specific_config else self.limits_dict
respond_with_json(request, 200, response, send_cors=True)
class ThumbnailResource(RestServlet):

View File

@@ -198,6 +198,7 @@ class RoomStateEventRestServlet(RestServlet):
self.delayed_events_handler = hs.get_delayed_events_handler()
self.auth = hs.get_auth()
self._max_event_delay_ms = hs.config.server.max_event_delay_ms
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/state/$eventtype
@@ -289,6 +290,25 @@ class RoomStateEventRestServlet(RestServlet):
content = parse_json_object_from_request(request)
is_requester_admin = await self.auth.is_server_admin(requester)
if not is_requester_admin:
spam_check = (
await self._spam_checker_module_callbacks.user_may_send_state_event(
userid=requester.user.to_string(),
room_id=room_id,
event_type=event_type,
state_key=state_key,
content=content,
)
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise SynapseError(
403,
"You are not permitted to send the state event",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
origin_server_ts = None
if requester.app_service:
origin_server_ts = parse_integer(request, "ts")

View File

@@ -40,7 +40,14 @@ class MediaConfigResource(RestServlet):
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.media.max_upload_size}
self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository
async def on_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
requester = await self.auth.get_user_by_req(request)
user_specific_config = (
await self.media_repository_callbacks.get_media_config_for_user(
requester.user.to_string()
)
)
response = user_specific_config if user_specific_config else self.limits_dict
respond_with_json(request, 200, response, send_cors=True)

View File

@@ -50,9 +50,12 @@ class BaseUploadServlet(RestServlet):
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.max_upload_size = hs.config.media.max_upload_size
self._media_repository_callbacks = (
hs.get_module_api_callbacks().media_repository
)
def _get_file_metadata(
self, request: SynapseRequest
async def _get_file_metadata(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, Optional[str], str]:
raw_content_length = request.getHeader("Content-Length")
if raw_content_length is None:
@@ -67,7 +70,14 @@ class BaseUploadServlet(RestServlet):
code=413,
errcode=Codes.TOO_LARGE,
)
if not await self._media_repository_callbacks.is_user_allowed_to_upload_media_of_size(
user_id, content_length
):
raise SynapseError(
msg="Upload request body is too large",
code=413,
errcode=Codes.TOO_LARGE,
)
args: Dict[bytes, List[bytes]] = request.args # type: ignore
upload_name_bytes = parse_bytes_from_args(args, "filename")
if upload_name_bytes:
@@ -104,7 +114,9 @@ class UploadServlet(BaseUploadServlet):
async def on_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
content_length, upload_name, media_type = self._get_file_metadata(request)
content_length, upload_name, media_type = await self._get_file_metadata(
request, requester.user.to_string()
)
try:
content: IO = request.content # type: ignore
@@ -152,7 +164,9 @@ class AsyncUploadServlet(BaseUploadServlet):
async with lock:
await self.media_repo.verify_can_upload(media_id, requester.user)
content_length, upload_name, media_type = self._get_file_metadata(request)
content_length, upload_name, media_type = await self._get_file_metadata(
request, requester.user.to_string()
)
try:
content: IO = request.content # type: ignore

View File

@@ -583,7 +583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
async def set_user_type(
self, user: UserID, user_type: Optional[Union[UserTypes, str]]
) -> None:
"""Sets the user type.
Args:
@@ -683,7 +685,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
retcol="user_type",
allow_none=True,
)
return res is None
return res is None or res not in [UserTypes.BOT, UserTypes.SUPPORT]
def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
@@ -959,10 +961,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)
async def count_real_users(self) -> int:
"""Counts all users without a special user_type registered on the homeserver."""
"""Counts all users without the bot or support user_types registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) FROM users where user_type is null")
txn.execute(
f"SELECT COUNT(*) FROM users WHERE user_type IS NULL OR user_type NOT IN ('{UserTypes.BOT}', '{UserTypes.SUPPORT}')"
)
row = txn.fetchone()
assert row is not None
return row[0]
@@ -2545,7 +2549,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
the user, setting their displayname to the given value
admin: is an admin user?
user_type: type of user. One of the values from api.constants.UserTypes,
or None for a normal user.
a custom value set in the configuration file, or None for a normal
user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
approved: Whether to consider the user has already been approved by an

View File

@@ -77,7 +77,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RatelimitOverride:
messages_per_second: int
messages_per_second: float
burst_count: int

View File

@@ -1,6 +1,10 @@
from typing import Optional
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
from synapse.config.ratelimiting import RatelimitSettings
from synapse.module_api.callbacks.ratelimit_callbacks import RatelimitModuleApiCallbacks
from synapse.storage.databases.main.room import RatelimitOverride
from synapse.types import create_requester
from tests import unittest
@@ -440,3 +444,49 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
)
self.assertTrue(success)
def test_get_ratelimit_override_for_user_callback(self) -> None:
test_user_id = "@user:test"
test_limiter_name = "name"
callbacks = RatelimitModuleApiCallbacks(self.hs)
requester = create_requester(test_user_id)
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
cfg=RatelimitSettings(
test_limiter_name,
per_second=0.1,
burst_count=3,
),
ratelimit_callbacks=callbacks,
)
# Observe four actions, exceeding the burst_count.
limiter.record_action(requester=requester, n_actions=4, _time_now_s=0.0)
# We should be prevented from taking a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=requester, _time_now_s=0.0)
)
self.assertFalse(success)
# Now register a callback that overrides the ratelimit for this user
# and limiter name.
async def get_ratelimit_override_for_user(
user_id: str, limiter_name: str
) -> Optional[RatelimitOverride]:
if user_id == test_user_id:
return RatelimitOverride(
messages_per_second=0.1,
burst_count=10,
)
return None
callbacks.register_callbacks(
get_ratelimit_override_for_user=get_ratelimit_override_for_user
)
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=requester, _time_now_s=0.0)
)
self.assertTrue(success)

View File

@@ -738,6 +738,41 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
)
def test_register_default_user_type(self) -> None:
"""Test that the default user type is none when registering a user."""
user_id = self.get_success(self.handler.register_user(localpart="user"))
user_info = self.get_success(self.store.get_user_by_id(user_id))
assert user_info is not None
self.assertEqual(user_info.user_type, None)
def test_register_extra_user_types_valid(self) -> None:
"""
Test that the specified user type is set correctly when registering a user.
n.b. No validation is done on the user type, so this test
is only to ensure that the user type can be set to any value.
"""
user_id = self.get_success(
self.handler.register_user(localpart="user", user_type="anyvalue")
)
user_info = self.get_success(self.store.get_user_by_id(user_id))
assert user_info is not None
self.assertEqual(user_info.user_type, "anyvalue")
@override_config(
{
"user_types": {
"extra_user_types": ["extra1", "extra2"],
"default_user_type": "extra1",
}
}
)
def test_register_extra_user_types_with_default(self) -> None:
"""Test that the default_user_type in config is set correctly when registering a user."""
user_id = self.get_success(self.handler.register_user(localpart="user"))
user_info = self.get_success(self.store.get_user_by_id(user_id))
assert user_info is not None
self.assertEqual(user_info.user_type, "extra1")
async def get_or_create_user(
self,
requester: Requester,

View File

@@ -1360,3 +1360,42 @@ class MediaHashesTestCase(unittest.HomeserverTestCase):
store_media.sha256,
SMALL_PNG_SHA256,
)
class MediaRepoSizeModuleCallbackTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
admin.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
self.mock_result = True # Allow all uploads by default
hs.get_module_api().register_media_repository_callbacks(
is_user_allowed_to_upload_media_of_size=self.is_user_allowed_to_upload_media_of_size,
)
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
async def is_user_allowed_to_upload_media_of_size(
self, user_id: str, size: int
) -> bool:
self.last_user_id = user_id
self.last_size = size
return self.mock_result
def test_upload_allowed(self) -> None:
self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
assert self.last_user_id == self.user
assert self.last_size == len(SMALL_PNG)
def test_upload_not_allowed(self) -> None:
self.mock_result = False
self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=413)
assert self.last_user_id == self.user
assert self.last_size == len(SMALL_PNG)

View File

@@ -0,0 +1,243 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
#
from typing import Literal, Union
from twisted.test.proto_helpers import MemoryReactor
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.rest import admin, login, room, room_upgrade_rest_servlet
from synapse.server import HomeServer
from synapse.types import Codes, JsonDict
from synapse.util import Clock
from tests.server import FakeChannel
from tests.unittest import HomeserverTestCase
class SpamCheckerTestCase(HomeserverTestCase):
servlets = [
room.register_servlets,
admin.register_servlets,
login.register_servlets,
room_upgrade_rest_servlet.register_servlets,
]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self._module_api = homeserver.get_module_api()
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
def create_room(self, content: JsonDict) -> FakeChannel:
channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
content,
access_token=self.token,
)
return channel
def test_may_user_create_room(self) -> None:
"""Test that the may_user_create_room callback is called when a user
creates a room, and that it receives the correct parameters.
"""
async def user_may_create_room(
user_id: str, room_config: JsonDict
) -> Union[Literal["NOT_SPAM"], Codes]:
self.last_room_config = room_config
self.last_user_id = user_id
return "NOT_SPAM"
self._module_api.register_spam_checker_callbacks(
user_may_create_room=user_may_create_room
)
channel = self.create_room({"foo": "baa"})
self.assertEqual(channel.code, 200)
self.assertEqual(self.last_user_id, self.user_id)
self.assertEqual(self.last_room_config["foo"], "baa")
def test_may_user_create_room_on_upgrade(self) -> None:
"""Test that the may_user_create_room callback is called when a room is upgraded."""
# First, create a room to upgrade.
channel = self.create_room({"topic": "foo"})
self.assertEqual(channel.code, 200)
room_id = channel.json_body["room_id"]
async def user_may_create_room(
user_id: str, room_config: JsonDict
) -> Union[Literal["NOT_SPAM"], Codes]:
self.last_room_config = room_config
self.last_user_id = user_id
return "NOT_SPAM"
# Register the callback for spam checking.
self._module_api.register_spam_checker_callbacks(
user_may_create_room=user_may_create_room
)
# Now upgrade the room.
channel = self.make_request(
"POST",
f"/_matrix/client/r0/rooms/{room_id}/upgrade",
# This will upgrade a room to the same version, but that's fine.
content={"new_version": DEFAULT_ROOM_VERSION},
access_token=self.token,
)
# Check that the callback was called and the room was upgraded.
self.assertEqual(channel.code, 200)
self.assertEqual(self.last_user_id, self.user_id)
# Check that the initial state received by callback contains the topic event.
self.assertTrue(
any(
event[0][0] == "m.room.topic" and event[1].get("topic") == "foo"
for event in self.last_room_config["initial_state"]
)
)
def test_may_user_create_room_disallowed(self) -> None:
"""Test that the codes response from may_user_create_room callback is respected
and returned via the API.
"""
async def user_may_create_room(
user_id: str, room_config: JsonDict
) -> Union[Literal["NOT_SPAM"], Codes]:
self.last_room_config = room_config
self.last_user_id = user_id
return Codes.UNAUTHORIZED
self._module_api.register_spam_checker_callbacks(
user_may_create_room=user_may_create_room
)
channel = self.create_room({"foo": "baa"})
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(self.last_user_id, self.user_id)
self.assertEqual(self.last_room_config["foo"], "baa")
def test_may_user_create_room_compatibility(self) -> None:
"""Test that the may_user_create_room callback is called when a user
creates a room for a module that uses the old callback signature
(without the `room_config` parameter)
"""
async def user_may_create_room(
user_id: str,
) -> Union[Literal["NOT_SPAM"], Codes]:
self.last_user_id = user_id
return "NOT_SPAM"
self._module_api.register_spam_checker_callbacks(
user_may_create_room=user_may_create_room
)
channel = self.create_room({"foo": "baa"})
self.assertEqual(channel.code, 200)
def test_user_may_send_state_event(self) -> None:
"""Test that the user_may_send_state_event callback is called when a state event
is sent, and that it receives the correct parameters.
"""
async def user_may_send_state_event(
user_id: str,
room_id: str,
event_type: str,
state_key: str,
content: JsonDict,
) -> Union[Literal["NOT_SPAM"], Codes]:
self.last_user_id = user_id
self.last_room_id = room_id
self.last_event_type = event_type
self.last_state_key = state_key
self.last_content = content
return "NOT_SPAM"
self._module_api.register_spam_checker_callbacks(
user_may_send_state_event=user_may_send_state_event
)
channel = self.create_room({})
self.assertEqual(channel.code, 200)
room_id = channel.json_body["room_id"]
event_type = "test.event.type"
state_key = "test.state.key"
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/%s/%s"
% (
room_id,
event_type,
state_key,
),
content={"foo": "bar"},
access_token=self.token,
)
self.assertEqual(channel.code, 200)
self.assertEqual(self.last_user_id, self.user_id)
self.assertEqual(self.last_room_id, room_id)
self.assertEqual(self.last_event_type, event_type)
self.assertEqual(self.last_state_key, state_key)
self.assertEqual(self.last_content, {"foo": "bar"})
def test_user_may_send_state_event_disallows(self) -> None:
"""Test that the user_may_send_state_event callback is called when a state event
is sent, and that the response is honoured.
"""
async def user_may_send_state_event(
user_id: str,
room_id: str,
event_type: str,
state_key: str,
content: JsonDict,
) -> Union[Literal["NOT_SPAM"], Codes]:
return Codes.FORBIDDEN
self._module_api.register_spam_checker_callbacks(
user_may_send_state_event=user_may_send_state_event
)
channel = self.create_room({})
self.assertEqual(channel.code, 200)
room_id = channel.json_body["room_id"]
event_type = "test.event.type"
state_key = "test.state.key"
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/%s/%s"
% (
room_id,
event_type,
state_key,
),
content={"foo": "bar"},
access_token=self.token,
)
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)

View File

@@ -328,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
@override_config(
{
"user_types": {
"extra_user_types": ["extra1", "extra2"],
}
}
)
def test_extra_user_type(self) -> None:
"""
Check that the extra user type can be used when registering a user.
"""
def nonce_mac(user_type: str) -> tuple[str, str]:
"""
Get a nonce and the expected HMAC for that nonce.
"""
channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(
nonce.encode("ascii")
+ b"\x00alice\x00abc123\x00notadmin\x00"
+ user_type.encode("ascii")
)
want_mac_str = want_mac.hexdigest()
return nonce, want_mac_str
nonce, mac = nonce_mac("extra1")
# Valid user_type
body = {
"nonce": nonce,
"username": "alice",
"password": "abc123",
"user_type": "extra1",
"mac": mac,
}
channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
nonce, mac = nonce_mac("extra3")
# Invalid user_type
body = {
"nonce": nonce,
"username": "alice",
"password": "abc123",
"user_type": "extra3",
"mac": mac,
}
channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self) -> None:
"""
Test that displayname of new user is set
@@ -1186,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase):
not_user_types=["custom"],
)
@override_config(
{
"user_types": {
"extra_user_types": ["extra1", "extra2"],
}
}
)
def test_filter_not_user_types_with_extra(self) -> None:
"""Tests that the endpoint handles the not_user_types param when extra_user_types are configured"""
regular_user_id = self.register_user("normalo", "secret")
extra1_user_id = self.register_user("extra1", "secret")
self.make_request(
"PUT",
"/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id),
{"user_type": "extra1"},
access_token=self.admin_user_tok,
)
def test_user_type(
expected_user_ids: List[str], not_user_types: Optional[List[str]] = None
) -> None:
"""Runs a test for the not_user_types param
Args:
expected_user_ids: Ids of the users that are expected to be returned
not_user_types: List of values for the not_user_types param
"""
user_type_query = ""
if not_user_types is not None:
user_type_query = "&".join(
[f"not_user_type={u}" for u in not_user_types]
)
test_url = f"{self.url}?{user_type_query}"
channel = self.make_request(
"GET",
test_url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code)
self.assertEqual(channel.json_body["total"], len(expected_user_ids))
self.assertEqual(
expected_user_ids,
[u["name"] for u in channel.json_body["users"]],
)
# Request without user_types → all users expected
test_user_type([self.admin_user, extra1_user_id, regular_user_id])
# Request and exclude extra1 user type
test_user_type(
[self.admin_user, regular_user_id],
not_user_types=["extra1"],
)
# Request and exclude extra1 and extra2 user types
test_user_type(
[self.admin_user, regular_user_id],
not_user_types=["extra1", "extra2"],
)
# Request and exclude empty user types → only expected the extra1 user
test_user_type([extra1_user_id], not_user_types=[""])
# Request and exclude an unregistered type → expect all users
test_user_type(
[self.admin_user, extra1_user_id, regular_user_id],
not_user_types=["extra3"],
)
def test_erasure_status(self) -> None:
# Create a new user.
user_id = self.register_user("eraseme", "eraseme")
@@ -2977,56 +3106,66 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
def set_user_type(self, user_type: Optional[str]) -> None:
# Set to user_type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"user_type": user_type},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(user_type, channel.json_body["user_type"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(user_type, channel.json_body["user_type"])
def test_set_user_type(self) -> None:
"""
Test changing user type.
"""
# Set to support type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"user_type": UserTypes.SUPPORT},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
self.set_user_type(UserTypes.SUPPORT)
# Change back to a regular user
self.set_user_type(None)
@override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}})
def test_set_user_type_with_extras(self) -> None:
"""
Test changing user type with extra_user_types configured.
"""
# Check that we can still set to support type
self.set_user_type(UserTypes.SUPPORT)
# Check that we can set to an extra user type
self.set_user_type("extra2")
# Change back to a regular user
self.set_user_type(None)
# Try setting to invalid type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"user_type": None},
content={"user_type": "extra3"},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_accidental_deactivation_prevention(self) -> None:
"""

View File

@@ -1618,6 +1618,63 @@ class MediaConfigTest(unittest.HomeserverTestCase):
)
class MediaConfigModuleCallbackTestCase(unittest.HomeserverTestCase):
servlets = [
media.register_servlets,
admin.register_servlets,
login.register_servlets,
]
def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer:
config = self.default_config()
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = self.register_user("user", "password")
self.tok = self.login("user", "password")
hs.get_module_api().register_media_repository_callbacks(
get_media_config_for_user=self.get_media_config_for_user,
)
async def get_media_config_for_user(
self,
user_id: str,
) -> Optional[JsonDict]:
# We echo back the user_id and set a custom upload size.
return {"m.upload.size": 1024, "user_id": user_id}
def test_media_config(self) -> None:
channel = self.make_request(
"GET",
"/_matrix/client/v1/media/config",
shorthand=False,
access_token=self.tok,
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["m.upload.size"], 1024)
self.assertEqual(channel.json_body["user_id"], self.user)
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
servlets = [
media.register_servlets,