Compare commits

...

13 Commits

Author SHA1 Message Date
H. Shay
4809a0cb35 merge in develop 2023-03-16 14:55:49 -07:00
H. Shay
ad325833e2 Merge branch 'develop' into shay/rework_module 2023-03-16 14:49:47 -07:00
H. Shay
d4ed0a48c1 requested changes 2023-03-16 14:31:10 -07:00
H. Shay
13676fb097 more develop merge fix 2023-03-06 12:53:52 -08:00
H. Shay
7fc487421f Merge branch 'develop' into shay/rework_module 2023-03-06 12:48:59 -08:00
H. Shay
2ab6cece75 add clearer return values 2023-03-06 12:21:27 -08:00
H. Shay
a9b0093d3a update changelog 2023-02-21 14:53:26 -08:00
H. Shay
9b702df296 newsfragment 2023-02-21 14:31:52 -08:00
H. Shay
7b610fca1a update docs with information on v2 callback 2023-02-21 14:27:53 -08:00
H. Shay
b564f29fe2 add a new test to check sending an additional event into room 2023-02-21 14:27:38 -08:00
H. Shay
5cebb3767b update tests to reflect new function signature 2023-02-21 14:27:19 -08:00
H. Shay
2e14bc3745 change callsites to reflect new function signature 2023-02-21 14:26:29 -08:00
H. Shay
aab5fb622e add check_event_allowed_v2 2023-02-21 14:25:49 -08:00
22 changed files with 373 additions and 100 deletions

1
changelog.d/15131.misc Normal file
View File

@@ -0,0 +1 @@
Add a new third party callback `check_event_allowed_v2` that is compatible with new batch persisting mechanisms.

View File

@@ -10,6 +10,75 @@ The available third party rules callbacks are:
### `check_event_allowed`
_First introduced in Synapse v1.7x.x
```python
async def check_event_allowed_v2(
event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict], Optional[dict]]
```
**<span style="color:red">
This callback is very experimental and can and will break without notice. Module developers
are encouraged to implement `check_event_for_spam` from the spam checker category instead.
</span>**
Returns:
- A tuple consisting of:
- a boolean representing whether or not the event is allowed
- an optional dict to form the basis of a replacement event for the event
- an optional dict to form the basis of an additional event to be sent into the
room
Called when processing any incoming event, with the event and a `StateMap`
representing the current state of the room the event is being sent into. A `StateMap` is
a dictionary that maps tuples containing an event type and a state key to the
corresponding state event. For example retrieving the room's `m.room.create` event from
the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
The module must return a boolean indicating whether the event can be allowed.
Note that this callback function processes incoming events coming via federation
traffic (on top of client traffic). This means denying an event might cause the local
copy of the room's history to diverge from that of remote servers. This may cause
federation issues in the room. It is strongly recommended to only deny events using this
callback function if the sender is a local user, or in a private federation in which all
servers are using the same module, with the same configuration.
If the boolean returned by the module is `True`, it may tell Synapse to replace the
event with new data by returning the new event's data as a dictionary. In order to do
that, it is recommended the module calls `event.get_dict()` to get the current event as a
dictionary, and modify the returned dictionary accordingly.
Module writers may also wish to use this check to send a second event into the room along
with the event being checked, if this is the case the module writer must provide a dict that
will form the basis of the event that is to be added to the room and it must be returned by `check_event_allowed_v2`.
This dict will then be turned into an event at the appropriate time and it will be persisted after the event
that triggered it, and if the event that triggered it is in a batch of events for persisting, it will be added to the
end of that batch. Note that the event MAY NOT be a membership event.
If `check_event_allowed_v2` raises an exception, the module is assumed to have failed.
The event will not be accepted but is not treated as explicitly rejected, either.
An HTTP request causing the module check will likely result in a 500 Internal
Server Error.
When the boolean returned by the module is `False`, the event is rejected.
(Module developers should not use exceptions for rejection.)
Note that replacing the event or adding an event only works for events sent by local users, not for events
received over federation.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
callback that does not return `True` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. This callback cannot be used in conjunction with `check_event_allowed`,
only one of these callbacks may be operational at a time - if both `check_event_allowed` and `check_event_allowed_v2`
active only `check_event_allowed` will be executed.
### `check_event_allowed`
_First introduced in Synapse v1.39.0_
```python

View File

@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
CHECK_EVENT_ALLOWED_CALLBACK = Callable[
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
]
CHECK_EVENT_ALLOWED_V2_CALLBACK = Callable[
[EventBase, StateMap[EventBase]],
Awaitable[Tuple[bool, Optional[dict], Optional[dict]]],
]
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
[str, str, StateMap[EventBase]], Awaitable[bool]
@@ -155,6 +159,9 @@ class ThirdPartyEventRules:
self._storage_controllers = hs.get_storage_controllers()
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
self._check_event_allowed_v2_callbacks: List[
CHECK_EVENT_ALLOWED_V2_CALLBACK
] = []
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
self._check_threepid_can_be_invited_callbacks: List[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
@@ -184,6 +191,7 @@ class ThirdPartyEventRules:
def register_third_party_rules_callbacks(
self,
check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
check_event_allowed_v2: Optional[CHECK_EVENT_ALLOWED_V2_CALLBACK] = None,
on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
check_threepid_can_be_invited: Optional[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
@@ -210,6 +218,9 @@ class ThirdPartyEventRules:
if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed)
if check_event_allowed_v2 is not None:
self._check_event_allowed_v2_callbacks.append(check_event_allowed_v2)
if on_create_room is not None:
self._on_create_room_callbacks.append(on_create_room)
@@ -256,7 +267,7 @@ class ThirdPartyEventRules:
self,
event: EventBase,
context: UnpersistedEventContextBase,
) -> Tuple[bool, Optional[dict]]:
) -> Tuple[bool, Optional[dict], Optional[dict]]:
"""Check if a provided event should be allowed in the given context.
The module can return:
@@ -264,7 +275,8 @@ class ThirdPartyEventRules:
* False: the event is not allowed, and should be rejected with M_FORBIDDEN.
If the event is allowed, the module can also return a dictionary to use as a
replacement for the event.
replacement for the event, and/or return a dictionary to use as the basis for
another event to be sent into the room.
Args:
event: The event to be checked.
@@ -274,8 +286,11 @@ class ThirdPartyEventRules:
The result from the ThirdPartyRules module, as above.
"""
# Bail out early without hitting the store if we don't have any callbacks to run.
if len(self._check_event_allowed_callbacks) == 0:
return True, None
if (
len(self._check_event_allowed_callbacks) == 0
and len(self._check_event_allowed_v2_callbacks) == 0
):
return True, None, None
prev_state_ids = await context.get_prev_state_ids()
@@ -288,35 +303,63 @@ class ThirdPartyEventRules:
# the hashes and signatures.
event.freeze()
for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await delay_cancellation(
callback(event, state_events)
)
except CancelledError:
raise
except SynapseError as e:
# FIXME: Being able to throw SynapseErrors is relied upon by
# some modules. PR #10386 accidentally broke this ability.
# That said, we aren't keen on exposing this implementation detail
# to modules and we should one day have a proper way to do what
# is wanted.
# This module callback needs a rework so that hacks such as
# this one are not necessary.
raise e
except Exception:
raise ModuleFailedException(
"Failed to run `check_event_allowed` module API callback"
)
if len(self._check_event_allowed_callbacks) != 0:
for callback in self._check_event_allowed_callbacks:
try:
res, replacement_data = await delay_cancellation(
callback(event, state_events)
)
except CancelledError:
raise
except SynapseError as e:
# FIXME: Being able to throw SynapseErrors is relied upon by
# some modules. PR #10386 accidentally broke this ability.
# That said, we aren't keen on exposing this implementation detail
# to modules and we should one day have a proper way to do what
# is wanted.
# This module callback needs a rework so that hacks such as
# this one are not necessary.
raise e
except Exception:
raise ModuleFailedException(
"Failed to run `check_event_allowed` module API callback"
)
# Return if the event shouldn't be allowed or if the module came up with a
# replacement dict for the event.
if res is False:
return res, None
elif isinstance(replacement_data, dict):
return True, replacement_data
# Return if the event shouldn't be allowed or if the module came up with a
# replacement dict for the event.
if res is False:
return res, None, None
elif isinstance(replacement_data, dict):
return True, replacement_data, None
else:
for v2_callback in self._check_event_allowed_v2_callbacks:
try:
res, replacement_data, new_event = await delay_cancellation(
v2_callback(event, state_events)
)
except CancelledError:
raise
except SynapseError as e:
# FIXME: Being able to throw SynapseErrors is relied upon by
# some modules. PR #10386 accidentally broke this ability.
# That said, we aren't keen on exposing this implementation detail
# to modules and we should one day have a proper way to do what
# is wanted.
# This module callback needs a rework so that hacks such as
# this one are not necessary.
raise e
except Exception:
raise ModuleFailedException(
"Failed to run `check_event_allowed_v2` module API callback"
)
return True, None
# Return if the event shouldn't be allowed, if the module came up with a
# replacement dict for the event, or if the module wants to send a new event
if res is False:
return res, None, None
else:
return True, replacement_data, new_event
return True, None, None
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool

View File

@@ -1007,6 +1007,7 @@ class FederationHandler:
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event(
builder=builder,
prev_event_ids=prev_event_ids,
@@ -1198,7 +1199,7 @@ class FederationHandler:
},
)
event, _ = await self.event_creation_handler.create_new_client_event(
event, _, _ = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1251,9 +1252,10 @@ class FederationHandler:
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event(builder=builder)
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
event_allowed, _, _ = await self.third_party_event_rules.check_event_allowed(
event, unpersisted_context
)
if not event_allowed:
@@ -1446,6 +1448,7 @@ class FederationHandler:
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1528,6 +1531,7 @@ class FederationHandler:
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1610,6 +1614,7 @@ class FederationHandler:
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event(builder=builder)
EventValidator().validate_new(event, self.config)

View File

@@ -404,9 +404,11 @@ class FederationEventHandler:
# for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK:
event_allowed, _ = await self._third_party_event_rules.check_event_allowed(
event, context
)
(
event_allowed,
_,
_,
) = await self._third_party_event_rules.check_event_allowed(event, context)
if not event_allowed:
logger.info("Sending of knock %s forbidden by third-party rules", event)
raise SynapseError(

View File

@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import random
from builtins import dict
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
@@ -577,7 +578,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
) -> Tuple[EventBase, UnpersistedEventContextBase]:
) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
@@ -649,7 +650,9 @@ class EventCreationHandler:
exceeded
Returns:
Tuple of created event, Context
Tuple of created event, Context, and an optional event dict to form the basis
of a new event if third_party_rules would like to send an additional event as a
consequence of this event.
"""
await self.auth_blocking.check_auth_blocking(requester=requester)
@@ -711,7 +714,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
event, unpersisted_context = await self.create_new_client_event(
event, unpersisted_context, new_event = await self.create_new_client_event(
builder=builder,
requester=requester,
allow_no_prev_events=allow_no_prev_events,
@@ -765,7 +768,7 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
return event, unpersisted_context
return event, unpersisted_context, new_event
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1008,11 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
event, unpersisted_context = await self.create_event(
(
event,
unpersisted_context,
third_party_event_dict,
) = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1054,9 +1061,24 @@ class EventCreationHandler:
Codes.FORBIDDEN,
)
events_and_context = [(event, context)]
if third_party_event_dict:
(
third_party_event,
unpersisted_third_party_context,
_,
) = await self.create_event(
requester,
third_party_event_dict,
)
third_party_context = await unpersisted_third_party_context.persist(
third_party_event
)
events_and_context.append((third_party_event, third_party_context))
ev = await self.handle_new_client_event(
requester=requester,
events_and_context=[(event, context)],
events_and_context=events_and_context,
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
@@ -1086,7 +1108,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
) -> Tuple[EventBase, UnpersistedEventContextBase]:
) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters
@@ -1135,7 +1157,9 @@ class EventCreationHandler:
batch persisting
Returns:
Tuple of created event, UnpersistedEventContext
Tuple of created event, UnpersistedEventContext, and an optional event dict
to form the basis of a new event if third_party_rules would like to send an
additional event as a consequence of this event.
"""
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
@@ -1269,9 +1293,11 @@ class EventCreationHandler:
if requester:
context.app_service = requester.app_service
res, new_content = await self.third_party_event_rules.check_event_allowed(
event, context
)
(
res,
new_content,
new_event,
) = await self.third_party_event_rules.check_event_allowed(event, context)
if res is False:
logger.info(
"Event %s forbidden by third-party rules",
@@ -1291,7 +1317,7 @@ class EventCreationHandler:
await self._validate_event_relation(event)
logger.debug("Created event %s", event.event_id)
return event, context
return event, context, new_event
async def _validate_event_relation(self, event: EventBase) -> None:
"""
@@ -2046,7 +2072,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
event, unpersisted_context = await self.create_event(
event, unpersisted_context, _ = await self.create_event(
requester,
{
"type": EventTypes.Dummy,

View File

@@ -213,6 +213,7 @@ class RoomCreationHandler:
(
tombstone_event,
tombstone_unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
requester,
{
@@ -1066,7 +1067,11 @@ class RoomCreationHandler:
content: JsonDict,
for_batch: bool,
**kwargs: Any,
) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
) -> Tuple[
EventBase,
synapse.events.snapshot.UnpersistedEventContextBase,
Optional[dict],
]:
"""
Creates an event and associated event context.
Args:
@@ -1088,6 +1093,7 @@ class RoomCreationHandler:
(
new_event,
new_unpersisted_context,
third_party_event,
) = await self.event_creation_handler.create_event(
creator,
event_dict,
@@ -1103,7 +1109,7 @@ class RoomCreationHandler:
prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id
return new_event, new_unpersisted_context
return new_event, new_unpersisted_context, third_party_event
visibility = room_config.get("visibility", "private")
preset_config = room_config.get(
@@ -1121,7 +1127,7 @@ class RoomCreationHandler:
)
creation_content.update({"creator": creator_id})
creation_event, unpersisted_creation_context = await create_event(
creation_event, unpersisted_creation_context, _ = await create_event(
EventTypes.Create, creation_content, False
)
creation_context = await unpersisted_creation_context.persist(creation_event)
@@ -1161,14 +1167,17 @@ class RoomCreationHandler:
current_state_group = event_to_state[member_event_id]
events_to_send = []
third_party_events_to_append = []
# We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
power_event, power_context = await create_event(
power_event, power_context, power_tp_event = await create_event(
EventTypes.PowerLevels, pl_content, True
)
events_to_send.append((power_event, power_context))
if power_tp_event:
third_party_events_to_append.append(power_tp_event)
else:
power_level_content: JsonDict = {
"users": {creator_id: 100},
@@ -1211,76 +1220,114 @@ class RoomCreationHandler:
# apply those.
if power_level_content_override:
power_level_content.update(power_level_content_override)
pl_event, pl_context = await create_event(
pl_event, pl_context, pl_tp_event = await create_event(
EventTypes.PowerLevels,
power_level_content,
True,
)
events_to_send.append((pl_event, pl_context))
if pl_tp_event:
third_party_events_to_append.append(pl_tp_event)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event(
room_alias_event, room_alias_context, ra_tp_event = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
)
events_to_send.append((room_alias_event, room_alias_context))
if ra_tp_event:
third_party_events_to_append.append(ra_tp_event)
if (EventTypes.JoinRules, "") not in initial_state:
join_rules_event, join_rules_context = await create_event(
join_rules_event, join_rules_context, jr_tp_event = await create_event(
EventTypes.JoinRules,
{"join_rule": config["join_rules"]},
True,
)
events_to_send.append((join_rules_event, join_rules_context))
if jr_tp_event:
third_party_events_to_append.append(jr_tp_event)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
visibility_event, visibility_context = await create_event(
visibility_event, visibility_context, vis_tp_event = await create_event(
EventTypes.RoomHistoryVisibility,
{"history_visibility": config["history_visibility"]},
True,
)
events_to_send.append((visibility_event, visibility_context))
if vis_tp_event:
third_party_events_to_append.append(vis_tp_event)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
guest_access_event, guest_access_context = await create_event(
(
guest_access_event,
guest_access_context,
ga_tp_event,
) = await create_event(
EventTypes.GuestAccess,
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True,
)
events_to_send.append((guest_access_event, guest_access_context))
if ga_tp_event:
third_party_events_to_append.append(ga_tp_event)
for (etype, state_key), content in initial_state.items():
event, context = await create_event(
event, context, tp_event = await create_event(
etype, content, True, state_key=state_key
)
events_to_send.append((event, context))
if tp_event:
third_party_events_to_append.append(tp_event)
if config["encrypted"]:
encryption_event, encryption_context = await create_event(
encryption_event, encryption_context, encrypt_tp_event = await create_event(
EventTypes.RoomEncryption,
{"algorithm": RoomEncryptionAlgorithms.DEFAULT},
True,
state_key="",
)
events_to_send.append((encryption_event, encryption_context))
if encrypt_tp_event:
third_party_events_to_append.append(encrypt_tp_event)
if "name" in room_config:
name = room_config["name"]
name_event, name_context = await create_event(
name_event, name_context, name_tp_event = await create_event(
EventTypes.Name,
{"name": name},
True,
)
events_to_send.append((name_event, name_context))
if name_tp_event:
third_party_events_to_append.append(name_tp_event)
if "topic" in room_config:
topic = room_config["topic"]
topic_event, topic_context = await create_event(
topic_event, topic_context, topic_tp_event = await create_event(
EventTypes.Topic,
{"topic": topic},
True,
)
events_to_send.append((topic_event, topic_context))
if topic_tp_event:
third_party_events_to_append.append(topic_tp_event)
for event_dict in third_party_events_to_append:
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
creator,
event_dict,
prev_event_ids=prev_event,
state_map=state_map,
for_batch=True,
current_state_group=current_state_group,
)
context = await unpersisted_context.persist(event)
events_to_send.append((event, context))
datastore = self.hs.get_datastores().state
events_and_context = (

View File

@@ -327,7 +327,11 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
event, unpersisted_context = await self.event_creation_handler.create_event(
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service
),

View File

@@ -418,6 +418,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(
event,
unpersisted_context,
third_party_event,
) = await self.event_creation_handler.create_event(
requester,
{
@@ -472,6 +473,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
)
)
if third_party_event:
(
tp_event,
tp_unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
requester,
third_party_event,
prev_event_ids=[result_event.event_id],
)
tp_context = await tp_unpersisted_context.persist(tp_event)
await self.event_creation_handler.handle_new_client_event(
requester, events_and_context=[(tp_event, tp_context)]
)
if event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -1951,6 +1966,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
(
event,
unpersisted_context,
third_party_event_dict,
) = await self.event_creation_handler.create_event(
requester,
event_dict,
@@ -1962,10 +1978,24 @@ class RoomMemberMasterHandler(RoomMemberHandler):
context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True
events_and_context = [(event, context)]
if third_party_event_dict:
(
third_party_event,
third_party_unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
requester, third_party_event_dict
)
third_party_context = await third_party_unpersisted_context.persist(
event
)
events_and_context.append((third_party_event, third_party_context))
result_event = (
await self.event_creation_handler.handle_new_client_event(
requester,
events_and_context=[(event, context)],
events_and_context=events_and_context,
extra_users=[UserID.from_string(target_user)],
)
)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Tuple
from typing import Optional, Tuple
from twisted.test.proto_helpers import MemoryReactor
@@ -81,7 +81,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def _create_duplicate_event(
self, txn_id: str
) -> Tuple[EventBase, UnpersistedEventContextBase]:
) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@@ -109,7 +109,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
event1, unpersisted_context = self._create_duplicate_event(txn_id)
event1, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
@@ -122,7 +122,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
event2, unpersisted_context = self._create_duplicate_event(txn_id)
event2, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
@@ -144,7 +144,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
event3, unpersisted_context = self._create_duplicate_event(txn_id)
event3, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event3))
self.assertNotEqual(event1.event_id, event3.event_id)
@@ -160,8 +160,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
event4, unpersisted_context = self._create_duplicate_event(txn_id)
event4, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event4))
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@@ -181,9 +182,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
event1, unpersisted_context1, _ = self._create_duplicate_event(txn_id)
context1 = self.get_success(unpersisted_context1.persist(event1))
event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
event2, unpersisted_context2, _ = self._create_duplicate_event(txn_id)
context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
@@ -209,7 +210,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
memberEvent, _ = self._create_and_persist_member_event()
# Try to create the event with empty prev_events bit with some auth_events
event, _ = self.get_success(
event, _, _ = self.get_success(
self.handler.create_event(
self.requester,
{

View File

@@ -507,7 +507,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_event(
requester,
{

View File

@@ -965,7 +965,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)

View File

@@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -171,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification.
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -202,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -378,7 +378,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation.
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{

View File

@@ -2935,7 +2935,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_new_client_event(builder)
)

View File

@@ -275,6 +275,46 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
def test_add_event(self) -> None:
# needs checking of combo of return conditions, ie replace event and send event
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict], Optional[dict]]:
event_dict = {
"type": "m.room.test",
"room_id": self.room_id,
"sender": self.user_id,
"content": {
"creator": "test_user",
"body": "message",
"msgtype": "message",
},
}
if ev.type == "message":
return True, None, event_dict
else:
return True, None, None
self.hs.get_third_party_event_rules()._check_event_allowed_v2_callbacks = [
check
]
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/message/1" % self.room_id,
{"x": "x"},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.result)
events = self.get_success(
self.hs.get_datastores().main.get_forward_extremities_for_room(self.room_id)
)
event = events[1]
e = self.get_success(self.hs.get_datastores().main.get_event(event["event_id"]))
self.assertEqual("m.room.test", e.type)
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""

View File

@@ -522,7 +522,8 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -545,7 +546,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
assert state_ids1 is not None
state1 = set(state_ids1.values())
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
event_handler.create_event(
self.requester,
{

View File

@@ -74,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -98,7 +98,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -123,7 +123,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -265,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
event_1, unpersisted_context_1 = self.get_success(
event_1, unpersisted_context_1, _ = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -290,7 +290,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.get_success(self._persistence.persist_event(event_1, context_1))
event_2, unpersisted_context_2 = self.get_success(
event_2, unpersisted_context_2, _ = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -431,7 +431,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
redaction_event, unpersisted_context = self.get_success(
redaction_event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)

View File

@@ -67,7 +67,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -521,7 +521,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
event1, unpersisted_context1 = self.get_success(
event1, unpersisted_context1, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
events_and_context.append((event1, unpersisted_context1))
@@ -537,7 +537,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
event2, unpersisted_context2 = self.get_success(
event2, unpersisted_context2, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder2)
)
events_and_context.append((event2, unpersisted_context2))
@@ -552,7 +552,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
event3, unpersisted_context3 = self.get_success(
event3, unpersisted_context3, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder3)
)
events_and_context.append((event3, unpersisted_context3))
@@ -568,7 +568,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
event4, unpersisted_context4 = self.get_success(
event4, unpersisted_context4, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder4)
)
events_and_context.append((event4, unpersisted_context4))

View File

@@ -95,6 +95,7 @@ async def create_event(
(
event,
unpersisted_context,
_,
) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)

View File

@@ -207,7 +207,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
context = self.get_success(unpersisted_context.persist(event))
@@ -233,7 +233,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
context = self.get_success(unpersisted_context.persist(event))
@@ -256,7 +256,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
context = self.get_success(unpersisted_context.persist(event))

View File

@@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
event_creator.create_event(
requester,
{

View File

@@ -335,9 +335,11 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
},
)
event, unpersisted_context = await event_creation_handler.create_new_client_event(
builder
)
(
event,
unpersisted_context,
_,
) = await event_creation_handler.create_new_client_event(builder)
context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)