Compare commits

..

12 Commits

Author SHA1 Message Date
Andrew Morgan
3dcc1efc43 Move callback-related code from AccountData to its own class 2023-03-09 16:50:31 +00:00
Andrew Morgan
46c0ab559b Move callback-related code from the PasswordAuthProvider to its own class 2023-03-09 16:50:31 +00:00
Andrew Morgan
e8cdfc771b Move callback-related code from the BackgroundUpdater to its own class 2023-03-09 16:50:31 +00:00
Andrew Morgan
1b30b82ac6 Move callback-related code from the PresenceRouter to its own class 2023-03-09 16:50:31 +00:00
Andrew Morgan
266f426c50 Move callback-related code from ThirdPartyEventRules to its own class
And update the many references.
2023-03-09 16:50:31 +00:00
Andrew Morgan
c3c3c6d200 Move callback-related code from the SpamChecker to its own class
And update the many references.
2023-03-09 16:50:31 +00:00
Andrew Morgan
9cd8fecdc5 Move Account Validity callbacks to a dedicated file
Spreading module callback registration across the codebase is both a bit
messy and makes it unclear where a user should register callbacks if
they want to define a new class of callbacks.

Consolidating these under the synapse.module_api module cleans things up
a bit, puts related code in the same place and makes it much more
obvious how to extend it.
2023-03-09 16:50:31 +00:00
Patrick Cloke
f4fc83ac75 Add a missing endpoint to the workers documentation. (#15223) 2023-03-08 07:51:34 -05:00
Shay
a368d30c1c More speedups/fixes to creating batched events (#15195) 2023-03-07 13:54:39 -08:00
Patrick Cloke
20ed8c926b Stabilize support for MSC3873: disambuguated event push keys. (#15190)
This removes the experimental configuration option and
always escapes the push rule condition keys.

Also escapes any (experimental) push rule condition keys
in the base rules which contain dot in a field name.
2023-03-07 11:27:57 -05:00
Quentin Gliech
47bc84dd53 Pass the Requester down to the HttpTransactionCache. (#15200) 2023-03-07 16:05:22 +00:00
Patrick Cloke
820f02b70b Stabilize support for MSC3966: event_property_contains push condition. (#15187)
This removes the configuration flag & updates the identifiers to
use the stable version.
2023-03-07 10:06:02 -05:00
61 changed files with 1696 additions and 1592 deletions

View File

@@ -1,25 +1,3 @@
Synapse 1.79.0 (2023-03-14)
===========================
No significant changes since 1.79.0rc2.
Synapse 1.79.0rc2 (2023-03-13)
==============================
Bugfixes
--------
- Fix a bug introduced in Synapse 1.79.0rc1 where attempting to register a `on_remove_user_third_party_identifier` module API callback would be a no-op. ([\#15227](https://github.com/matrix-org/synapse/issues/15227))
- Fix a rare bug introduced in Synapse 1.73 where events could remain unsent to other homeservers after a faster-join to a room. ([\#15248](https://github.com/matrix-org/synapse/issues/15248))
Internal Changes
----------------
- Refactor `filter_events_for_server`. ([\#15240](https://github.com/matrix-org/synapse/issues/15240))
Synapse 1.79.0rc1 (2023-03-07)
==============================
@@ -69,7 +47,7 @@ Improved Documentation
Deprecations and Removals
-------------------------
- Deprecate the `on_threepid_bind` module callback, to be replaced by [`on_add_user_third_party_identifier`](https://matrix-org.github.io/synapse/v1.79/modules/third_party_rules_callbacks.html#on_add_user_third_party_identifier). See [upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.79/docs/upgrade.md#upgrading-to-v1790). ([\#15044](https://github.com/matrix-org/synapse/issues/15044))
- Deprecate the `on_threepid_bind` module callback, to be replaced by [`on_add_user_third_party_identifier`](https://matrix-org.github.io/synapse/v1.79/modules/third_party_rules_callbacks.html#on_add_user_third_party_identifier). See [upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.79/docs/upgrade.md#upgrading-to-v1790). ([\#15044]
- Remove the unspecced `room_alias` field from the [`/createRoom`](https://spec.matrix.org/v1.6/client-server-api/#post_matrixclientv3createroom) response. ([\#15093](https://github.com/matrix-org/synapse/issues/15093))
- Remove the unspecced `PUT` on the `/knock/{roomIdOrAlias}` endpoint. ([\#15189](https://github.com/matrix-org/synapse/issues/15189))
- Remove the undocumented and unspecced `type` parameter to the `/thumbnail` endpoint. ([\#15137](https://github.com/matrix-org/synapse/issues/15137))

View File

@@ -0,0 +1 @@
Stabilise support for [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966): `event_property_contains` push condition.

1
changelog.d/15190.bugfix Normal file
View File

@@ -0,0 +1 @@
Implement [MSC3873](https://github.com/matrix-org/matrix-spec-proposals/pull/3873) to fix a long-standing bug where properties with dots were handled ambiguously in push rules.

1
changelog.d/15195.misc Normal file
View File

@@ -0,0 +1 @@
Improve performance of creating and authenticating events.

1
changelog.d/15200.misc Normal file
View File

@@ -0,0 +1 @@
Make the `HttpTransactionCache` use the `Requester` in addition of the just the `Request` to build the transaction key.

1
changelog.d/15223.doc Normal file
View File

@@ -0,0 +1 @@
Add a missing endpoint to the workers documentation.

12
debian/changelog vendored
View File

@@ -1,15 +1,3 @@
matrix-synapse-py3 (1.79.0) stable; urgency=medium
* New Synapse release 1.79.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 14 Mar 2023 16:14:50 +0100
matrix-synapse-py3 (1.79.0~rc2) stable; urgency=medium
* New Synapse release 1.79.0rc2.
-- Synapse Packaging team <packages@matrix.org> Mon, 13 Mar 2023 12:54:21 +0000
matrix-synapse-py3 (1.79.0~rc1) stable; urgency=medium
* New Synapse release 1.79.0rc1.

View File

@@ -231,6 +231,7 @@ information.
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$
^/_matrix/client/v1/rooms/.*/timestamp_to_event$
^/_matrix/client/(api/v1|r0|v3|unstable/.*)/rooms/.*/aliases
^/_matrix/client/(api/v1|r0|v3|unstable)/search$
^/_matrix/client/(r0|v3|unstable)/user/.*/filter(/|$)

View File

@@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
[tool.poetry]
name = "matrix-synapse"
version = "1.79.0"
version = "1.79.0rc1"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"

View File

@@ -52,7 +52,6 @@ fn bench_match_exact(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();
@@ -98,7 +97,6 @@ fn bench_match_word(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();
@@ -144,7 +142,6 @@ fn bench_match_word_miss(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();
@@ -190,7 +187,6 @@ fn bench_eval_message(b: &mut Bencher) {
true,
vec![],
false,
false,
)
.unwrap();

View File

@@ -71,7 +71,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::EventMatch(
EventMatchCondition {
key: Cow::Borrowed("content.m.relates_to.rel_type"),
key: Cow::Borrowed("content.m\\.relates_to.rel_type"),
pattern: Cow::Borrowed("m.replace"),
},
))]),
@@ -146,7 +146,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::ExactEventPropertyContainsType(EventPropertyIsTypeCondition {
key: Cow::Borrowed("content.org.matrix.msc3952.mentions.user_ids"),
key: Cow::Borrowed("content.org\\.matrix\\.msc3952\\.mentions.user_ids"),
value_type: Cow::Borrowed(&EventMatchPatternType::UserId),
}),
)]),
@@ -167,7 +167,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
priority_class: 5,
conditions: Cow::Borrowed(&[
Condition::Known(KnownCondition::EventPropertyIs(EventPropertyIsCondition {
key: Cow::Borrowed("content.org.matrix.msc3952.mentions.room"),
key: Cow::Borrowed("content.org\\.matrix\\.msc3952\\.mentions.room"),
value: Cow::Borrowed(&SimpleJsonValue::Bool(true)),
})),
Condition::Known(KnownCondition::SenderNotificationPermission {

View File

@@ -96,9 +96,6 @@ pub struct PushRuleEvaluator {
/// If MSC3931 (room version feature flags) is enabled. Usually controlled by the same
/// flag as MSC1767 (extensible events core).
msc3931_enabled: bool,
/// If MSC3966 (exact_event_property_contains push rule condition) is enabled.
msc3966_exact_event_property_contains: bool,
}
#[pymethods]
@@ -116,7 +113,6 @@ impl PushRuleEvaluator {
related_event_match_enabled: bool,
room_version_feature_flags: Vec<String>,
msc3931_enabled: bool,
msc3966_exact_event_property_contains: bool,
) -> Result<Self, Error> {
let body = match flattened_keys.get("content.body") {
Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone(),
@@ -134,7 +130,6 @@ impl PushRuleEvaluator {
related_event_match_enabled,
room_version_feature_flags,
msc3931_enabled,
msc3966_exact_event_property_contains,
})
}
@@ -301,8 +296,8 @@ impl PushRuleEvaluator {
Some(Cow::Borrowed(pattern)),
)?
}
KnownCondition::ExactEventPropertyContains(event_property_is) => self
.match_exact_event_property_contains(
KnownCondition::EventPropertyContains(event_property_is) => self
.match_event_property_contains(
event_property_is.key.clone(),
event_property_is.value.clone(),
)?,
@@ -321,7 +316,7 @@ impl PushRuleEvaluator {
EventMatchPatternType::UserLocalpart => get_localpart_from_id(user_id)?,
};
self.match_exact_event_property_contains(
self.match_event_property_contains(
exact_event_match.key.clone(),
Cow::Borrowed(&SimpleJsonValue::Str(pattern.to_string())),
)?
@@ -454,17 +449,12 @@ impl PushRuleEvaluator {
}
}
/// Evaluates a `exact_event_property_contains` condition. (MSC3966)
fn match_exact_event_property_contains(
/// Evaluates a `event_property_contains` condition.
fn match_event_property_contains(
&self,
key: Cow<str>,
value: Cow<SimpleJsonValue>,
) -> Result<bool, Error> {
// First check if the feature is enabled.
if !self.msc3966_exact_event_property_contains {
return Ok(false);
}
let haystack = if let Some(JsonValue::Array(haystack)) = self.flattened_keys.get(&*key) {
haystack
} else {
@@ -515,7 +505,6 @@ fn push_rule_evaluator() {
true,
vec![],
true,
true,
)
.unwrap();
@@ -545,7 +534,6 @@ fn test_requires_room_version_supports_condition() {
false,
flags,
true,
true,
)
.unwrap();

View File

@@ -337,13 +337,9 @@ pub enum KnownCondition {
// Identical to related_event_match but gives predefined patterns. Cannot be added by users.
#[serde(skip_deserializing, rename = "im.nheko.msc3664.related_event_match")]
RelatedEventMatchType(RelatedEventMatchTypeCondition),
#[serde(rename = "org.matrix.msc3966.exact_event_property_contains")]
ExactEventPropertyContains(EventPropertyIsCondition),
EventPropertyContains(EventPropertyIsCondition),
// Identical to exact_event_property_contains but gives predefined patterns. Cannot be added by users.
#[serde(
skip_deserializing,
rename = "org.matrix.msc3966.exact_event_property_contains"
)]
#[serde(skip_deserializing, rename = "event_property_contains")]
ExactEventPropertyContainsType(EventPropertyIsTypeCondition),
ContainsDisplayName,
RoomMemberCount {

View File

@@ -65,7 +65,6 @@ class PushRuleEvaluator:
related_event_match_enabled: bool,
room_version_feature_flags: Tuple[str, ...],
msc3931_enabled: bool,
msc3966_exact_event_property_contains: bool,
): ...
def run(
self,

View File

@@ -58,9 +58,6 @@ from synapse.config._base import format_config_error
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import ListenerConfig, ManholeConfig
from synapse.crypto import context_factory
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseSite
from synapse.logging.context import PreserveLoggingContext
@@ -68,6 +65,15 @@ from synapse.logging.opentracing import init_tracer
from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.metrics.jemalloc import setup_jemalloc_stats
from synapse.module_api.callbacks.presence_router_callbacks import (
load_legacy_presence_router,
)
from synapse.module_api.callbacks.spam_checker_callbacks import (
load_legacy_spam_checkers,
)
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
from synapse.types import ISynapseReactor
from synapse.util import SYNAPSE_VERSION
from synapse.util.caches.lrucache import setup_expire_lru_cache_entries

View File

@@ -166,20 +166,9 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False)
# MSC3873: Disambiguate event_match keys.
self.msc3873_escape_event_match_key = experimental.get(
"msc3873_escape_event_match_key", False
)
# MSC3966: exact_event_property_contains push rule condition.
self.msc3966_exact_event_property_contains = experimental.get(
"msc3966_exact_event_property_contains", False
)
# MSC3952: Intentional mentions, this depends on MSC3966.
self.msc3952_intentional_mentions = (
experimental.get("msc3952_intentional_mentions", False)
and self.msc3966_exact_event_property_contains
self.msc3952_intentional_mentions = experimental.get(
"msc3952_intentional_mentions", False
)
# MSC3959: Do not generate notifications for edits.
@@ -187,10 +176,5 @@ class ExperimentalConfig(Config):
"msc3958_supress_edit_notifs", False
)
# MSC3966: exact_event_property_contains push rule condition.
self.msc3966_exact_event_property_contains = experimental.get(
"msc3966_exact_event_property_contains", False
)
# MSC3967: Do not require UIA when first uploading cross signing keys
self.msc3967_enabled = experimental.get("msc3967_enabled", False)

View File

@@ -168,13 +168,24 @@ async def check_state_independent_auth_rules(
return
# 2. Reject if event has auth_events that: ...
auth_events = await store.get_events(
event.auth_event_ids(),
redact_behaviour=EventRedactBehaviour.as_is,
allow_rejected=True,
)
if batched_auth_events:
auth_events.update(batched_auth_events)
# Copy the batched auth events to avoid mutating them.
auth_events = dict(batched_auth_events)
needed_auth_event_ids = set(event.auth_event_ids()) - batched_auth_events.keys()
if needed_auth_event_ids:
auth_events.update(
await store.get_events(
needed_auth_event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
allow_rejected=True,
)
)
else:
auth_events = await store.get_events(
event.auth_event_ids(),
redact_behaviour=EventRedactBehaviour.as_is,
allow_rejected=True,
)
room_id = event.room_id
auth_dict: MutableStateMap[str] = {}

View File

@@ -12,93 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
from typing import TYPE_CHECKING, Dict, Iterable, Set, Union
from twisted.internet.defer import CancelledError
from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.async_helpers import delay_cancellation
if TYPE_CHECKING:
from synapse.server import HomeServer
GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
"""
if hs.config.server.presence_router_module_class is None:
return
module = hs.config.server.presence_router_module_class
config = hs.config.server.presence_router_config
api = hs.get_module_api()
presence_router = module(config=config, module_api=api)
# The known hooks. If a module implements a method which name appears in this set,
# we'll want to register it.
presence_router_methods = {
"get_users_for_states",
"get_interested_users",
}
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(
f: Optional[Callable[P, R]]
) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks: Dict[str, Optional[Callable[..., Any]]] = {
hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods
}
api.register_presence_router_callbacks(**hooks)
class PresenceRouter:
"""
A module that the homeserver will call upon to help route user presence updates to
@@ -108,30 +34,7 @@ class PresenceRouter:
ALL_USERS = "ALL"
def __init__(self, hs: "HomeServer"):
# Initially there are no callbacks
self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
def register_presence_router_callbacks(
self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
) -> None:
# PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner
paired_methods = [get_users_for_states, get_interested_users]
if paired_methods.count(None) == 1:
raise RuntimeError(
"PresenceRouter modules must register neither or both of the paired callbacks: "
"[get_users_for_states, get_interested_users]"
)
# Append the methods provided to the lists of callbacks
if get_users_for_states is not None:
self._get_users_for_states_callbacks.append(get_users_for_states)
if get_interested_users is not None:
self._get_interested_users_callbacks.append(get_interested_users)
self._module_api_callbacks = hs.get_module_api_callbacks().presence_router
async def get_users_for_states(
self,
@@ -150,13 +53,13 @@ class PresenceRouter:
"""
# Bail out early if we don't have any callbacks to run.
if len(self._get_users_for_states_callbacks) == 0:
if len(self._module_api_callbacks.get_users_for_states_callbacks) == 0:
# Don't include any extra destinations for presence updates
return {}
users_for_states: Dict[str, Set[UserPresenceState]] = {}
# run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks:
for callback in self._module_api_callbacks.get_users_for_states_callbacks:
try:
# Note: result is an object here, because we don't trust modules to
# return the types they're supposed to.
@@ -206,13 +109,13 @@ class PresenceRouter:
"""
# Bail out early if we don't have any callbacks to run.
if len(self._get_interested_users_callbacks) == 0:
if len(self._module_api_callbacks.get_interested_users_callbacks) == 0:
# Don't report any additional interested users
return set()
interested_users = set()
# run all the callbacks for get_interested_users and combine the results
for callback in self._get_interested_users_callbacks:
for callback in self._module_api_callbacks.get_interested_users_callbacks:
try:
result = await delay_cancellation(callback(user_id))
except CancelledError:

View File

@@ -293,6 +293,7 @@ class EventContext(UnpersistedEventContextBase):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event, state_filter

View File

@@ -13,19 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
List,
Optional,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Collection, Optional, Tuple, Union
# `Literal` appears with Python 3.8.
from typing_extensions import Literal
@@ -37,7 +26,7 @@ from synapse.media._base import FileInfo
from synapse.media.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import JsonDict, RoomAlias, UserProfile
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.async_helpers import delay_cancellation
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -46,338 +35,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
["synapse.events.EventBase"],
Awaitable[
Union[
str,
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
SHOULD_DROP_FEDERATED_EVENT_CALLBACK = Callable[
["synapse.events.EventBase"],
Awaitable[Union[bool, str]],
]
USER_MAY_JOIN_ROOM_CALLBACK = Callable[
[str, str, bool],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_INVITE_CALLBACK = Callable[
[str, str, str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[
[str, str, str, str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_CREATE_ROOM_CALLBACK = Callable[
[str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[
[str, RoomAlias],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
[str, str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
Optional[str],
Collection[Tuple[str, str]],
],
Awaitable[RegistrationBehaviour],
]
CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
Optional[str],
Collection[Tuple[str, str]],
Optional[str],
],
Awaitable[RegistrationBehaviour],
]
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
[ReadableFileWrapper, FileInfo],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
"""Wrapper that loads spam checkers configured using the old configuration, and
registers the spam checker hooks they implement.
"""
spam_checkers: List[Any] = []
api = hs.get_module_api()
for module, config in hs.config.spamchecker.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
spam_checkers.append(module(config=config, api=api))
else:
spam_checkers.append(module(config=config))
# The known spam checker hooks. If a spam checker module implements a method
# which name appears in this set, we'll want to register it.
spam_checker_methods = {
"check_event_for_spam",
"user_may_invite",
"user_may_create_room",
"user_may_create_room_alias",
"user_may_publish_room",
"check_username_for_spam",
"check_registration_for_spam",
"check_media_file_for_spam",
}
for spam_checker in spam_checkers:
# Methods on legacy spam checkers might not be async, so we wrap them around a
# wrapper that will call maybe_awaitable on the result.
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
wrapped_func = f
if f.__name__ == "check_registration_for_spam":
checker_args = inspect.signature(f)
if len(checker_args.parameters) == 3:
# Backwards compatibility; some modules might implement a hook that
# doesn't expect a 4th argument. In this case, wrap it in a function
# that gives it only 3 arguments and drops the auth_provider_id on
# the floor.
def wrapper(
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
# Assertion required because mypy can't prove we won't
# change `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return f(
email_threepid,
username,
request_info,
)
wrapped_func = wrapper
elif len(checker_args.parameters) != 4:
raise RuntimeError(
"Bad signature for callback check_registration_for_spam",
)
def run(*args: Any, **kwargs: Any) -> Awaitable:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert wrapped_func is not None
return maybe_awaitable(wrapped_func(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(spam_checker, hook, None))
for hook in spam_checker_methods
}
api.register_spam_checker_callbacks(**hooks)
class SpamChecker:
NOT_SPAM: Literal["NOT_SPAM"] = "NOT_SPAM"
def __init__(self, hs: "synapse.server.HomeServer") -> None:
self.hs = hs
def __init__(self, hs: "synapse.server.HomeServer"):
self.clock = hs.get_clock()
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
self._should_drop_federated_event_callbacks: List[
SHOULD_DROP_FEDERATED_EVENT_CALLBACK
] = []
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
self._user_may_send_3pid_invite_callbacks: List[
USER_MAY_SEND_3PID_INVITE_CALLBACK
] = []
self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
self._user_may_create_room_alias_callbacks: List[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = []
self._user_may_publish_room_callbacks: List[USER_MAY_PUBLISH_ROOM_CALLBACK] = []
self._check_username_for_spam_callbacks: List[
CHECK_USERNAME_FOR_SPAM_CALLBACK
] = []
self._check_registration_for_spam_callbacks: List[
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = []
self._check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
def register_callbacks(
self,
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
should_drop_federated_event: Optional[
SHOULD_DROP_FEDERATED_EVENT_CALLBACK
] = None,
user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
user_may_create_room_alias: Optional[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = None,
user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None,
check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None,
check_registration_for_spam: Optional[
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
self._check_event_for_spam_callbacks.append(check_event_for_spam)
if should_drop_federated_event is not None:
self._should_drop_federated_event_callbacks.append(
should_drop_federated_event
)
if user_may_join_room is not None:
self._user_may_join_room_callbacks.append(user_may_join_room)
if user_may_invite is not None:
self._user_may_invite_callbacks.append(user_may_invite)
if user_may_send_3pid_invite is not None:
self._user_may_send_3pid_invite_callbacks.append(
user_may_send_3pid_invite,
)
if user_may_create_room is not None:
self._user_may_create_room_callbacks.append(user_may_create_room)
if user_may_create_room_alias is not None:
self._user_may_create_room_alias_callbacks.append(
user_may_create_room_alias,
)
if user_may_publish_room is not None:
self._user_may_publish_room_callbacks.append(user_may_publish_room)
if check_username_for_spam is not None:
self._check_username_for_spam_callbacks.append(check_username_for_spam)
if check_registration_for_spam is not None:
self._check_registration_for_spam_callbacks.append(
check_registration_for_spam,
)
if check_media_file_for_spam is not None:
self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam)
self._module_api_callbacks = hs.get_module_api_callbacks().spam_checker
@trace
async def check_event_for_spam(
@@ -401,7 +65,7 @@ class SpamChecker:
string should be used as the client-facing error message. This usage is
generally discouraged as it doesn't support internationalization.
"""
for callback in self._check_event_for_spam_callbacks:
for callback in self._module_api_callbacks.check_event_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -456,7 +120,9 @@ class SpamChecker:
Returns:
True if the event should be silently dropped
"""
for callback in self._should_drop_federated_event_callbacks:
for (
callback
) in self._module_api_callbacks.should_drop_federated_event_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -480,7 +146,7 @@ class SpamChecker:
Returns:
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
"""
for callback in self._user_may_join_room_callbacks:
for callback in self._module_api_callbacks.user_may_join_room_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -521,7 +187,7 @@ class SpamChecker:
Returns:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_invite_callbacks:
for callback in self._module_api_callbacks.user_may_invite_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -568,7 +234,7 @@ class SpamChecker:
Returns:
NOT_SPAM if the operation is permitted, Codes otherwise.
"""
for callback in self._user_may_send_3pid_invite_callbacks:
for callback in self._module_api_callbacks.user_may_send_3pid_invite_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -605,7 +271,7 @@ class SpamChecker:
Args:
userid: The ID of the user attempting to create a room
"""
for callback in self._user_may_create_room_callbacks:
for callback in self._module_api_callbacks.user_may_create_room_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -641,7 +307,7 @@ class SpamChecker:
room_alias: The alias to be created
"""
for callback in self._user_may_create_room_alias_callbacks:
for callback in self._module_api_callbacks.user_may_create_room_alias_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -676,7 +342,7 @@ class SpamChecker:
userid: The user ID attempting to publish the room
room_id: The ID of the room that would be published
"""
for callback in self._user_may_publish_room_callbacks:
for callback in self._module_api_callbacks.user_may_publish_room_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -717,7 +383,7 @@ class SpamChecker:
Returns:
True if the user is spammy.
"""
for callback in self._check_username_for_spam_callbacks:
for callback in self._module_api_callbacks.check_username_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -751,7 +417,9 @@ class SpamChecker:
Enum for how the request should be handled
"""
for callback in self._check_registration_for_spam_callbacks:
for (
callback
) in self._module_api_callbacks.check_registration_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
@@ -794,7 +462,7 @@ class SpamChecker:
file_info: Metadata about the file.
"""
for callback in self._check_media_file_for_spam_callbacks:
for callback in self._module_api_callbacks.check_media_file_for_spam_callbacks:
with Measure(
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet.defer import CancelledError
@@ -21,7 +21,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContextBase
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.async_helpers import delay_cancellation
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -29,117 +29,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
CHECK_EVENT_ALLOWED_CALLBACK = Callable[
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
]
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
[str, str, StateMap[EventBase]], Awaitable[bool]
]
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
[str, StateMap[EventBase], str], Awaitable[bool]
]
ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
"""Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement.
"""
if hs.config.thirdpartyrules.third_party_event_rules is None:
return
module, config = hs.config.thirdpartyrules.third_party_event_rules
api = hs.get_module_api()
third_party_rules = module(config=config, module_api=api)
# The known hooks. If a module implements a method which name appears in this set,
# we'll want to register it.
third_party_event_rules_methods = {
"check_event_allowed",
"on_create_room",
"check_threepid_can_be_invited",
"check_visibility_can_be_modified",
}
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# We return a separate wrapper for these methods because, in order to wrap them
# correctly, we need to await its result. Therefore it doesn't make a lot of
# sense to make it go through the run() wrapper.
if f.__name__ == "check_event_allowed":
# We need to wrap check_event_allowed because its old form would return either
# a boolean or a dict, but now we want to return the dict separately from the
# boolean.
async def wrap_check_event_allowed(
event: EventBase,
state_events: StateMap[EventBase],
) -> Tuple[bool, Optional[dict]]:
# Assertion required because mypy can't prove we won't change
# `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
res = await f(event, state_events)
if isinstance(res, dict):
return True, res
else:
return res, None
return wrap_check_event_allowed
if f.__name__ == "on_create_room":
# We need to wrap on_create_room because its old form would return a boolean
# if the room creation is denied, but now we just want it to raise an
# exception.
async def wrap_on_create_room(
requester: Requester, config: dict, is_requester_admin: bool
) -> None:
# Assertion required because mypy can't prove we won't change
# `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
res = await f(requester, config, is_requester_admin)
if res is False:
raise SynapseError(
403,
"Room creation forbidden with these parameters",
)
return wrap_on_create_room
def run(*args: Any, **kwargs: Any) -> Awaitable:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(third_party_rules, hook, None))
for hook in third_party_event_rules_methods
}
api.register_third_party_rules_callbacks(**hooks)
class ThirdPartyEventRules:
"""Allows server admins to provide a Python module implementing an extra
set of rules to apply when processing events.
@@ -153,104 +42,9 @@ class ThirdPartyEventRules:
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
self._check_threepid_can_be_invited_callbacks: List[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = []
self._check_visibility_can_be_modified_callbacks: List[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = []
self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
self._check_can_shutdown_room_callbacks: List[
CHECK_CAN_SHUTDOWN_ROOM_CALLBACK
] = []
self._check_can_deactivate_user_callbacks: List[
CHECK_CAN_DEACTIVATE_USER_CALLBACK
] = []
self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = []
self._on_user_deactivation_status_changed_callbacks: List[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = []
self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
self._on_add_user_third_party_identifier_callbacks: List[
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = []
self._on_remove_user_third_party_identifier_callbacks: List[
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = []
def register_third_party_rules_callbacks(
self,
check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
check_threepid_can_be_invited: Optional[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = None,
check_visibility_can_be_modified: Optional[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = None,
on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None,
check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None,
on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None,
on_user_deactivation_status_changed: Optional[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None,
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
on_add_user_third_party_identifier: Optional[
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = None,
on_remove_user_third_party_identifier: Optional[
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = None,
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed)
if on_create_room is not None:
self._on_create_room_callbacks.append(on_create_room)
if check_threepid_can_be_invited is not None:
self._check_threepid_can_be_invited_callbacks.append(
check_threepid_can_be_invited,
)
if check_visibility_can_be_modified is not None:
self._check_visibility_can_be_modified_callbacks.append(
check_visibility_can_be_modified,
)
if on_new_event is not None:
self._on_new_event_callbacks.append(on_new_event)
if check_can_shutdown_room is not None:
self._check_can_shutdown_room_callbacks.append(check_can_shutdown_room)
if check_can_deactivate_user is not None:
self._check_can_deactivate_user_callbacks.append(check_can_deactivate_user)
if on_profile_update is not None:
self._on_profile_update_callbacks.append(on_profile_update)
if on_user_deactivation_status_changed is not None:
self._on_user_deactivation_status_changed_callbacks.append(
on_user_deactivation_status_changed,
)
if on_threepid_bind is not None:
self._on_threepid_bind_callbacks.append(on_threepid_bind)
if on_add_user_third_party_identifier is not None:
self._on_add_user_third_party_identifier_callbacks.append(
on_add_user_third_party_identifier
)
if on_remove_user_third_party_identifier is not None:
self._on_remove_user_third_party_identifier_callbacks.append(
on_remove_user_third_party_identifier
)
self._module_api_callbacks = (
hs.get_module_api_callbacks().third_party_event_rules
)
async def check_event_allowed(
self,
@@ -274,7 +68,7 @@ class ThirdPartyEventRules:
The result from the ThirdPartyRules module, as above.
"""
# Bail out early without hitting the store if we don't have any callbacks to run.
if len(self._check_event_allowed_callbacks) == 0:
if len(self._module_api_callbacks.check_event_allowed_callbacks) == 0:
return True, None
prev_state_ids = await context.get_prev_state_ids()
@@ -288,7 +82,7 @@ class ThirdPartyEventRules:
# the hashes and signatures.
event.freeze()
for callback in self._check_event_allowed_callbacks:
for callback in self._module_api_callbacks.check_event_allowed_callbacks:
try:
res, replacement_data = await delay_cancellation(
callback(event, state_events)
@@ -329,7 +123,7 @@ class ThirdPartyEventRules:
config: The creation config from the client.
is_requester_admin: If the requester is an admin
"""
for callback in self._on_create_room_callbacks:
for callback in self._module_api_callbacks.on_create_room_callbacks:
try:
await callback(requester, config, is_requester_admin)
except Exception as e:
@@ -357,12 +151,14 @@ class ThirdPartyEventRules:
True if the 3PID can be invited, False if not.
"""
# Bail out early without hitting the store if we don't have any callbacks to run.
if len(self._check_threepid_can_be_invited_callbacks) == 0:
if len(self._module_api_callbacks.check_threepid_can_be_invited_callbacks) == 0:
return True
state_events = await self._get_state_map_for_room(room_id)
for callback in self._check_threepid_can_be_invited_callbacks:
for (
callback
) in self._module_api_callbacks.check_threepid_can_be_invited_callbacks:
try:
threepid_can_be_invited = await delay_cancellation(
callback(medium, address, state_events)
@@ -390,12 +186,17 @@ class ThirdPartyEventRules:
True if the room's visibility can be modified, False if not.
"""
# Bail out early without hitting the store if we don't have any callback
if len(self._check_visibility_can_be_modified_callbacks) == 0:
if (
len(self._module_api_callbacks.check_visibility_can_be_modified_callbacks)
== 0
):
return True
state_events = await self._get_state_map_for_room(room_id)
for callback in self._check_visibility_can_be_modified_callbacks:
for (
callback
) in self._module_api_callbacks.check_visibility_can_be_modified_callbacks:
try:
visibility_can_be_modified = await delay_cancellation(
callback(room_id, state_events, new_visibility)
@@ -417,13 +218,13 @@ class ThirdPartyEventRules:
event_id: The ID of the event.
"""
# Bail out early without hitting the store if we don't have any callbacks
if len(self._on_new_event_callbacks) == 0:
if len(self._module_api_callbacks.on_new_event_callbacks) == 0:
return
event = await self.store.get_event(event_id)
state_events = await self._get_state_map_for_room(event.room_id)
for callback in self._on_new_event_callbacks:
for callback in self._module_api_callbacks.on_new_event_callbacks:
try:
await callback(event, state_events)
except Exception as e:
@@ -439,7 +240,7 @@ class ThirdPartyEventRules:
requester: The ID of the user requesting the shutdown.
room_id: The ID of the room.
"""
for callback in self._check_can_shutdown_room_callbacks:
for callback in self._module_api_callbacks.check_can_shutdown_room_callbacks:
try:
can_shutdown_room = await delay_cancellation(callback(user_id, room_id))
if can_shutdown_room is False:
@@ -464,7 +265,7 @@ class ThirdPartyEventRules:
requester
user_id: The ID of the room.
"""
for callback in self._check_can_deactivate_user_callbacks:
for callback in self._module_api_callbacks.check_can_deactivate_user_callbacks:
try:
can_deactivate_user = await delay_cancellation(
callback(user_id, by_admin)
@@ -502,7 +303,7 @@ class ThirdPartyEventRules:
by_admin: Whether the profile update was performed by a server admin.
deactivation: Whether this change was made while deactivating the user.
"""
for callback in self._on_profile_update_callbacks:
for callback in self._module_api_callbacks.on_profile_update_callbacks:
try:
await callback(user_id, new_profile, by_admin, deactivation)
except Exception as e:
@@ -520,7 +321,9 @@ class ThirdPartyEventRules:
deactivated: Whether the user is now deactivated.
by_admin: Whether the deactivation was performed by a server admin.
"""
for callback in self._on_user_deactivation_status_changed_callbacks:
for (
callback
) in self._module_api_callbacks.on_user_deactivation_status_changed_callbacks:
try:
await callback(user_id, deactivated, by_admin)
except Exception as e:
@@ -543,7 +346,7 @@ class ThirdPartyEventRules:
medium: the threepid's medium.
address: the threepid's address.
"""
for callback in self._on_threepid_bind_callbacks:
for callback in self._module_api_callbacks.on_threepid_bind_callbacks:
try:
await callback(user_id, medium, address)
except Exception as e:
@@ -562,7 +365,9 @@ class ThirdPartyEventRules:
medium: The medium of the third-party ID (email, msisdn).
address: The address of the third-party ID (i.e. an email address).
"""
for callback in self._on_add_user_third_party_identifier_callbacks:
for (
callback
) in self._module_api_callbacks.on_add_user_third_party_identifier_callbacks:
try:
await callback(user_id, medium, address)
except Exception as e:
@@ -584,7 +389,9 @@ class ThirdPartyEventRules:
medium: The medium of the third-party ID (email, msisdn).
address: The address of the third-party ID (i.e. an email address).
"""
for callback in self._on_remove_user_third_party_identifier_callbacks:
for (
callback
) in self._module_api_callbacks.on_remove_user_third_party_identifier_callbacks:
try:
await callback(user_id, medium, address)
except Exception as e:

View File

@@ -497,8 +497,8 @@ class PerDestinationQueue:
#
# Note: `catchup_pdus` will have exactly one PDU per room.
for pdu in catchup_pdus:
# The PDU from the DB will be the newest PDU in the room from
# *this server* that we tried---but were unable---to send to the remote.
# The PDU from the DB will be the last PDU in the room from
# *this server* that wasn't sent to the remote. However, other
# servers may have sent lots of events since then, and we want
# to try and tell the remote only about the *latest* events in
# the room. This is so that it doesn't get inundated by events
@@ -516,11 +516,6 @@ class PerDestinationQueue:
# If the event is in the extremities, then great! We can just
# use that without having to do further checks.
room_catchup_pdus = [pdu]
elif await self._store.is_partial_state_room(pdu.room_id):
# We can't be sure which events the destination should
# see using only partial state. Avoid doing so, and just retry
# sending our the newest PDU the remote is missing from us.
room_catchup_pdus = [pdu]
else:
# If not, fetch the extremities and figure out which we can
# send.
@@ -552,8 +547,6 @@ class PerDestinationQueue:
self._server_name,
new_pdus,
redact=False,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
)
# If we've filtered out all the extremities, fall back to

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
@@ -33,10 +33,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
[str, Optional[str], str, JsonDict], Awaitable
]
class AccountDataHandler:
def __init__(self, hs: "HomeServer"):
@@ -60,16 +56,7 @@ class AccountDataHandler:
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
self._on_account_data_updated_callbacks: List[
ON_ACCOUNT_DATA_UPDATED_CALLBACK
] = []
def register_module_callbacks(
self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
) -> None:
"""Register callbacks from modules."""
if on_account_data_updated is not None:
self._on_account_data_updated_callbacks.append(on_account_data_updated)
self._module_api_callbacks = hs.get_module_api_callbacks().account_data
async def _notify_modules(
self,
@@ -92,7 +79,7 @@ class AccountDataHandler:
account_data_type: The type of the account data.
content: The content that is now associated with this type.
"""
for callback in self._on_account_data_updated_callbacks:
for callback in self._module_api_callbacks.on_account_data_updated_callbacks:
try:
await callback(user_id, room_id, account_data_type, content)
except Exception as e:

View File

@@ -15,9 +15,7 @@
import email.mime.multipart
import email.utils
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from twisted.web.http import Request
from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -30,25 +28,17 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
class AccountValidityHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.config = hs.config
self.store = self.hs.get_datastores().main
self.send_email_handler = self.hs.get_send_email_handler()
self.clock = self.hs.get_clock()
self.store = hs.get_datastores().main
self.send_email_handler = hs.get_send_email_handler()
self.clock = hs.get_clock()
self._app_name = self.hs.config.email.email_app_name
self._app_name = hs.config.email.email_app_name
self._module_api_callbacks = hs.get_module_api_callbacks().account_validity
self._account_validity_enabled = (
hs.config.account_validity.account_validity_enabled
@@ -78,69 +68,6 @@ class AccountValidityHandler:
if hs.config.worker.run_background_tasks:
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
self._on_legacy_send_mail_callback: Optional[
ON_LEGACY_SEND_MAIL_CALLBACK
] = None
self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
# The legacy admin requests callback isn't a protected attribute because we need
# to access it from the admin servlet, which is outside of this handler.
self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
def register_account_validity_callbacks(
self,
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self._is_user_expired_callbacks.append(is_user_expired)
if on_user_registration is not None:
self._on_user_registration_callbacks.append(on_user_registration)
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
# an admin one). As part of moving the feature into a module, we need to change
# the path from /_matrix/client/unstable/account_validity/... to
# /_synapse/client/account_validity, because:
#
# * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
# * the way we register servlets means that modules can't register resources
# under /_matrix/client
#
# We need to allow for a transition period between the old and new endpoints
# in order to allow for clients to update (and for emails to be processed).
#
# Once the email-account-validity module is loaded, it will take control of account
# validity by moving the rows from our `account_validity` table into its own table.
#
# Therefore, we need to allow modules (in practice just the one implementing the
# email-based account validity) to temporarily hook into the legacy endpoints so we
# can route the traffic coming into the old endpoints into the module, which is
# why we have the following three temporary hooks.
if on_legacy_send_mail is not None:
if self._on_legacy_send_mail_callback is not None:
raise RuntimeError("Tried to register on_legacy_send_mail twice")
self._on_legacy_send_mail_callback = on_legacy_send_mail
if on_legacy_renew is not None:
if self._on_legacy_renew_callback is not None:
raise RuntimeError("Tried to register on_legacy_renew twice")
self._on_legacy_renew_callback = on_legacy_renew
if on_legacy_admin_request is not None:
if self.on_legacy_admin_request_callback is not None:
raise RuntimeError("Tried to register on_legacy_admin_request twice")
self.on_legacy_admin_request_callback = on_legacy_admin_request
async def is_user_expired(self, user_id: str) -> bool:
"""Checks if a user has expired against third-party modules.
@@ -150,7 +77,7 @@ class AccountValidityHandler:
Returns:
Whether the user has expired.
"""
for callback in self._is_user_expired_callbacks:
for callback in self._module_api_callbacks.is_user_expired_callbacks:
expired = await delay_cancellation(callback(user_id))
if expired is not None:
return expired
@@ -168,7 +95,7 @@ class AccountValidityHandler:
Args:
user_id: The ID of the newly registered user.
"""
for callback in self._on_user_registration_callbacks:
for callback in self._module_api_callbacks.on_user_registration_callbacks:
await callback(user_id)
@wrap_as_background_process("send_renewals")
@@ -198,8 +125,8 @@ class AccountValidityHandler:
"""
# If a module supports sending a renewal email from here, do that, otherwise do
# the legacy dance.
if self._on_legacy_send_mail_callback is not None:
await self._on_legacy_send_mail_callback(user_id)
if self._module_api_callbacks.on_legacy_send_mail_callback is not None:
await self._module_api_callbacks.on_legacy_send_mail_callback(user_id)
return
if not self._account_validity_renew_by_email_enabled:
@@ -336,8 +263,10 @@ class AccountValidityHandler:
"""
# If a module supports triggering a renew from here, do that, otherwise do the
# legacy dance.
if self._on_legacy_renew_callback is not None:
return await self._on_legacy_renew_callback(renewal_token)
if self._module_api_callbacks.on_legacy_renew_callback is not None:
return await self._module_api_callbacks.on_legacy_renew_callback(
renewal_token
)
try:
(

View File

@@ -65,6 +65,10 @@ from synapse.http.server import finish_request, respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api.callbacks.password_auth_provider_callbacks import (
CHECK_3PID_AUTH_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
)
from synapse.storage.databases.main.registration import (
LoginTokenExpired,
LoginTokenLookupResult,
@@ -1096,7 +1100,7 @@ class AuthHandler:
return self._password_enabled_for_login and self._password_localdb_enabled
def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
"""Get the login types supported for the /login API
By default this is just 'm.login.password' (unless password_enabled is
False in the config file), but password auth providers can provide
@@ -1999,124 +2003,16 @@ def load_single_legacy_password_auth_provider(
)
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
class PasswordAuthProvider:
"""
A class that the AuthHandler calls when authenticating users
It allows modules to provide alternative methods for authentication
"""
def __init__(self) -> None:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
self.get_displayname_for_registration_callbacks: List[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
def register_password_auth_provider_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
self.check_3pid_auth_callbacks.append(check_3pid_auth)
# register on_logged_out callback
if on_logged_out is not None:
self.on_logged_out_callbacks.append(on_logged_out)
if auth_checkers is not None:
# register a new supported login_type
# Iterate through all of the types being registered
for (login_type, fields), callback in auth_checkers.items():
# Note: fields may be empty here. This would allow a modules auth checker to
# be called with just 'login_type' and no password or other secrets
# Need to check that all the field names are strings or may get nasty errors later
for f in fields:
if not isinstance(f, str):
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but all parameter names must be strings"
% (login_type, fields)
)
# 2 modules supporting the same login type must expect the same fields
# e.g. 1 can't expect "pass" if the other expects "password"
# so throw an exception if that happens
if login_type not in self._supported_login_types.get(login_type, []):
self._supported_login_types[login_type] = fields
else:
fields_currently_supported = self._supported_login_types.get(
login_type
)
if fields_currently_supported != fields:
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but another module had already registered support for that type with parameters %s"
% (login_type, fields, fields_currently_supported)
)
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
if get_username_for_registration is not None:
self.get_username_for_registration_callbacks.append(
get_username_for_registration,
)
if get_displayname_for_registration is not None:
self.get_displayname_for_registration_callbacks.append(
get_displayname_for_registration,
)
if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
def __init__(self, hs: "HomeServer") -> None:
self._module_api_callbacks = (
hs.get_module_api_callbacks().password_auth_provider
)
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
"""Get the login types supported by this password provider
@@ -2126,7 +2022,7 @@ class PasswordAuthProvider:
to the /login API.
"""
return self._supported_login_types
return self._module_api_callbacks.supported_login_types
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
@@ -2149,7 +2045,7 @@ class PasswordAuthProvider:
# Go through all callbacks for the login type until one returns with a value
# other than None (i.e. until a callback returns a success)
for callback in self.auth_checker_callbacks[login_type]:
for callback in self._module_api_callbacks.auth_checker_callbacks[login_type]:
try:
result = await delay_cancellation(
callback(username, login_type, login_dict)
@@ -2214,7 +2110,7 @@ class PasswordAuthProvider:
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
for callback in self.check_3pid_auth_callbacks:
for callback in self._module_api_callbacks.check_3pid_auth_callbacks:
try:
result = await delay_cancellation(callback(medium, address, password))
except CancelledError:
@@ -2272,7 +2168,7 @@ class PasswordAuthProvider:
self, user_id: str, device_id: Optional[str], access_token: str
) -> None:
# call all of the on_logged_out callbacks
for callback in self.on_logged_out_callbacks:
for callback in self._module_api_callbacks.on_logged_out_callbacks:
try:
await callback(user_id, device_id, access_token)
except Exception as e:
@@ -2297,7 +2193,9 @@ class PasswordAuthProvider:
The localpart to use when registering this user, or None if no module
returned a localpart.
"""
for callback in self.get_username_for_registration_callbacks:
for (
callback
) in self._module_api_callbacks.get_username_for_registration_callbacks:
try:
res = await delay_cancellation(callback(uia_results, params))
@@ -2342,7 +2240,9 @@ class PasswordAuthProvider:
A tuple which first element is the display name, and the second is an MXC URL
to the user's avatar.
"""
for callback in self.get_displayname_for_registration_callbacks:
for (
callback
) in self._module_api_callbacks.get_displayname_for_registration_callbacks:
try:
res = await delay_cancellation(callback(uia_results, params))
@@ -2385,7 +2285,7 @@ class PasswordAuthProvider:
Returns:
Whether the 3PID is allowed to be bound on this homeserver
"""
for callback in self.is_3pid_allowed_callbacks:
for callback in self._module_api_callbacks.is_3pid_allowed_callbacks:
try:
res = await delay_cancellation(callback(medium, address, registration))

View File

@@ -63,9 +63,18 @@ class EventAuthHandler:
self._store, event, batched_auth_events
)
auth_event_ids = event.auth_event_ids()
auth_events_by_id = await self._store.get_events(auth_event_ids)
if batched_auth_events:
auth_events_by_id.update(batched_auth_events)
# Copy the batched auth events to avoid mutating them.
auth_events_by_id = dict(batched_auth_events)
needed_auth_event_ids = set(auth_event_ids) - set(batched_auth_events)
if needed_auth_event_ids:
auth_events_by_id.update(
await self._store.get_events(needed_auth_event_ids)
)
else:
auth_events_by_id = await self._store.get_events(auth_event_ids)
check_state_dependent_auth_rules(event, auth_events_by_id.values())
def compute_auth_events(

View File

@@ -392,7 +392,7 @@ class FederationHandler:
get_prev_content=False,
)
# We unset `filter_out_erased_senders` as we might otherwise get false
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = await filter_events_for_server(
self._storage_controllers,
@@ -400,8 +400,7 @@ class FederationHandler:
self.server_name,
events_to_check,
redact=False,
filter_out_erased_senders=False,
filter_out_remote_partial_state_events=False,
check_history_visibility_only=True,
)
if filtered_extremities:
extremities_to_request.append(bp.event_id)
@@ -1332,13 +1331,7 @@ class FederationHandler:
)
events = await filter_events_for_server(
self._storage_controllers,
origin,
self.server_name,
events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, origin, self.server_name, events
)
return events
@@ -1369,13 +1362,7 @@ class FederationHandler:
await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
events = await filter_events_for_server(
self._storage_controllers,
origin,
self.server_name,
[event],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, origin, self.server_name, [event]
)
event = events[0]
return event
@@ -1403,13 +1390,7 @@ class FederationHandler:
)
missing_events = await filter_events_for_server(
self._storage_controllers,
origin,
self.server_name,
missing_events,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, origin, self.server_name, missing_events
)
return missing_events

View File

@@ -1123,7 +1123,9 @@ class RoomCreationHandler:
event_dict,
prev_event_ids=prev_event,
depth=depth,
state_map=state_map,
# Take a copy to ensure each event gets a unique copy of
# state_map since it is modified below.
state_map=dict(state_map),
for_batch=for_batch,
)

View File

@@ -39,56 +39,9 @@ from twisted.web.resource import Resource
from synapse.api import errors
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.presence_router import (
GET_INTERESTED_USERS_CALLBACK,
GET_USERS_FOR_STATES_CALLBACK,
PresenceRouter,
)
from synapse.events.spamcheck import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
SHOULD_DROP_FEDERATED_EVENT_CALLBACK,
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK,
USER_MAY_CREATE_ROOM_CALLBACK,
USER_MAY_INVITE_CALLBACK,
USER_MAY_JOIN_ROOM_CALLBACK,
USER_MAY_PUBLISH_ROOM_CALLBACK,
USER_MAY_SEND_3PID_INVITE_CALLBACK,
SpamChecker,
)
from synapse.events.third_party_rules import (
CHECK_CAN_DEACTIVATE_USER_CALLBACK,
CHECK_CAN_SHUTDOWN_ROOM_CALLBACK,
CHECK_EVENT_ALLOWED_CALLBACK,
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK,
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK,
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
ON_CREATE_ROOM_CALLBACK,
ON_NEW_EVENT_CALLBACK,
ON_PROFILE_UPDATE_CALLBACK,
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
ON_THREEPID_BIND_CALLBACK,
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
)
from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
from synapse.handlers.account_validity import (
IS_USER_EXPIRED_CALLBACK,
ON_LEGACY_ADMIN_REQUEST,
ON_LEGACY_RENEW_CALLBACK,
ON_LEGACY_SEND_MAIL_CALLBACK,
ON_USER_REGISTRATION_CALLBACK,
)
from synapse.handlers.auth import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
AuthHandler,
)
from synapse.events.presence_router import PresenceRouter
from synapse.events.spamcheck import SpamChecker
from synapse.handlers.auth import AuthHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.push_rules import RuleSpec, check_actions
from synapse.http.client import SimpleHttpClient
@@ -105,13 +58,62 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore
from synapse.storage.background_updates import (
from synapse.module_api.callbacks.account_data_callbacks import (
ON_ACCOUNT_DATA_UPDATED_CALLBACK,
)
from synapse.module_api.callbacks.account_validity_callbacks import (
IS_USER_EXPIRED_CALLBACK,
ON_LEGACY_ADMIN_REQUEST,
ON_LEGACY_RENEW_CALLBACK,
ON_LEGACY_SEND_MAIL_CALLBACK,
ON_USER_REGISTRATION_CALLBACK,
)
from synapse.module_api.callbacks.background_updater_callbacks import (
DEFAULT_BATCH_SIZE_CALLBACK,
MIN_BATCH_SIZE_CALLBACK,
ON_UPDATE_CALLBACK,
)
from synapse.module_api.callbacks.password_auth_provider_callbacks import (
CHECK_3PID_AUTH_CALLBACK,
CHECK_AUTH_CALLBACK,
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
GET_USERNAME_FOR_REGISTRATION_CALLBACK,
IS_3PID_ALLOWED_CALLBACK,
ON_LOGGED_OUT_CALLBACK,
)
from synapse.module_api.callbacks.presence_router_callbacks import (
GET_INTERESTED_USERS_CALLBACK,
GET_USERS_FOR_STATES_CALLBACK,
)
from synapse.module_api.callbacks.spam_checker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK,
CHECK_REGISTRATION_FOR_SPAM_CALLBACK,
CHECK_USERNAME_FOR_SPAM_CALLBACK,
SHOULD_DROP_FEDERATED_EVENT_CALLBACK,
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK,
USER_MAY_CREATE_ROOM_CALLBACK,
USER_MAY_INVITE_CALLBACK,
USER_MAY_JOIN_ROOM_CALLBACK,
USER_MAY_PUBLISH_ROOM_CALLBACK,
USER_MAY_SEND_3PID_INVITE_CALLBACK,
)
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
CHECK_CAN_DEACTIVATE_USER_CALLBACK,
CHECK_CAN_SHUTDOWN_ROOM_CALLBACK,
CHECK_EVENT_ALLOWED_CALLBACK,
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK,
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK,
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
ON_CREATE_ROOM_CALLBACK,
ON_NEW_EVENT_CALLBACK,
ON_PROFILE_UPDATE_CALLBACK,
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK,
ON_THREEPID_BIND_CALLBACK,
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK,
)
from synapse.rest.client.login import LoginResponse
from synapse.storage import DataStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
from synapse.types import (
@@ -250,6 +252,7 @@ class ModuleApi:
self._push_rules_handler = hs.get_push_rules_handler()
self._device_handler = hs.get_device_handler()
self.custom_template_dir = hs.config.server.custom_template_directory
self._callbacks = hs.get_module_api_callbacks()
try:
app_name = self._hs.config.email.email_app_name
@@ -270,13 +273,6 @@ class ModuleApi:
self._public_room_list_manager = PublicRoomListManager(hs)
self._account_data_manager = AccountDataManager(hs)
self._spam_checker = hs.get_spam_checker()
self._account_validity_handler = hs.get_account_validity_handler()
self._third_party_event_rules = hs.get_third_party_event_rules()
self._password_auth_provider = hs.get_password_auth_provider()
self._presence_router = hs.get_presence_router()
self._account_data_handler = hs.get_account_data_handler()
#################################################################################
# The following methods should only be called during the module's initialisation.
@@ -305,7 +301,7 @@ class ModuleApi:
Added in Synapse v1.37.0.
"""
return self._spam_checker.register_callbacks(
return self._callbacks.spam_checker.register_callbacks(
check_event_for_spam=check_event_for_spam,
should_drop_federated_event=should_drop_federated_event,
user_may_join_room=user_may_join_room,
@@ -332,7 +328,7 @@ class ModuleApi:
Added in Synapse v1.39.0.
"""
return self._account_validity_handler.register_account_validity_callbacks(
return self._callbacks.account_validity.register_callbacks(
is_user_expired=is_user_expired,
on_user_registration=on_user_registration,
on_legacy_send_mail=on_legacy_send_mail,
@@ -370,7 +366,7 @@ class ModuleApi:
Added in Synapse v1.39.0.
"""
return self._third_party_event_rules.register_third_party_rules_callbacks(
return self._callbacks.third_party_event_rules.register_callbacks(
check_event_allowed=check_event_allowed,
on_create_room=on_create_room,
check_threepid_can_be_invited=check_threepid_can_be_invited,
@@ -395,7 +391,7 @@ class ModuleApi:
Added in Synapse v1.42.0.
"""
return self._presence_router.register_presence_router_callbacks(
return self._callbacks.presence_router.register_callbacks(
get_users_for_states=get_users_for_states,
get_interested_users=get_interested_users,
)
@@ -420,7 +416,7 @@ class ModuleApi:
Added in Synapse v1.46.0.
"""
return self._password_auth_provider.register_password_auth_provider_callbacks(
return self._callbacks.password_auth_provider.register_callbacks(
check_3pid_auth=check_3pid_auth,
on_logged_out=on_logged_out,
is_3pid_allowed=is_3pid_allowed,
@@ -441,12 +437,11 @@ class ModuleApi:
Added in Synapse v1.49.0.
"""
for db in self._hs.get_datastores().databases:
db.updates.register_update_controller_callbacks(
on_update=on_update,
default_batch_size=default_batch_size,
min_batch_size=min_batch_size,
)
self._callbacks.background_updater.register_callbacks(
on_update=on_update,
default_batch_size=default_batch_size,
min_batch_size=min_batch_size,
)
def register_account_data_callbacks(
self,
@@ -457,7 +452,7 @@ class ModuleApi:
Added in Synapse 1.57.0.
"""
return self._account_data_handler.register_module_callbacks(
return self._callbacks.account_data.register_callbacks(
on_account_data_updated=on_account_data_updated,
)

View File

@@ -0,0 +1,36 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .account_data_callbacks import AccountDataModuleApiCallbacks
from .account_validity_callbacks import AccountValidityModuleApiCallbacks
from .background_updater_callbacks import BackgroundUpdaterModuleApiCallbacks
from .password_auth_provider_callbacks import PasswordAuthProviderModuleApiCallbacks
from .presence_router_callbacks import PresenceRouterModuleApiCallbacks
from .spam_checker_callbacks import SpamCheckerModuleApiCallbacks
from .third_party_event_rules_callbacks import ThirdPartyEventRulesModuleApiCallbacks
__all__ = [
"ModuleApiCallbacks",
]
class ModuleApiCallbacks:
def __init__(self) -> None:
self.account_data = AccountDataModuleApiCallbacks()
self.account_validity = AccountValidityModuleApiCallbacks()
self.background_updater = BackgroundUpdaterModuleApiCallbacks()
self.password_auth_provider = PasswordAuthProviderModuleApiCallbacks()
self.presence_router = PresenceRouterModuleApiCallbacks()
self.spam_checker = SpamCheckerModuleApiCallbacks()
self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks()

View File

@@ -0,0 +1,35 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2021, 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Awaitable, Callable, List, Optional
from synapse.types import JsonDict
ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
[str, Optional[str], str, JsonDict], Awaitable
]
class AccountDataModuleApiCallbacks:
def __init__(self) -> None:
self.on_account_data_updated_callbacks: List[
ON_ACCOUNT_DATA_UPDATED_CALLBACK
] = []
def register_callbacks(
self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
) -> None:
"""Register callbacks from modules."""
if on_account_data_updated is not None:
self.on_account_data_updated_callbacks.append(on_account_data_updated)

View File

@@ -0,0 +1,93 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Awaitable, Callable, List, Optional, Tuple
from twisted.web.http import Request
logger = logging.getLogger(__name__)
# Types for callbacks to be registered via the module api
IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
class AccountValidityModuleApiCallbacks:
def __init__(self) -> None:
self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None
self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
# The legacy admin requests callback isn't a protected attribute because we need
# to access it from the admin servlet, which is outside of this handler.
self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
def register_callbacks(
self,
is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
) -> None:
"""Register callbacks from module for each hook."""
if is_user_expired is not None:
self.is_user_expired_callbacks.append(is_user_expired)
if on_user_registration is not None:
self.on_user_registration_callbacks.append(on_user_registration)
# The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
# an admin one). As part of moving the feature into a module, we need to change
# the path from /_matrix/client/unstable/account_validity/... to
# /_synapse/client/account_validity, because:
#
# * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
# * the way we register servlets means that modules can't register resources
# under /_matrix/client
#
# We need to allow for a transition period between the old and new endpoints
# in order to allow for clients to update (and for emails to be processed).
#
# Once the email-account-validity module is loaded, it will take control of account
# validity by moving the rows from our `account_validity` table into its own table.
#
# Therefore, we need to allow modules (in practice just the one implementing the
# email-based account validity) to temporarily hook into the legacy endpoints so we
# can route the traffic coming into the old endpoints into the module, which is
# why we have the following three temporary hooks.
if on_legacy_send_mail is not None:
if self.on_legacy_send_mail_callback is not None:
raise RuntimeError("Tried to register on_legacy_send_mail twice")
self.on_legacy_send_mail_callback = on_legacy_send_mail
if on_legacy_renew is not None:
if self.on_legacy_renew_callback is not None:
raise RuntimeError("Tried to register on_legacy_renew twice")
self.on_legacy_renew_callback = on_legacy_renew
if on_legacy_admin_request is not None:
if self.on_legacy_admin_request_callback is not None:
raise RuntimeError("Tried to register on_legacy_admin_request twice")
self.on_legacy_admin_request_callback = on_legacy_admin_request

View File

@@ -0,0 +1,54 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import AsyncContextManager, Awaitable, Callable, Optional
logger = logging.getLogger(__name__)
ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
class BackgroundUpdaterModuleApiCallbacks:
def __init__(self) -> None:
self.on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
self.default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
self.min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
def register_callbacks(
self,
on_update: ON_UPDATE_CALLBACK,
default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
) -> None:
"""Register callbacks from a module for each hook."""
if self.on_update_callback is not None:
logger.warning(
"More than one module tried to register callbacks for controlling"
" background updates. Only the callbacks registered by the first module"
" (in order of appearance in Synapse's configuration file) that tried to"
" do so will be called."
)
return
self.on_update_callback = on_update
if default_batch_size is not None:
self.default_batch_size_callback = default_batch_size
if min_batch_size is not None:
self.min_batch_size_callback = min_batch_size

View File

@@ -0,0 +1,138 @@
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 - 2020, 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.module_api import LoginResponse
logger = logging.getLogger(__name__)
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
[JsonDict, JsonDict],
Awaitable[Optional[str]],
]
IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
class PasswordAuthProviderModuleApiCallbacks:
def __init__(self) -> None:
# Mapping from login type to login parameters
self.supported_login_types: Dict[str, Tuple[str, ...]] = {}
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
self.get_username_for_registration_callbacks: List[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
self.get_displayname_for_registration_callbacks: List[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
def register_callbacks(
self,
check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = None,
get_displayname_for_registration: Optional[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = None,
) -> None:
# Register check_3pid_auth callback
if check_3pid_auth is not None:
self.check_3pid_auth_callbacks.append(check_3pid_auth)
# register on_logged_out callback
if on_logged_out is not None:
self.on_logged_out_callbacks.append(on_logged_out)
if auth_checkers is not None:
# register a new supported login_type
# Iterate through all of the types being registered
for (login_type, fields), callback in auth_checkers.items():
# Note: fields may be empty here. This would allow a modules auth checker to
# be called with just 'login_type' and no password or other secrets
# Need to check that all the field names are strings or may get nasty errors later
for f in fields:
if not isinstance(f, str):
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but all parameter names must be strings"
% (login_type, fields)
)
# 2 modules supporting the same login type must expect the same fields
# e.g. 1 can't expect "pass" if the other expects "password"
# so throw an exception if that happens
if login_type not in self.supported_login_types.get(login_type, []):
self.supported_login_types[login_type] = fields
else:
fields_currently_supported = self.supported_login_types.get(
login_type
)
if fields_currently_supported != fields:
raise RuntimeError(
"A module tried to register support for login type: %s with parameters %s"
" but another module had already registered support for that type with parameters %s"
% (login_type, fields, fields_currently_supported)
)
# Add the new method to the list of auth_checker_callbacks for this login type
self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
if get_username_for_registration is not None:
self.get_username_for_registration_callbacks.append(
get_username_for_registration,
)
if get_displayname_for_registration is not None:
self.get_displayname_for_registration_callbacks.append(
get_displayname_for_registration,
)
if is_3pid_allowed is not None:
self.is_3pid_allowed_callbacks.append(is_3pid_allowed)

View File

@@ -0,0 +1,122 @@
# Copyright 2021, 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
Union,
)
from typing_extensions import ParamSpec
from synapse.api.presence import UserPresenceState
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
from synapse.server import HomeServer
GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
P = ParamSpec("P")
R = TypeVar("R")
def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
"""
if hs.config.server.presence_router_module_class is None:
return
module = hs.config.server.presence_router_module_class
config = hs.config.server.presence_router_config
api = hs.get_module_api()
presence_router = module(config=config, module_api=api)
# The known hooks. If a module implements a method which name appears in this set,
# we'll want to register it.
presence_router_methods = {
"get_users_for_states",
"get_interested_users",
}
# All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed
def async_wrapper(
f: Optional[Callable[P, R]]
) -> Optional[Callable[P, Awaitable[R]]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
def run(*args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks: Dict[str, Optional[Callable[..., Any]]] = {
hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods
}
api.register_presence_router_callbacks(**hooks)
class PresenceRouterModuleApiCallbacks:
def __init__(self) -> None:
# Initially there are no callbacks
self.get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
self.get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
def register_callbacks(
self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
) -> None:
# PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner
paired_methods = [get_users_for_states, get_interested_users]
if paired_methods.count(None) == 1:
raise RuntimeError(
"PresenceRouter modules must register neither or both of the paired callbacks: "
"[get_users_for_states, get_interested_users]"
)
# Append the methods provided to the lists of callbacks
if get_users_for_states is not None:
self.get_users_for_states_callbacks.append(get_users_for_states)
if get_interested_users is not None:
self.get_interested_users_callbacks.append(get_interested_users)

View File

@@ -0,0 +1,373 @@
# Copyright 2017 New Vector Ltd
# Copyright 2019, 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
List,
Optional,
Tuple,
Union,
)
# `Literal` appears with Python 3.8.
from typing_extensions import Literal
import synapse
from synapse.api.errors import Codes
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import ReadableFileWrapper
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import JsonDict, RoomAlias, UserProfile
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import synapse.events
import synapse.server
logger = logging.getLogger(__name__)
CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[
["synapse.events.EventBase"],
Awaitable[
Union[
str,
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
SHOULD_DROP_FEDERATED_EVENT_CALLBACK = Callable[
["synapse.events.EventBase"],
Awaitable[Union[bool, str]],
]
USER_MAY_JOIN_ROOM_CALLBACK = Callable[
[str, str, bool],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_INVITE_CALLBACK = Callable[
[str, str, str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[
[str, str, str, str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_CREATE_ROOM_CALLBACK = Callable[
[str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[
[str, RoomAlias],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[
[str, str],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
Optional[str],
Collection[Tuple[str, str]],
],
Awaitable[RegistrationBehaviour],
]
CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
[
Optional[dict],
Optional[str],
Collection[Tuple[str, str]],
Optional[str],
],
Awaitable[RegistrationBehaviour],
]
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
[ReadableFileWrapper, FileInfo],
Awaitable[
Union[
Literal["NOT_SPAM"],
Codes,
# Highly experimental, not officially part of the spamchecker API, may
# disappear without warning depending on the results of ongoing
# experiments.
# Use this to return additional information as part of an error.
Tuple[Codes, JsonDict],
# Deprecated
bool,
]
],
]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
"""Wrapper that loads spam checkers configured using the old configuration, and
registers the spam checker hooks they implement.
"""
spam_checkers: List[Any] = []
api = hs.get_module_api()
for module, config in hs.config.spamchecker.spam_checkers:
# Older spam checkers don't accept the `api` argument, so we
# try and detect support.
spam_args = inspect.getfullargspec(module)
if "api" in spam_args.args:
spam_checkers.append(module(config=config, api=api))
else:
spam_checkers.append(module(config=config))
# The known spam checker hooks. If a spam checker module implements a method
# which name appears in this set, we'll want to register it.
spam_checker_methods = {
"check_event_for_spam",
"user_may_invite",
"user_may_create_room",
"user_may_create_room_alias",
"user_may_publish_room",
"check_username_for_spam",
"check_registration_for_spam",
"check_media_file_for_spam",
}
for spam_checker in spam_checkers:
# Methods on legacy spam checkers might not be async, so we wrap them around a
# wrapper that will call maybe_awaitable on the result.
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
wrapped_func = f
if f.__name__ == "check_registration_for_spam":
checker_args = inspect.signature(f)
if len(checker_args.parameters) == 3:
# Backwards compatibility; some modules might implement a hook that
# doesn't expect a 4th argument. In this case, wrap it in a function
# that gives it only 3 arguments and drops the auth_provider_id on
# the floor.
def wrapper(
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
# Assertion required because mypy can't prove we won't
# change `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return f(
email_threepid,
username,
request_info,
)
wrapped_func = wrapper
elif len(checker_args.parameters) != 4:
raise RuntimeError(
"Bad signature for callback check_registration_for_spam",
)
def run(*args: Any, **kwargs: Any) -> Awaitable:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert wrapped_func is not None
return maybe_awaitable(wrapped_func(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(spam_checker, hook, None))
for hook in spam_checker_methods
}
api.register_spam_checker_callbacks(**hooks)
class SpamCheckerModuleApiCallbacks:
def __init__(self) -> None:
self.check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
self.should_drop_federated_event_callbacks: List[
SHOULD_DROP_FEDERATED_EVENT_CALLBACK
] = []
self.user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self.user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
self.user_may_send_3pid_invite_callbacks: List[
USER_MAY_SEND_3PID_INVITE_CALLBACK
] = []
self.user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = []
self.user_may_create_room_alias_callbacks: List[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = []
self.user_may_publish_room_callbacks: List[USER_MAY_PUBLISH_ROOM_CALLBACK] = []
self.check_username_for_spam_callbacks: List[
CHECK_USERNAME_FOR_SPAM_CALLBACK
] = []
self.check_registration_for_spam_callbacks: List[
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = []
self.check_media_file_for_spam_callbacks: List[
CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK
] = []
def register_callbacks(
self,
check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None,
should_drop_federated_event: Optional[
SHOULD_DROP_FEDERATED_EVENT_CALLBACK
] = None,
user_may_join_room: Optional[USER_MAY_JOIN_ROOM_CALLBACK] = None,
user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None,
user_may_send_3pid_invite: Optional[USER_MAY_SEND_3PID_INVITE_CALLBACK] = None,
user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None,
user_may_create_room_alias: Optional[
USER_MAY_CREATE_ROOM_ALIAS_CALLBACK
] = None,
user_may_publish_room: Optional[USER_MAY_PUBLISH_ROOM_CALLBACK] = None,
check_username_for_spam: Optional[CHECK_USERNAME_FOR_SPAM_CALLBACK] = None,
check_registration_for_spam: Optional[
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
self.check_event_for_spam_callbacks.append(check_event_for_spam)
if should_drop_federated_event is not None:
self.should_drop_federated_event_callbacks.append(
should_drop_federated_event
)
if user_may_join_room is not None:
self.user_may_join_room_callbacks.append(user_may_join_room)
if user_may_invite is not None:
self.user_may_invite_callbacks.append(user_may_invite)
if user_may_send_3pid_invite is not None:
self.user_may_send_3pid_invite_callbacks.append(
user_may_send_3pid_invite,
)
if user_may_create_room is not None:
self.user_may_create_room_callbacks.append(user_may_create_room)
if user_may_create_room_alias is not None:
self.user_may_create_room_alias_callbacks.append(
user_may_create_room_alias,
)
if user_may_publish_room is not None:
self.user_may_publish_room_callbacks.append(user_may_publish_room)
if check_username_for_spam is not None:
self.check_username_for_spam_callbacks.append(check_username_for_spam)
if check_registration_for_spam is not None:
self.check_registration_for_spam_callbacks.append(
check_registration_for_spam,
)
if check_media_file_for_spam is not None:
self.check_media_file_for_spam_callbacks.append(check_media_file_for_spam)

View File

@@ -0,0 +1,238 @@
# Copyright 2019, 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
CHECK_EVENT_ALLOWED_CALLBACK = Callable[
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
]
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
[str, str, StateMap[EventBase]], Awaitable[bool]
]
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
[str, StateMap[EventBase], str], Awaitable[bool]
]
ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
"""Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement.
"""
if hs.config.thirdpartyrules.third_party_event_rules is None:
return
module, config = hs.config.thirdpartyrules.third_party_event_rules
api = hs.get_module_api()
third_party_rules = module(config=config, module_api=api)
# The known hooks. If a module implements a method which name appears in this set,
# we'll want to register it.
third_party_event_rules_methods = {
"check_event_allowed",
"on_create_room",
"check_threepid_can_be_invited",
"check_visibility_can_be_modified",
}
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
# f might be None if the callback isn't implemented by the module. In this
# case we don't want to register a callback at all so we return None.
if f is None:
return None
# We return a separate wrapper for these methods because, in order to wrap them
# correctly, we need to await its result. Therefore it doesn't make a lot of
# sense to make it go through the run() wrapper.
if f.__name__ == "check_event_allowed":
# We need to wrap check_event_allowed because its old form would return either
# a boolean or a dict, but now we want to return the dict separately from the
# boolean.
async def wrap_check_event_allowed(
event: EventBase,
state_events: StateMap[EventBase],
) -> Tuple[bool, Optional[dict]]:
# Assertion required because mypy can't prove we won't change
# `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
res = await f(event, state_events)
if isinstance(res, dict):
return True, res
else:
return res, None
return wrap_check_event_allowed
if f.__name__ == "on_create_room":
# We need to wrap on_create_room because its old form would return a boolean
# if the room creation is denied, but now we just want it to raise an
# exception.
async def wrap_on_create_room(
requester: Requester, config: dict, is_requester_admin: bool
) -> None:
# Assertion required because mypy can't prove we won't change
# `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
res = await f(requester, config, is_requester_admin)
if res is False:
raise SynapseError(
403,
"Room creation forbidden with these parameters",
)
return wrap_on_create_room
def run(*args: Any, **kwargs: Any) -> Awaitable:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
return run
# Register the hooks through the module API.
hooks = {
hook: async_wrapper(getattr(third_party_rules, hook, None))
for hook in third_party_event_rules_methods
}
api.register_third_party_rules_callbacks(**hooks)
class ThirdPartyEventRulesModuleApiCallbacks:
def __init__(self) -> None:
self.check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
self.on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
self.check_threepid_can_be_invited_callbacks: List[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = []
self.check_visibility_can_be_modified_callbacks: List[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = []
self.on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
self.check_can_shutdown_room_callbacks: List[
CHECK_CAN_SHUTDOWN_ROOM_CALLBACK
] = []
self.check_can_deactivate_user_callbacks: List[
CHECK_CAN_DEACTIVATE_USER_CALLBACK
] = []
self.on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = []
self.on_user_deactivation_status_changed_callbacks: List[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = []
self.on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
self.on_add_user_third_party_identifier_callbacks: List[
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = []
self.on_remove_user_third_party_identifier_callbacks: List[
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = []
def register_callbacks(
self,
check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
check_threepid_can_be_invited: Optional[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
] = None,
check_visibility_can_be_modified: Optional[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = None,
on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None,
check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None,
on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None,
on_user_deactivation_status_changed: Optional[
ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
] = None,
on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
on_add_user_third_party_identifier: Optional[
ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = None,
on_remove_user_third_party_identifier: Optional[
ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
] = None,
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
self.check_event_allowed_callbacks.append(check_event_allowed)
if on_create_room is not None:
self.on_create_room_callbacks.append(on_create_room)
if check_threepid_can_be_invited is not None:
self.check_threepid_can_be_invited_callbacks.append(
check_threepid_can_be_invited,
)
if check_visibility_can_be_modified is not None:
self.check_visibility_can_be_modified_callbacks.append(
check_visibility_can_be_modified,
)
if on_new_event is not None:
self.on_new_event_callbacks.append(on_new_event)
if check_can_shutdown_room is not None:
self.check_can_shutdown_room_callbacks.append(check_can_shutdown_room)
if check_can_deactivate_user is not None:
self.check_can_deactivate_user_callbacks.append(check_can_deactivate_user)
if on_profile_update is not None:
self.on_profile_update_callbacks.append(on_profile_update)
if on_user_deactivation_status_changed is not None:
self.on_user_deactivation_status_changed_callbacks.append(
on_user_deactivation_status_changed,
)
if on_threepid_bind is not None:
self.on_threepid_bind_callbacks.append(on_threepid_bind)
if on_add_user_third_party_identifier is not None:
self.on_add_user_third_party_identifier_callbacks.append(
on_add_user_third_party_identifier
)
if on_remove_user_third_party_identifier is not None:
self.on_remove_user_third_party_identifier_callbacks.append(
on_remove_user_third_party_identifier
)

View File

@@ -273,10 +273,7 @@ class BulkPushRuleEvaluator:
related_event_id, allow_none=True
)
if related_event is not None:
related_events[relation_type] = _flatten_dict(
related_event,
msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key,
)
related_events[relation_type] = _flatten_dict(related_event)
reply_event_id = (
event.content.get("m.relates_to", {})
@@ -291,10 +288,7 @@ class BulkPushRuleEvaluator:
)
if related_event is not None:
related_events["m.in_reply_to"] = _flatten_dict(
related_event,
msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key,
)
related_events["m.in_reply_to"] = _flatten_dict(related_event)
# indicate that this is from a fallback relation.
if relation_type == "m.thread" and event.content.get(
@@ -401,10 +395,7 @@ class BulkPushRuleEvaluator:
)
evaluator = PushRuleEvaluator(
_flatten_dict(
event,
msc3873_escape_event_match_key=self.hs.config.experimental.msc3873_escape_event_match_key,
),
_flatten_dict(event),
has_mentions,
room_member_count,
sender_power_level,
@@ -413,7 +404,6 @@ class BulkPushRuleEvaluator:
self._related_event_match_enabled,
event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
self.hs.config.experimental.msc3966_exact_event_property_contains,
)
users = rules_by_user.keys()
@@ -495,8 +485,6 @@ def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, JsonValue]] = None,
*,
msc3873_escape_event_match_key: bool = False,
) -> Dict[str, JsonValue]:
"""
Given a JSON dictionary (or event) which might contain sub dictionaries,
@@ -525,11 +513,10 @@ def _flatten_dict(
if result is None:
result = {}
for key, value in d.items():
if msc3873_escape_event_match_key:
# Escape periods in the key with a backslash (and backslashes with an
# extra backslash). This is since a period is used as a separator between
# nested fields.
key = key.replace("\\", "\\\\").replace(".", "\\.")
# Escape periods in the key with a backslash (and backslashes with an
# extra backslash). This is since a period is used as a separator between
# nested fields.
key = key.replace("\\", "\\\\").replace(".", "\\.")
if _is_simple_value(value):
result[".".join(prefix + [key])] = value
@@ -537,12 +524,7 @@ def _flatten_dict(
result[".".join(prefix + [key])] = [v for v in value if _is_simple_value(v)]
elif isinstance(value, Mapping):
# do not set `room_version` due to recursion considerations below
_flatten_dict(
value,
prefix=(prefix + [key]),
result=result,
msc3873_escape_event_match_key=msc3873_escape_event_match_key,
)
_flatten_dict(value, prefix=(prefix + [key]), result=result)
# `room_version` should only ever be set when looking at the top level of an event
if (

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError, SynapseError
@@ -23,10 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.logging.opentracing import set_tag
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, Requester, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -70,10 +70,13 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
async def on_POST(
self, request: SynapseRequest, txn_id: Optional[str] = None
async def _do(
self,
request: SynapseRequest,
requester: Requester,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
await assert_user_is_admin(self.auth, requester)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
event_type = body.get("type", EventTypes.Message)
@@ -106,9 +109,18 @@ class SendServerNoticeServlet(RestServlet):
return HTTPStatus.OK, {"event_id": event.event_id}
def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester, None)
async def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, txn_id
)

View File

@@ -683,19 +683,18 @@ class AccountValidityRenewServlet(RestServlet):
PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs: "HomeServer"):
self.account_activity_handler = hs.get_account_validity_handler()
self.account_validity_handler = hs.get_account_validity_handler()
self.account_validity_module_callbacks = (
hs.get_module_api_callbacks().account_validity
)
self.auth = hs.get_auth()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.account_activity_handler.on_legacy_admin_request_callback:
expiration_ts = (
await (
self.account_activity_handler.on_legacy_admin_request_callback(
request
)
)
if self.account_validity_module_callbacks.on_legacy_admin_request_callback:
expiration_ts = await self.account_validity_module_callbacks.on_legacy_admin_request_callback(
request
)
else:
body = parse_json_object_from_request(request)
@@ -706,7 +705,7 @@ class AccountValidityRenewServlet(RestServlet):
"Missing property 'user_id' in the request body",
)
expiration_ts = await self.account_activity_handler.renew_account_for_user(
expiration_ts = await self.account_validity_handler.renew_account_for_user(
body["user_id"],
body.get("expiration_ts"),
not body.get("enable_renewal_emails", True),

View File

@@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
@@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
def on_PUT(
async def on_PUT(
self, request: SynapseRequest, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester)
async def _do(
self, request: SynapseRequest, requester: Requester
) -> Tuple[int, JsonDict]:
room_id, _, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(TransactionRestServlet):
class RoomStateEventRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super().__init__()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
event_type: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict: JsonDict = {
@@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_GET(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Tuple[int, str]:
return 200, "Not implemented"
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_type: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, event_type, None)
def on_PUT(
async def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
event_type,
txn_id,
)
@@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_identifier: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request, allow_empty_body=True)
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
@@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_identifier: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_identifier, None)
async def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, room_identifier, txn_id
)
# TODO: Needs unit testing
class PublicRoomListRestServlet(TransactionRestServlet):
class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
def on_PUT(
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
return await self._do(requester, room_id)
async def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, requester, room_id
)
@@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
membership_action: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
Membership.JOIN,
Membership.LEAVE,
@@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
membership_action: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
return await self._do(request, requester, room_id, membership_action, None)
async def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
return await self.txns.fetch_or_execute_request(
request,
requester,
self._do,
request,
requester,
room_id,
membership_action,
txn_id,
)
@@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
async def on_POST(
async def _do(
self,
request: SynapseRequest,
requester: Requester,
room_id: str,
event_id: str,
txn_id: Optional[str] = None,
txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
try:
@@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_PUT(
async def on_POST(
self,
request: SynapseRequest,
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
return await self._do(request, requester, room_id, event_id, None)
async def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
return await self.txns.fetch_or_execute_request(
request, requester, self._do, request, requester, room_id, event_id, txn_id
)
@@ -1224,7 +1280,6 @@ def register_txn_path(
servlet: RestServlet,
regex_string: str,
http_server: HttpServer,
with_get: bool = False,
) -> None:
"""Registers a transaction-based path.
@@ -1236,7 +1291,6 @@ def register_txn_path(
regex_string: The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server: The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
on_POST = getattr(servlet, "on_POST", None)
on_PUT = getattr(servlet, "on_PUT", None)
@@ -1254,18 +1308,6 @@ def register_txn_path(
on_PUT,
servlet.__class__.__name__,
)
on_GET = getattr(servlet, "on_GET", None)
if with_get:
if on_GET is None:
raise RuntimeError(
"register_txn_path called with with_get = True, but no on_GET method exists"
)
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
on_GET,
servlet.__class__.__name__,
)
class TimestampLookupRestServlet(RestServlet):

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Tuple
from typing import TYPE_CHECKING, Tuple
from synapse.http import servlet
from synapse.http.server import HttpServer
@@ -21,7 +21,7 @@ from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_r
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
from synapse.types import JsonDict, Requester
from ._base import client_patterns
@@ -43,19 +43,26 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
async def _put(
async def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
return await self.txns.fetch_or_execute_request(
request,
requester,
self._put,
request,
requester,
message_type,
)
async def _put(
self,
request: SynapseRequest,
requester: Requester,
message_type: str,
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("messages",))

View File

@@ -15,16 +15,16 @@
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Hashable, Tuple
from typing_extensions import ParamSpec
from twisted.internet.defer import Deferred
from twisted.python.failure import Failure
from twisted.web.server import Request
from twisted.web.iweb import IRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import JsonDict
from synapse.types import JsonDict, Requester
from synapse.util.async_helpers import ObservableDeferred
if TYPE_CHECKING:
@@ -41,53 +41,47 @@ P = ParamSpec("P")
class HttpTransactionCache:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
self.transactions: Dict[
str, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
Hashable, Tuple[ObservableDeferred[Tuple[int, JsonDict]], int]
] = {}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request: Request) -> str:
def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable:
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
Idempotency is based on the returned key being the same for separate
requests to the same endpoint. The key is formed from the HTTP request
path and the access_token for the requesting user.
path and attributes from the requester: the access_token_id for regular users,
the user ID for guest users, and the appservice ID for appservice users.
Args:
request: The incoming request. Must contain an access_token.
request: The incoming request.
requester: The requester doing the request.
Returns:
A transaction key
"""
assert request.path is not None
token = self.auth.get_access_token_from_request(request)
return request.path.decode("utf8") + "/" + token
path: str = request.path.decode("utf8")
if requester.is_guest:
assert requester.user is not None, "Guest requester must have a user ID set"
return (path, "guest", requester.user)
elif requester.app_service is not None:
return (path, "appservice", requester.app_service.id)
else:
assert (
requester.access_token_id is not None
), "Requester must have an access_token_id"
return (path, "user", requester.access_token_id)
def fetch_or_execute_request(
self,
request: Request,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
) -> Awaitable[Tuple[int, JsonDict]]:
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
See:
fetch_or_execute
"""
return self.fetch_or_execute(
self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(
self,
txn_key: str,
request: IRequest,
requester: Requester,
fn: Callable[P, Awaitable[Tuple[int, JsonDict]]],
*args: P.args,
**kwargs: P.kwargs,
@@ -96,14 +90,15 @@ class HttpTransactionCache:
to produce a response for this transaction.
Args:
txn_key: A key to ensure idempotency should fetch_or_execute be
called again at a later point in time.
request:
requester:
fn: A function which returns a tuple of (response_code, response_dict).
*args: Arguments to pass to fn.
**kwargs: Keyword arguments to pass to fn.
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
txn_key = self._get_transaction_key(request, requester)
if txn_key in self.transactions:
observable = self.transactions[txn_key][0]
else:

View File

@@ -108,6 +108,7 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.media.media_repository import MediaRepository
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks import ModuleApiCallbacks
from synapse.notifier import Notifier, ReplicationNotifier
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.push.pusherpool import PusherPool
@@ -673,7 +674,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_password_auth_provider(self) -> PasswordAuthProvider:
return PasswordAuthProvider()
return PasswordAuthProvider(self)
@cache_in_self
def get_room_member_handler(self) -> RoomMemberHandler:
@@ -777,6 +778,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_module_api(self) -> ModuleApi:
return ModuleApi(self, self.get_auth_handler())
@cache_in_self
def get_module_api_callbacks(self) -> ModuleApiCallbacks:
return ModuleApiCallbacks()
@cache_in_self
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)

View File

@@ -42,11 +42,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
ON_UPDATE_CALLBACK = Callable[[str, str, bool], AsyncContextManager[int]]
DEFAULT_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
MIN_BATCH_SIZE_CALLBACK = Callable[[str, str], Awaitable[int]]
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _BackgroundUpdateHandler:
"""A handler for a given background update.
@@ -149,13 +144,11 @@ class BackgroundUpdater:
self._database_name = database.name()
self._module_api_callbacks = hs.get_module_api_callbacks().background_updater
# if a background update is currently running, its name.
self._current_background_update: Optional[str] = None
self._on_update_callback: Optional[ON_UPDATE_CALLBACK] = None
self._default_batch_size_callback: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None
self._min_batch_size_callback: Optional[MIN_BATCH_SIZE_CALLBACK] = None
self._background_update_performance: Dict[str, BackgroundUpdatePerformance] = {}
self._background_update_handlers: Dict[str, _BackgroundUpdateHandler] = {}
self._all_done = False
@@ -175,31 +168,6 @@ class BackgroundUpdater:
self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms
self.sleep_enabled = hs.config.background_updates.sleep_enabled
def register_update_controller_callbacks(
self,
on_update: ON_UPDATE_CALLBACK,
default_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
min_batch_size: Optional[DEFAULT_BATCH_SIZE_CALLBACK] = None,
) -> None:
"""Register callbacks from a module for each hook."""
if self._on_update_callback is not None:
logger.warning(
"More than one module tried to register callbacks for controlling"
" background updates. Only the callbacks registered by the first module"
" (in order of appearance in Synapse's configuration file) that tried to"
" do so will be called."
)
return
self._on_update_callback = on_update
if default_batch_size is not None:
self._default_batch_size_callback = default_batch_size
if min_batch_size is not None:
self._min_batch_size_callback = min_batch_size
def _get_context_manager_for_update(
self,
sleep: bool,
@@ -228,8 +196,10 @@ class BackgroundUpdater:
Note: this is a *target*, and an iteration may take substantially longer or
shorter.
"""
if self._on_update_callback is not None:
return self._on_update_callback(update_name, database_name, oneshot)
if self._module_api_callbacks.on_update_callback is not None:
return self._module_api_callbacks.on_update_callback(
update_name, database_name, oneshot
)
return _BackgroundUpdateContextManager(
sleep, self._clock, self.sleep_duration_ms, self.update_duration_ms
@@ -239,8 +209,10 @@ class BackgroundUpdater:
"""The batch size to use for the first iteration of a new background
update.
"""
if self._default_batch_size_callback is not None:
return await self._default_batch_size_callback(update_name, database_name)
if self._module_api_callbacks.default_batch_size_callback is not None:
return await self._module_api_callbacks.default_batch_size_callback(
update_name, database_name
)
return self.default_background_batch_size
@@ -249,8 +221,10 @@ class BackgroundUpdater:
Used to ensure that progress is always made. Must be greater than 0.
"""
if self._min_batch_size_callback is not None:
return await self._min_batch_size_callback(update_name, database_name)
if self._module_api_callbacks.min_batch_size_callback is not None:
return await self._module_api_callbacks.min_batch_size_callback(
update_name, database_name
)
return self.minimum_background_batch_size

View File

@@ -14,17 +14,7 @@
# limitations under the License.
import logging
from enum import Enum, auto
from typing import (
Collection,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
import attr
from typing_extensions import Final
@@ -575,43 +565,29 @@ async def filter_events_for_server(
storage: StorageControllers,
target_server_name: str,
local_server_name: str,
events: Sequence[EventBase],
*,
redact: bool,
filter_out_erased_senders: bool,
filter_out_remote_partial_state_events: bool,
events: List[EventBase],
redact: bool = True,
check_history_visibility_only: bool = False,
) -> List[EventBase]:
"""Filter a list of events based on whether the target server is allowed to
"""Filter a list of events based on whether given server is allowed to
see them.
For a fully stated room, the target server is allowed to see an event E if:
- the state at E has world readable or shared history vis, OR
- the state at E says that the target server is in the room.
For a partially stated room, the target server is allowed to see E if:
- E was created by this homeserver, AND:
- the partial state at E has world readable or shared history vis, OR
- the partial state at E says that the target server is in the room.
TODO: state before or state after?
Args:
storage
target_server_name
local_server_name
server_name
events
redact: Controls what to do with events which have been filtered out.
If True, include their redacted forms; if False, omit them entirely.
filter_out_erased_senders: If true, also filter out events whose sender has been
redact: Whether to return a redacted version of the event, or
to filter them out entirely.
check_history_visibility_only: Whether to only check the
history visibility, rather than things like if the sender has been
erased. This is used e.g. during pagination to decide whether to
backfill or not.
filter_out_remote_partial_state_events: If True, also filter out events in
partial state rooms created by other homeservers.
Returns
The filtered events.
"""
def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
if erased_senders and erased_senders[event.sender]:
logger.info("Sender of %s has been erased, redacting", event.event_id)
return True
@@ -640,7 +616,7 @@ async def filter_events_for_server(
# server has no users in the room: redact
return False
if filter_out_erased_senders:
if not check_history_visibility_only:
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
else:
# We don't want to check whether users are erased, which is equivalent
@@ -655,15 +631,15 @@ async def filter_events_for_server(
# otherwise a room could be fully joined after we retrieve those, which would then bypass
# this check but would base the filtering on an outdated view of the membership events.
partial_state_invisible_event_ids: Set[str] = set()
if filter_out_remote_partial_state_events:
partial_state_invisible_events = set()
if not check_history_visibility_only:
for e in events:
sender_domain = get_domain_from_id(e.sender)
if (
sender_domain != local_server_name
and await storage.main.is_partial_state_room(e.room_id)
):
partial_state_invisible_event_ids.add(e.event_id)
partial_state_invisible_events.add(e)
# Let's check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't
@@ -682,20 +658,17 @@ async def filter_events_for_server(
target_server_name,
)
def include_event_in_output(e: EventBase) -> bool:
to_return = []
for e in events:
erased = is_sender_erased(e, erased_senders)
visible = check_event_is_visible(
event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
)
if e.event_id in partial_state_invisible_event_ids:
if e in partial_state_invisible_events:
visible = False
return visible and not erased
to_return = []
for e in events:
if include_event_in_output(e):
if visible and not erased:
to_return.append(e)
elif redact:
to_return.append(prune_event(e))

View File

@@ -19,10 +19,13 @@ import attr
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes
from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
from synapse.events.presence_router import PresenceRouter
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.presence_router_callbacks import (
load_legacy_presence_router,
)
from synapse.rest import admin
from synapse.rest.client import login, presence, room
from synapse.server import HomeServer

View File

@@ -1,5 +1,4 @@
from typing import Callable, Collection, List, Optional, Tuple
from unittest import mock
from typing import Callable, List, Optional, Tuple
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -501,87 +500,3 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(len(sent_pdus), 1)
self.assertEqual(sent_pdus[0].event_id, event_2.event_id)
self.assertFalse(per_dest_queue._catching_up)
def test_catch_up_is_not_blocked_by_remote_event_in_partial_state_room(
self,
) -> None:
"""Detects (part of?) https://github.com/matrix-org/synapse/issues/15220."""
# ARRANGE:
# - a local user (u1)
# - a room which contains u1 and two remote users, @u2:host2 and @u3:other
# - events in that room such that
# - history visibility is restricted
# - u1 sent message events e1 and e2
# - afterwards, u3 sent a remote event e3
# - catchup to begin for host2; last successfully sent event was e1
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
self.register_user("u1", "you the one")
u1_token = self.login("u1", "you the one")
room = self.helper.create_room_as("u1", tok=u1_token)
self.helper.send_state(
room_id=room,
event_type="m.room.history_visibility",
body={"history_visibility": "joined"},
tok=u1_token,
)
self.get_success(
event_injection.inject_member_event(self.hs, room, "@u2:host2", "join")
)
self.get_success(
event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
)
# create some events
event_id_1 = self.helper.send(room, "hello", tok=u1_token)["event_id"]
event_id_2 = self.helper.send(room, "world", tok=u1_token)["event_id"]
# pretend that u3 changes their displayname
event_id_3 = self.get_success(
event_injection.inject_member_event(self.hs, room, "@u3:other", "join")
).event_id
# destination_rooms should already be populated, but let us pretend that we already
# sent (successfully) up to and including event id 1
event_1 = self.get_success(self.hs.get_datastores().main.get_event(event_id_1))
assert event_1.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
)
)
# also fetch event 2 so we can compare its stream ordering to the sender's
# last_successful_stream_ordering later
event_2 = self.get_success(self.hs.get_datastores().main.get_event(event_id_2))
# Mock event 3 as having partial state
self.get_success(
event_injection.mark_event_as_partial_state(self.hs, event_id_3, room)
)
# Fail the test if we block on full state for event 3.
async def mock_await_full_state(event_ids: Collection[str]) -> None:
if event_id_3 in event_ids:
raise AssertionError("Tried to await full state for event_id_3")
# ACT
with mock.patch.object(
self.hs.get_storage_controllers().state._partial_state_events_tracker,
"await_full_state",
mock_await_full_state,
):
self.get_success(per_dest_queue._catch_up_transmission_loop())
# ASSERT
# We should have:
# - not sent event 3: it's not ours, and the room is partial stated
# - fallen back to sending event 2: it's the most recent event in the room
# we tried to send to host2
# - completed catch-up
self.assertEqual(len(sent_pdus), 1)
self.assertEqual(sent_pdus[0].event_id, event_id_2)
self.assertFalse(per_dest_queue._catching_up)
self.assertEqual(
per_dest_queue._last_successful_stream_ordering,
event_2.internal_metadata.stream_ordering,
)

View File

@@ -727,7 +727,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.called = True
on_logged_out = Mock(side_effect=on_logged_out)
self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
self.hs.get_module_api_callbacks().password_auth_provider.on_logged_out_callbacks.append(
on_logged_out
)
@@ -857,7 +857,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
)
m = Mock(return_value=make_awaitable(False))
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
self.hs.get_module_api_callbacks().password_auth_provider.is_3pid_allowed_callbacks = [
m
]
self.register_user(username, "password")
tok = self.login(username, "password")
@@ -887,7 +889,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
m.assert_called_once_with("email", "foo@test.com", registration)
m = Mock(return_value=make_awaitable(True))
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
self.hs.get_module_api_callbacks().password_auth_provider.is_3pid_allowed_callbacks = [
m
]
channel = self.make_request(
"POST",

View File

@@ -791,8 +791,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
return False
# Configure a spam checker that does not filter any users.
spam_checker = self.hs.get_spam_checker()
spam_checker._check_username_for_spam_callbacks = [allow_all]
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.check_username_for_spam_callbacks = [allow_all]
# The results do not change:
# We get one search result when searching for user2 by user1.
@@ -804,7 +804,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# All users are spammy.
return True
spam_checker._check_username_for_spam_callbacks = [block_all]
spam_checker_callbacks.check_username_for_spam_callbacks = [block_all]
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))

View File

@@ -31,7 +31,6 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo
@@ -39,6 +38,9 @@ from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage, ReadableFileWrapper
from synapse.media.storage_provider import FileStorageProviderBackend
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spam_checker_callbacks import (
load_legacy_spam_checkers,
)
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer

View File

@@ -228,14 +228,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
return len(result) > 0
@override_config(
{
"experimental_features": {
"msc3952_intentional_mentions": True,
"msc3966_exact_event_property_contains": True,
}
}
)
@override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
def test_user_mentions(self) -> None:
"""Test the behavior of an event which includes invalid user mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -331,14 +324,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
)
@override_config(
{
"experimental_features": {
"msc3952_intentional_mentions": True,
"msc3966_exact_event_property_contains": True,
}
}
)
@override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
def test_room_mentions(self) -> None:
"""Test the behavior of an event which includes invalid room mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)

View File

@@ -51,11 +51,7 @@ class FlattenDictTestCase(unittest.TestCase):
# If a field has a dot in it, escape it.
input = {"m.foo": {"b\\ar": "abc"}}
self.assertEqual({"m.foo.b\\ar": "abc"}, _flatten_dict(input))
self.assertEqual(
{"m\\.foo.b\\\\ar": "abc"},
_flatten_dict(input, msc3873_escape_event_match_key=True),
)
self.assertEqual({"m\\.foo.b\\\\ar": "abc"}, _flatten_dict(input))
def test_non_string(self) -> None:
"""String, booleans, ints, nulls and list of those should be kept while other items are dropped."""
@@ -125,7 +121,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
"content.org.matrix.msc1767.markup": [],
"content.org\\.matrix\\.msc1767\\.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -137,7 +133,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
"content.org.matrix.msc1767.markup": [],
"content.org\\.matrix\\.msc1767\\.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -173,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
msc3966_exact_event_property_contains=True,
)
def test_display_name(self) -> None:
@@ -526,7 +521,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"""Check that exact_event_property_contains conditions work as expected."""
condition = {
"kind": "org.matrix.msc3966.exact_event_property_contains",
"kind": "event_property_contains",
"key": "content.value",
"value": "foobaz",
}

View File

@@ -1249,9 +1249,8 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
# account status will fail.
return UserID.from_string(user_id).localpart == "someuser"
self.hs.get_account_validity_handler()._is_user_expired_callbacks.append(
is_expired
)
account_validity_callbacks = self.hs.get_module_api_callbacks().account_validity
account_validity_callbacks.is_user_expired_callbacks.append(is_expired)
self._test_status(
users=[user],

View File

@@ -33,7 +33,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
a user's account data changes.
"""
mocked_callback = Mock(return_value=make_awaitable(None))
self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
self.hs.get_module_api_callbacks().account_data.on_account_data_updated_callbacks.append(
mocked_callback
)

View File

@@ -814,7 +814,8 @@ class RoomsCreateTestCase(RoomBase):
return False
join_mock = Mock(side_effect=user_may_join_room)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.user_may_join_room_callbacks.append(join_mock)
channel = self.make_request(
"POST",
@@ -840,7 +841,8 @@ class RoomsCreateTestCase(RoomBase):
return Codes.CONSENT_NOT_GIVEN
join_mock = Mock(side_effect=user_may_join_room_codes)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(join_mock)
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.user_may_join_room_callbacks.append(join_mock)
channel = self.make_request(
"POST",
@@ -1162,7 +1164,8 @@ class RoomJoinTestCase(RoomBase):
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.user_may_join_room_callbacks.append(callback_mock)
# Join a first room, without being invited to it.
self.helper.join(self.room1, self.user2, tok=self.tok2)
@@ -1227,7 +1230,8 @@ class RoomJoinTestCase(RoomBase):
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
callback_mock = Mock(side_effect=user_may_join_room, spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_join_room_callbacks.append(callback_mock)
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.user_may_join_room_callbacks.append(callback_mock)
# Join a first room, without being invited to it.
self.helper.join(self.room1, self.user2, tok=self.tok2)
@@ -1642,8 +1646,8 @@ class RoomMessagesTestCase(RoomBase):
return self.mock_return_value
spam_checker = SpamCheck()
self.hs.get_spam_checker()._check_event_for_spam_callbacks.append(
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.check_event_for_spam_callbacks.append(
spam_checker.check_event_for_spam
)
@@ -3381,7 +3385,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.user_may_send_3pid_invite_callbacks.append(mock)
# Send a 3PID invite into the room and check that it succeeded.
email_to_invite = "teresa@example.com"
@@ -3446,7 +3451,8 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(synapse.module_api.NOT_SPAM),
spec=lambda *x: None,
)
self.hs.get_spam_checker()._user_may_send_3pid_invite_callbacks.append(mock)
spam_checker_callbacks = self.hs.get_module_api_callbacks().spam_checker
spam_checker_callbacks.user_may_send_3pid_invite_callbacks.append(mock)
# Send a 3PID invite into the room and check that it succeeded.
email_to_invite = "teresa@example.com"

View File

@@ -22,7 +22,9 @@ from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
from synapse.server import HomeServer
@@ -146,7 +148,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [
self.hs.get_module_api_callbacks().third_party_event_rules.check_event_allowed_callbacks = [
callback
]
@@ -202,7 +204,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
) -> Tuple[bool, Optional[JsonDict]]:
raise NastyHackException(429, "message")
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
self.hs.get_module_api_callbacks().third_party_event_rules.check_event_allowed_callbacks = [
check
]
# Make a request
channel = self.make_request(
@@ -229,7 +233,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev.content = {"x": "y"}
return True, None
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
self.hs.get_module_api_callbacks().third_party_event_rules.check_event_allowed_callbacks = [
check
]
# now send the event
channel = self.make_request(
@@ -253,7 +259,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
d["content"] = {"x": "y"}
return True, d
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
self.hs.get_module_api_callbacks().third_party_event_rules.check_event_allowed_callbacks = [
check
]
# now send the event
channel = self.make_request(
@@ -289,7 +297,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
return True, d
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
self.hs.get_module_api_callbacks().third_party_event_rules.check_event_allowed_callbacks = [
check
]
# Send an event, then edit it.
channel = self.make_request(
@@ -440,7 +450,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
)
return True, None
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [test_fn]
self.hs.get_module_api_callbacks().third_party_event_rules.check_event_allowed_callbacks = [
test_fn
]
# Sometimes the bug might not happen the first time the event type is added
# to the state but might happen when an event updates the state of the room for
@@ -466,7 +478,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None))
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
self.hs.get_module_api_callbacks().third_party_event_rules.on_new_event_callbacks.append(
on_new_event
)
@@ -569,7 +581,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
self.hs.get_module_api_callbacks().third_party_event_rules.on_profile_update_callbacks.append(
m
)
# Change the display name.
channel = self.make_request(
@@ -628,7 +642,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(m)
self.hs.get_module_api_callbacks().third_party_event_rules.on_profile_update_callbacks.append(
m
)
# Register an admin user.
self.register_user("admin", "password", admin=True)
@@ -667,15 +683,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(None))
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules._on_user_deactivation_status_changed_callbacks.append(
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules.on_user_deactivation_status_changed_callbacks.append(
deactivation_mock,
)
# Also register a mocked callback for profile updates, to check that the
# deactivation code calls it in a way that let modules know the user is being
# deactivated.
profile_mock = Mock(return_value=make_awaitable(None))
self.hs.get_third_party_event_rules()._on_profile_update_callbacks.append(
self.hs.get_module_api_callbacks().third_party_event_rules.on_profile_update_callbacks.append(
profile_mock,
)
@@ -725,8 +741,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mock callback.
m = Mock(return_value=make_awaitable(None))
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules.on_user_deactivation_status_changed_callbacks.append(m)
# Register an admin user.
self.register_user("admin", "password", admin=True)
@@ -779,8 +795,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False))
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules._check_can_deactivate_user_callbacks.append(
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules.check_can_deactivate_user_callbacks.append(
deactivation_mock,
)
@@ -825,8 +841,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
deactivation_mock = Mock(return_value=make_awaitable(False))
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules._check_can_deactivate_user_callbacks.append(
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules.check_can_deactivate_user_callbacks.append(
deactivation_mock,
)
@@ -864,8 +880,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
shutdown_mock = Mock(return_value=make_awaitable(False))
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules._check_can_shutdown_room_callbacks.append(
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules.check_can_shutdown_room_callbacks.append(
shutdown_mock,
)
@@ -900,8 +916,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
# Register a mocked callback.
threepid_bind_mock = Mock(return_value=make_awaitable(None))
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules.on_threepid_bind_callbacks.append(threepid_bind_mock)
# Register an admin user.
self.register_user("admin", "password", admin=True)
@@ -941,16 +957,18 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
just before associating and removing a 3PID to/from an account.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
on_add_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules.register_third_party_rules_callbacks(
on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
third_party_rules.on_threepid_bind_callbacks.append(
on_add_user_third_party_identifier_callback_mock
)
third_party_rules.on_threepid_bind_callbacks.append(
on_remove_user_third_party_identifier_callback_mock
)
# Register an admin user.
@@ -1006,12 +1024,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
when a user is deactivated and their third-party ID associations are deleted.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
on_remove_user_third_party_identifier_callback_mock = Mock(
return_value=make_awaitable(None)
)
third_party_rules = self.hs.get_third_party_event_rules()
third_party_rules.register_third_party_rules_callbacks(
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
third_party_rules.on_threepid_bind_callbacks.append(
on_remove_user_third_party_identifier_callback_mock
)
# Register an admin user.
@@ -1037,9 +1055,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the mock was not called on the act of adding a third-party ID.
on_remove_user_third_party_identifier_callback_mock.assert_not_called()
# Now deactivate the user.
channel = self.make_request(
"PUT",

View File

@@ -39,15 +39,23 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (HTTPStatus.OK, {"result": "GOOD JOB!"})
self.mock_key = "foo"
# Here we make sure that we're setting all the fields that HttpTransactionCache
# uses to build the transaction key.
self.mock_request = Mock()
self.mock_request.path = b"/foo/bar"
self.mock_requester = Mock()
self.mock_requester.app_service = None
self.mock_requester.is_guest = False
self.mock_requester.access_token_id = 1234
@defer.inlineCallbacks
def test_executes_given_function(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
)
cb.assert_called_once_with("some_arg", keyword="arg")
self.assertEqual(res, self.mock_http_response)
@@ -58,8 +66,13 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
res = yield self.cache.fetch_or_execute_request(
self.mock_request,
self.mock_requester,
cb,
"some_arg",
keyword="arg",
changing_args=i,
)
self.assertEqual(res, self.mock_http_response)
# expect only a single call to do the work
@@ -77,7 +90,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test() -> Generator["defer.Deferred[Any]", object, None]:
with LoggingContext("c") as c1:
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
self.assertIs(current_context(), c1)
self.assertEqual(res, (1, {}))
@@ -106,12 +121,16 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
yield self.cache.fetch_or_execute(self.mock_key, cb)
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@@ -134,29 +153,39 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
with LoggingContext("test") as test_context:
try:
yield self.cache.fetch_or_execute(self.mock_key, cb)
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
except Exception as e:
self.assertEqual(e.args[0], "boo")
self.assertIs(current_context(), test_context)
res = yield self.cache.fetch_or_execute(self.mock_key, cb)
res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb
)
self.assertEqual(res, self.mock_http_response)
self.assertIs(current_context(), test_context)
@defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
# still using cache
cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
# no longer using cache
self.assertEqual(cb.call_count, 2)
self.assertEqual(cb.call_args_list, [call("an arg"), call("an arg")])

View File

@@ -71,12 +71,18 @@ from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseRequest
from synapse.logging.context import ContextResourceUsage
from synapse.module_api.callbacks.presence_router_callbacks import (
load_legacy_presence_router,
)
from synapse.module_api.callbacks.spam_checker_callbacks import (
load_legacy_spam_checkers,
)
from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
load_legacy_third_party_event_rules,
)
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine

View File

@@ -102,34 +102,3 @@ async def create_event(
context = await unpersisted_context.persist(event)
return event, context
async def mark_event_as_partial_state(
hs: synapse.server.HomeServer,
event_id: str,
room_id: str,
) -> None:
"""
(Falsely) mark an event as having partial state.
Naughty, but occasionally useful when checking that partial state doesn't
block something from happening.
If the event already has partial state, this insert will fail (event_id is unique
in this table).
"""
store = hs.get_datastores().main
await store.db_pool.simple_upsert(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"room_id": room_id},
)
await store.db_pool.simple_insert(
table="partial_state_events",
values={
"room_id": room_id,
"event_id": event_id,
},
)

View File

@@ -63,13 +63,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers,
"test_server",
"hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, "test_server", "hs", events_to_filter
)
)
@@ -91,13 +85,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.assertEqual(
self.get_success(
filter_events_for_server(
self._storage_controllers,
"remote_hs",
"hs",
[outlier],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, "remote_hs", "hs", [outlier]
)
),
[outlier],
@@ -108,13 +96,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers,
"remote_hs",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
)
)
self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@@ -126,13 +108,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# be redacted)
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers,
"other_server",
"local_hs",
[outlier, evt],
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, "other_server", "local_hs", [outlier, evt]
)
)
self.assertEqual(filtered[0], outlier)
@@ -167,13 +143,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# ... and the filtering happens.
filtered = self.get_success(
filter_events_for_server(
self._storage_controllers,
"test_server",
"local_hs",
events_to_filter,
redact=True,
filter_out_erased_senders=True,
filter_out_remote_partial_state_events=True,
self._storage_controllers, "test_server", "local_hs", events_to_filter
)
)