Support for room version 12

This commit is contained in:
Kegan Dougal
2025-05-13 10:58:01 +01:00
committed by Andrew Morgan
parent edac7a471f
commit 731e81c9a3
28 changed files with 2174 additions and 189 deletions

View File

@@ -4094,7 +4094,7 @@ The default power levels for each preset are:
"m.room.history_visibility": 100
"m.room.canonical_alias": 50
"m.room.avatar": 50
"m.room.tombstone": 100
"m.room.tombstone": 100 (150 if MSC4289 is used)
"m.room.server_acl": 100
"m.room.encryption": 100
```

View File

@@ -5109,7 +5109,7 @@ properties:
"m.room.avatar": 50
"m.room.tombstone": 100
"m.room.tombstone": 100 (150 if MSC4289 is used)
"m.room.server_acl": 100

View File

@@ -46,6 +46,9 @@ MAX_USERID_LENGTH = 255
# Constant value used for the pseudo-thread which is the main timeline.
MAIN_TIMELINE: Final = "main"
# MAX_INT + 1, so it always trumps any PL in canonical JSON.
CREATOR_POWER_LEVEL = 2**53
class Membership:
"""Represents the membership states of a user in a room."""
@@ -235,6 +238,8 @@ class EventContentFields:
#
# This is deprecated in MSC2175.
ROOM_CREATOR: Final = "creator"
# MSC4289
ADDITIONAL_CREATORS: Final = "additional_creators"
# The version of the room for `m.room.create` events.
ROOM_VERSION: Final = "room_version"

View File

@@ -36,12 +36,14 @@ class EventFormatVersions:
ROOM_V1_V2 = 1 # $id:server event id format: used for room v1 and v2
ROOM_V3 = 2 # MSC1659-style $hash event id format: used for room v3
ROOM_V4_PLUS = 3 # MSC1884-style $hash format: introduced for room v4
ROOM_V11_HYDRA_PLUS = 4 # MSC4291 room IDs as hashes: introduced for room HydraV11
KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.ROOM_V1_V2,
EventFormatVersions.ROOM_V3,
EventFormatVersions.ROOM_V4_PLUS,
EventFormatVersions.ROOM_V11_HYDRA_PLUS,
}
@@ -50,6 +52,7 @@ class StateResolutionVersions:
V1 = 1 # room v1 state res
V2 = 2 # MSC1442 state res: room v2 and later
V2_1 = 3 # MSC4297 state res
class RoomDisposition:
@@ -109,6 +112,10 @@ class RoomVersion:
msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag
# MSC3757: Restricting who can overwrite a state event
msc3757_enabled: bool
# MSC4289: Creator power enabled
msc4289_creator_power_enabled: bool
# MSC4291: Room IDs as hashes of the create event
msc4291_room_ids_as_hashes: bool
class RoomVersions:
@@ -131,6 +138,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V2 = RoomVersion(
"2",
@@ -151,6 +160,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V3 = RoomVersion(
"3",
@@ -171,6 +182,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V4 = RoomVersion(
"4",
@@ -191,6 +204,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V5 = RoomVersion(
"5",
@@ -211,6 +226,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V6 = RoomVersion(
"6",
@@ -231,6 +248,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V7 = RoomVersion(
"7",
@@ -251,6 +270,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V8 = RoomVersion(
"8",
@@ -271,6 +292,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V9 = RoomVersion(
"9",
@@ -291,6 +314,8 @@ class RoomVersions:
enforce_int_power_levels=False,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V10 = RoomVersion(
"10",
@@ -311,6 +336,8 @@ class RoomVersions:
enforce_int_power_levels=True,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
MSC1767v10 = RoomVersion(
# MSC1767 (Extensible Events) based on room version "10"
@@ -332,6 +359,8 @@ class RoomVersions:
enforce_int_power_levels=True,
msc3931_push_features=(PushRuleRoomFlag.EXTENSIBLE_EVENTS,),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
MSC3757v10 = RoomVersion(
# MSC3757 (Restricting who can overwrite a state event) based on room version "10"
@@ -353,6 +382,8 @@ class RoomVersions:
enforce_int_power_levels=True,
msc3931_push_features=(),
msc3757_enabled=True,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
V11 = RoomVersion(
"11",
@@ -373,6 +404,8 @@ class RoomVersions:
enforce_int_power_levels=True,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
MSC3757v11 = RoomVersion(
# MSC3757 (Restricting who can overwrite a state event) based on room version "11"
@@ -394,6 +427,52 @@ class RoomVersions:
enforce_int_power_levels=True,
msc3931_push_features=(),
msc3757_enabled=True,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
)
HydraV11 = RoomVersion(
"org.matrix.hydra.11",
RoomDisposition.UNSTABLE,
EventFormatVersions.ROOM_V11_HYDRA_PLUS,
StateResolutionVersions.V2_1, # Changed from v11
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
implicit_room_creator=True, # Used by MSC3820
updated_redaction_rules=True, # Used by MSC3820
restricted_join_rule=True,
restricted_join_rule_fix=True,
knock_join_rule=True,
msc3389_relation_redactions=False,
knock_restricted_join_rule=True,
enforce_int_power_levels=True,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=True, # Changed from v11
msc4291_room_ids_as_hashes=True, # Changed from v11
)
V12 = RoomVersion(
"12",
RoomDisposition.STABLE,
EventFormatVersions.ROOM_V11_HYDRA_PLUS,
StateResolutionVersions.V2_1, # Changed from v11
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
implicit_room_creator=True, # Used by MSC3820
updated_redaction_rules=True, # Used by MSC3820
restricted_join_rule=True,
restricted_join_rule_fix=True,
knock_join_rule=True,
msc3389_relation_redactions=False,
knock_restricted_join_rule=True,
enforce_int_power_levels=True,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=True, # Changed from v11
msc4291_room_ids_as_hashes=True, # Changed from v11
)
@@ -411,6 +490,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V9,
RoomVersions.V10,
RoomVersions.V11,
RoomVersions.V12,
RoomVersions.MSC3757v10,
RoomVersions.MSC3757v11,
)

View File

@@ -101,6 +101,9 @@ def compute_content_hash(
event_dict.pop("outlier", None)
event_dict.pop("destinations", None)
# N.B. no need to pop the room_id from create events in MSC4291 rooms
# as they shouldn't have one.
event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes)

View File

@@ -45,6 +45,7 @@ from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
from synapse.api.constants import (
CREATOR_POWER_LEVEL,
MAX_PDU_SIZE,
EventContentFields,
EventTypes,
@@ -65,6 +66,7 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.state import CREATE_KEY
from synapse.events import is_creator
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
MutableStateMap,
@@ -261,7 +263,8 @@ async def check_state_independent_auth_rules(
f"Event {event.event_id} has unexpected auth_event for {k}: {auth_event_id}",
)
# We also need to check that the auth event itself is not rejected.
# 2.3 ... If there are entries which were themselves rejected under the checks performed on receipt
# of a PDU, reject.
if auth_event.rejected_reason:
raise AuthError(
403,
@@ -271,7 +274,7 @@ async def check_state_independent_auth_rules(
auth_dict[k] = auth_event_id
# 3. If event does not have a m.room.create in its auth_events, reject.
# 2.4. If event does not have a m.room.create in its auth_events, reject.
creation_event = auth_dict.get((EventTypes.Create, ""), None)
if not creation_event:
raise AuthError(403, "No create event in auth events")
@@ -311,13 +314,14 @@ def check_state_dependent_auth_rules(
# Later code relies on there being a create event e.g _can_federate, _is_membership_change_allowed
# so produce a more intelligible error if we don't have one.
if auth_dict.get(CREATE_KEY) is None:
create_event = auth_dict.get(CREATE_KEY)
if create_event is None:
raise AuthError(
403, f"Event {event.event_id} is missing a create event in auth_events."
)
# additional check for m.federate
creating_domain = get_domain_from_id(event.room_id)
creating_domain = get_domain_from_id(create_event.sender)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_dict):
@@ -470,12 +474,20 @@ def _check_create(event: "EventBase") -> None:
if event.prev_event_ids():
raise AuthError(403, "Create event has prev events")
# 1.2 If the domain of the room_id does not match the domain of the sender,
# reject.
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(403, "Creation event's room_id domain does not match sender's")
if event.room_version.msc4291_room_ids_as_hashes:
# 1.2 If the create event has a room_id, reject
if "room_id" in event:
raise AuthError(403, "Create event has a room_id")
else:
# 1.2 If the domain of the room_id does not match the domain of the sender,
# reject.
if not event.room_version.msc4291_room_ids_as_hashes:
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403, "Creation event's room_id domain does not match sender's"
)
# 1.3 If content.room_version is present and is not a recognised version, reject
room_version_prop = event.content.get("room_version", "1")
@@ -492,6 +504,16 @@ def _check_create(event: "EventBase") -> None:
):
raise AuthError(403, "Create event lacks a 'creator' property")
# 1.5 If the additional_creators field is present and is not an array of strings where each
# string is a valid user ID, reject.
if (
event.room_version.msc4289_creator_power_enabled
and EventContentFields.ADDITIONAL_CREATORS in event.content
):
check_valid_additional_creators(
event.content[EventContentFields.ADDITIONAL_CREATORS]
)
def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
creation_event = auth_events.get((EventTypes.Create, ""))
@@ -533,7 +555,13 @@ def _is_membership_change_allowed(
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
# We need the create event in order to check if we can federate or not.
# If it's missing, yell loudly. Previously we only did this inside the
# _can_federate check.
create_event = auth_events.get((EventTypes.Create, ""))
if not create_event:
raise AuthError(403, "Create event missing from auth_events")
creating_domain = get_domain_from_id(create_event.sender)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
@@ -903,6 +931,32 @@ def _check_power_levels(
except Exception:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
if room_version_obj.msc4289_creator_power_enabled:
# Enforce the creator does not appear in the users map
create_event = auth_events.get((EventTypes.Create, ""))
if not create_event:
raise SynapseError(
400, "Cannot check power levels without a create event in auth_events"
)
if create_event.sender in user_list:
raise SynapseError(
400,
"Creator user %s must not appear in content.users"
% (create_event.sender,),
)
additional_creators = create_event.content.get(
EventContentFields.ADDITIONAL_CREATORS, []
)
if additional_creators:
creators_in_user_list = set(additional_creators).intersection(
set(user_list)
)
if len(creators_in_user_list) > 0:
raise SynapseError(
400,
"Additional creators users must not appear in content.users",
)
# Reject events with stringy power levels if required by room version
if (
event.type == EventTypes.PowerLevels
@@ -1028,6 +1082,9 @@ def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> in
"A create event in the auth events chain is required to calculate user power level correctly,"
" but was not found. This indicates a bug"
)
if create_event.room_version.msc4289_creator_power_enabled:
if is_creator(create_event, user_id):
return CREATOR_POWER_LEVEL
power_level_event = get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
@@ -1188,3 +1245,26 @@ def auth_types_for_event(
auth_types.add(key)
return auth_types
def check_valid_additional_creators(additional_creators: Any) -> None:
"""Check if the additional_creators provided is valid according to MSC4289.
The additional_creators can be supplied from an m.room.create event or from an /upgrade request.
Raises:
AuthError if the additional_creators is invalid for some reason.
"""
if type(additional_creators) is not list:
raise AuthError(400, "additional_creators must be an array")
for entry in additional_creators:
if type(entry) is not str:
raise AuthError(400, "entry in additional_creators is not a string")
if not UserID.is_valid(entry):
raise AuthError(400, "entry in additional_creators is not a valid user ID")
# UserID.is_valid doesn't actually validate everything, so check the rest manually.
if len(entry) > 255 or len(entry.encode("utf-8")) > 255:
raise AuthError(
400,
"entry in additional_creators too long",
)

View File

@@ -41,10 +41,13 @@ from typing import (
import attr
from unpaddedbase64 import encode_base64
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
from synapse.synapse_rust.events import EventInternalMetadata
from synapse.types import JsonDict, StrCollection
from synapse.types import (
JsonDict,
StrCollection,
)
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@@ -209,7 +212,6 @@ class EventBase(metaclass=abc.ABCMeta):
content: DictProperty[JsonDict] = DictProperty("content")
hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
room_id: DictProperty[str] = DictProperty("room_id")
sender: DictProperty[str] = DictProperty("sender")
# TODO state_key should be Optional[str]. This is generally asserted in Synapse
# by calling is_state() first (which ensures it is not None), but it is hard (not possible?)
@@ -224,6 +226,10 @@ class EventBase(metaclass=abc.ABCMeta):
def event_id(self) -> str:
raise NotImplementedError()
@property
def room_id(self) -> str:
raise NotImplementedError()
@property
def membership(self) -> str:
return self.content["membership"]
@@ -386,6 +392,10 @@ class FrozenEvent(EventBase):
def event_id(self) -> str:
return self._event_id
@property
def room_id(self) -> str:
return self._dict["room_id"]
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.ROOM_V3 # All events of this type are V2
@@ -443,6 +453,10 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id
@property
def room_id(self) -> str:
return self._dict["room_id"]
def prev_event_ids(self) -> List[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@@ -481,6 +495,67 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id
class FrozenEventV4(FrozenEventV3):
"""FrozenEventV4 for MSC4291 room IDs are hashes"""
format_version = EventFormatVersions.ROOM_V11_HYDRA_PLUS
"""Override the room_id for m.room.create events"""
def __init__(
self,
event_dict: JsonDict,
room_version: RoomVersion,
internal_metadata_dict: Optional[JsonDict] = None,
rejected_reason: Optional[str] = None,
):
super().__init__(
event_dict=event_dict,
room_version=room_version,
internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
)
self._room_id: Optional[str] = None
@property
def room_id(self) -> str:
# if we have calculated the room ID already, don't do it again.
if self._room_id:
return self._room_id
is_create_event = self.type == EventTypes.Create and self.get_state_key() == ""
# for non-create events: use the supplied value from the JSON, as per FrozenEventV3
if not is_create_event:
self._room_id = self._dict["room_id"]
assert self._room_id is not None
return self._room_id
# for create events: calculate the room ID
from synapse.crypto.event_signing import compute_event_reference_hash
self._room_id = "!" + encode_base64(
compute_event_reference_hash(self)[1], urlsafe=True
)
return self._room_id
def auth_event_ids(self) -> StrCollection:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
Returns:
The list of event IDs of this event's auth_events
Includes the creation event ID for convenience of all the codepaths
which expects the auth chain to include the creator ID, even though
it's explicitly not included on the wire. Excludes the create event
for the create event itself.
"""
create_event_id = "$" + self.room_id[1:]
assert create_event_id not in self._dict["auth_events"]
if self.type == EventTypes.Create and self.get_state_key() == "":
return self._dict["auth_events"] # should be []
return self._dict["auth_events"] + [create_event_id]
def _event_type_from_format_version(
format_version: int,
) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
@@ -500,6 +575,8 @@ def _event_type_from_format_version(
return FrozenEventV2
elif format_version == EventFormatVersions.ROOM_V4_PLUS:
return FrozenEventV3
elif format_version == EventFormatVersions.ROOM_V11_HYDRA_PLUS:
return FrozenEventV4
else:
raise Exception("No event format %r" % (format_version,))
@@ -559,6 +636,23 @@ def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
return _EventRelation(parent_id, rel_type, aggregation_key)
def is_creator(create: EventBase, user_id: str) -> bool:
"""
Return true if the provided user ID is the room creator.
This includes additional creators in MSC4289.
"""
assert create.type == EventTypes.Create
if create.sender == user_id:
return True
if create.room_version.msc4289_creator_power_enabled:
additional_creators = set(
create.content.get(EventContentFields.ADDITIONAL_CREATORS, [])
)
return user_id in additional_creators
return False
@attr.s(slots=True, frozen=True, auto_attribs=True)
class StrippedStateEvent:
"""

View File

@@ -82,7 +82,8 @@ class EventBuilder:
room_version: RoomVersion
room_id: str
# MSC4291 makes the room ID == the create event ID. This means the create event has no room_id.
room_id: Optional[str]
type: str
sender: str
@@ -142,7 +143,14 @@ class EventBuilder:
Returns:
The signed and hashed event.
"""
# Create events always have empty auth_events.
if self.type == EventTypes.Create and self.is_state() and self.state_key == "":
auth_event_ids = []
# Calculate auth_events for non-create events
if auth_event_ids is None:
# Every non-create event must have a room ID
assert self.room_id is not None
state_ids = await self._state.compute_state_after_events(
self.room_id,
prev_event_ids,
@@ -224,12 +232,31 @@ class EventBuilder:
"auth_events": auth_events,
"prev_events": prev_events,
"type": self.type,
"room_id": self.room_id,
"sender": self.sender,
"content": self.content,
"unsigned": self.unsigned,
"depth": depth,
}
if self.room_id is not None:
event_dict["room_id"] = self.room_id
if self.room_version.msc4291_room_ids_as_hashes:
# In MSC4291: the create event has no room ID as the create event ID /is/ the room ID.
if (
self.type == EventTypes.Create
and self.is_state()
and self._state_key == ""
):
assert self.room_id is None
else:
# All other events do not reference the create event in auth_events, as the room ID
# /is/ the create event. However, the rest of the code (for consistency between room
# versions) assume that the create event remains part of the auth events. c.f. event
# class which automatically adds the create event when `.auth_event_ids()` is called
assert self.room_id is not None
create_event_id = "$" + self.room_id[1:]
auth_event_ids.remove(create_event_id)
event_dict["auth_events"] = auth_event_ids
if self.is_state():
event_dict["state_key"] = self._state_key
@@ -285,7 +312,7 @@ class EventBuilderFactory:
room_version=room_version,
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
room_id=key_values.get("room_id"),
sender=key_values["sender"],
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),

View File

@@ -176,9 +176,12 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
if room_version.updated_redaction_rules:
# MSC2176 rules state that create events cannot have their `content` redacted.
new_content = event_dict["content"]
elif not room_version.implicit_room_creator:
if not room_version.implicit_room_creator:
# Some room versions give meaning to `creator`
add_fields("creator")
if room_version.msc4291_room_ids_as_hashes:
# room_id is not allowed on the create event as it's derived from the event ID
allowed_keys.remove("room_id")
elif event_type == EventTypes.JoinRules:
add_fields("join_rule")
@@ -527,6 +530,10 @@ def serialize_event(
if config.as_client_event:
d = config.event_format(d)
# Ensure the room_id field is set for create events in MSC4291 rooms
if e.type == EventTypes.Create and e.room_version.msc4291_room_ids_as_hashes:
d["room_id"] = e.room_id
# If the event is a redaction, the field with the redacted event ID appears
# in a different location depending on the room version. e.redacts handles
# fetching from the proper location; copy it to the other location for forwards-
@@ -869,6 +876,14 @@ def strip_event(event: EventBase) -> JsonDict:
Stripped state events can only have the `sender`, `type`, `state_key` and `content`
properties present.
"""
# MSC4311: Ensure the create event is available on invites and knocks.
# TODO: Implement the rest of MSC4311
if (
event.room_version.msc4291_room_ids_as_hashes
and event.type == EventTypes.Create
and event.get_state_key() == ""
):
return event.get_pdu_json()
return {
"type": event.type,

View File

@@ -183,8 +183,18 @@ class EventValidator:
fields an event would have
"""
create_event_as_room_id = (
event.room_version.msc4291_room_ids_as_hashes
and event.type == EventTypes.Create
and hasattr(event, "state_key")
and event.state_key == ""
)
strings = ["room_id", "sender", "type"]
if create_event_as_room_id:
strings.remove("room_id")
if hasattr(event, "state_key"):
strings.append("state_key")
@@ -192,7 +202,14 @@ class EventValidator:
if not isinstance(getattr(event, s), str):
raise SynapseError(400, "Not '%s' a string type" % (s,))
RoomID.from_string(event.room_id)
if not create_event_as_room_id:
assert event.room_id is not None
RoomID.from_string(event.room_id)
if event.room_version.msc4291_room_ids_as_hashes and not RoomID.is_valid(
event.room_id
):
raise SynapseError(400, f"Invalid room ID '{event.room_id}'")
UserID.from_string(event.sender)
if event.type == EventTypes.Message:

View File

@@ -342,6 +342,21 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB
if room_version.strict_canonicaljson:
validate_canonicaljson(pdu_json)
# enforce that MSC4291 auth events don't include the create event.
# N.B. if they DO include a spurious create event, it'll fail auth checks elsewhere, so we don't
# need to do expensive DB lookups to find which event ID is the create event here.
if room_version.msc4291_room_ids_as_hashes:
room_id = pdu_json.get("room_id")
if room_id:
create_event_id = "$" + room_id[1:]
auth_events = pdu_json.get("auth_events")
if auth_events:
if create_event_id in auth_events:
raise SynapseError(
400,
"auth_events must not contain the create event",
Codes.BAD_JSON,
)
event = make_event_from_dict(pdu_json, room_version)
return event

View File

@@ -23,6 +23,8 @@ from typing import TYPE_CHECKING, List, Mapping, Optional, Union
from synapse import event_auth
from synapse.api.constants import (
CREATOR_POWER_LEVEL,
EventContentFields,
EventTypes,
JoinRules,
Membership,
@@ -141,6 +143,8 @@ class EventAuthHandler:
Raises:
SynapseError if no appropriate user is found.
"""
create_event_id = current_state_ids[(EventTypes.Create, "")]
create_event = await self._store.get_event(create_event_id)
power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
invite_level = 0
users_default_level = 0
@@ -156,15 +160,28 @@ class EventAuthHandler:
# Find the user with the highest power level (only interested in local
# users).
user_power_level = 0
chosen_user = None
local_users_in_room = await self._store.get_local_users_in_room(room_id)
chosen_user = max(
local_users_in_room,
key=lambda user: users.get(user, users_default_level),
default=None,
)
if create_event.room_version.msc4289_creator_power_enabled:
creators = set(
create_event.content.get(EventContentFields.ADDITIONAL_CREATORS, [])
)
creators.add(create_event.sender)
local_creators = creators.intersection(set(local_users_in_room))
if len(local_creators) > 0:
chosen_user = local_creators.pop() # random creator
user_power_level = CREATOR_POWER_LEVEL
else:
chosen_user = max(
local_users_in_room,
key=lambda user: users.get(user, users_default_level),
default=None,
)
# Return the chosen if they can issue invites.
if chosen_user:
user_power_level = users.get(chosen_user, users_default_level)
# Return the chosen if they can issue invites.
user_power_level = users.get(chosen_user, users_default_level)
if chosen_user and user_power_level >= invite_level:
logger.debug(
"Found a user who can issue invites %s with power level %d >= invite level %d",

View File

@@ -674,7 +674,10 @@ class EventCreationHandler:
Codes.USER_ACCOUNT_SUSPENDED,
)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
is_create_event = (
event_dict["type"] == EventTypes.Create and event_dict["state_key"] == ""
)
if is_create_event:
room_version_id = event_dict["content"]["room_version"]
maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not maybe_room_version_obj:
@@ -780,6 +783,7 @@ class EventCreationHandler:
"""
# the only thing the user can do is join the server notices room.
if builder.type == EventTypes.Member:
assert builder.room_id is not None
membership = builder.content.get("membership", None)
if membership == Membership.JOIN:
return await self.store.is_server_notice_room(builder.room_id)
@@ -1242,13 +1246,40 @@ class EventCreationHandler:
for_verification=False,
)
if (
builder.room_version.msc4291_room_ids_as_hashes
and builder.type == EventTypes.Create
and builder.is_state()
):
if builder.room_id is not None:
raise SynapseError(
400,
"Cannot resend m.room.create event",
Codes.INVALID_PARAM,
)
else:
assert builder.room_id is not None
if prev_event_ids is not None:
assert len(prev_event_ids) <= 10, (
"Attempting to create an event with %i prev_events"
% (len(prev_event_ids),)
)
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
if builder.room_id:
prev_event_ids = await self.store.get_prev_events_for_room(
builder.room_id
)
else:
prev_event_ids = [] # can only happen for the create event in MSC4291 rooms
if builder.type == EventTypes.Create and builder.is_state():
if len(prev_event_ids) != 0:
raise SynapseError(
400,
"Cannot resend m.room.create event",
Codes.INVALID_PARAM,
)
# We now ought to have some `prev_events` (unless it's a create event).
#
@@ -2124,6 +2155,7 @@ class EventCreationHandler:
original_event.room_version, third_party_result
)
self.validator.validate_builder(builder)
assert builder.room_id is not None
except SynapseError as e:
raise Exception(
"Third party rules module created an invalid event: " + e.msg,

View File

@@ -81,6 +81,7 @@ from synapse.types import (
Requester,
RoomAlias,
RoomID,
RoomIdWithDomain,
RoomStreamToken,
StateMap,
StrCollection,
@@ -188,7 +189,11 @@ class RoomCreationHandler:
)
async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
self,
requester: Requester,
old_room_id: str,
new_version: RoomVersion,
additional_creators: Optional[List[str]],
) -> str:
"""Replace a room with a new room with a different version
@@ -196,6 +201,7 @@ class RoomCreationHandler:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
new_version: the new room version to use
additional_creators: additional room creators, for MSC4289.
Returns:
the new room id
@@ -224,8 +230,29 @@ class RoomCreationHandler:
old_room = await self.store.get_room(old_room_id)
if old_room is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
old_room_is_public, _ = old_room
new_room_id = self._generate_room_id()
creation_event_with_context = None
if new_version.msc4291_room_ids_as_hashes:
old_room_create_event = await self.store.get_create_event_for_room(
old_room_id
)
creation_content = self._calculate_upgraded_room_creation_content(
old_room_create_event,
tombstone_event_id=None,
new_room_version=new_version,
additional_creators=additional_creators,
)
creation_event_with_context = await self._generate_create_event_for_room_id(
requester,
creation_content,
old_room_is_public,
new_version,
)
(create_event, _) = creation_event_with_context
new_room_id = create_event.room_id
else:
new_room_id = self._generate_room_id()
# Try several times, it could fail with PartialStateConflictError
# in _upgrade_room, cf comment in except block.
@@ -274,6 +301,8 @@ class RoomCreationHandler:
new_version,
tombstone_event,
tombstone_context,
additional_creators,
creation_event_with_context,
)
return ret
@@ -297,6 +326,10 @@ class RoomCreationHandler:
new_version: RoomVersion,
tombstone_event: EventBase,
tombstone_context: synapse.events.snapshot.EventContext,
additional_creators: Optional[List[str]],
creation_event_with_context: Optional[
Tuple[EventBase, synapse.events.snapshot.EventContext]
] = None,
) -> str:
"""
Args:
@@ -308,6 +341,8 @@ class RoomCreationHandler:
new_version: the version to upgrade the room to
tombstone_event: the tombstone event to send to the old room
tombstone_context: the context for the tombstone event
additional_creators: additional room creators, for MSC4289.
creation_event_with_context: The new room's create event, for room IDs as create event IDs.
Raises:
ShadowBanError if the requester is shadow-banned.
@@ -317,14 +352,16 @@ class RoomCreationHandler:
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
# create the new room. may raise a `StoreError` in the exceedingly unlikely
# event of a room ID collision.
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room[0],
room_version=new_version,
)
# We've already stored the room if we have the create event
if not creation_event_with_context:
# create the new room. may raise a `StoreError` in the exceedingly unlikely
# event of a room ID collision.
await self.store.store_room(
room_id=new_room_id,
room_creator_user_id=user_id,
is_public=old_room[0],
room_version=new_version,
)
await self.clone_existing_room(
requester,
@@ -332,6 +369,8 @@ class RoomCreationHandler:
new_room_id=new_room_id,
new_room_version=new_version,
tombstone_event_id=tombstone_event.event_id,
additional_creators=additional_creators,
creation_event_with_context=creation_event_with_context,
)
# now send the tombstone
@@ -365,6 +404,7 @@ class RoomCreationHandler:
old_room_id,
new_room_id,
old_room_state,
additional_creators,
)
return new_room_id
@@ -375,6 +415,7 @@ class RoomCreationHandler:
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
additional_creators: Optional[List[str]],
) -> None:
"""Send updated power levels in both rooms after an upgrade
@@ -383,7 +424,7 @@ class RoomCreationHandler:
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
additional_creators: Additional creators in the new room.
Raises:
ShadowBanError if the requester is shadow-banned.
"""
@@ -439,6 +480,14 @@ class RoomCreationHandler:
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
new_room_version = await self.store.get_room_version(new_room_id)
if new_room_version.msc4289_creator_power_enabled:
self._remove_creators_from_pl_users_map(
old_room_pl_state.content.get("users", {}),
requester.user.to_string(),
additional_creators,
)
await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
@@ -453,6 +502,30 @@ class RoomCreationHandler:
ratelimit=False,
)
def _calculate_upgraded_room_creation_content(
self,
old_room_create_event: EventBase,
tombstone_event_id: Optional[str],
new_room_version: RoomVersion,
) -> JsonDict:
creation_content: JsonDict = {
"room_version": new_room_version.identifier,
"predecessor": {
"room_id": old_room_create_event.room_id,
},
}
if tombstone_event_id is not None:
creation_content["predecessor"]["event_id"] = tombstone_event_id
# Check if old room was non-federatable
if not old_room_create_event.content.get(EventContentFields.FEDERATE, True):
# If so, mark the new room as non-federatable as well
creation_content[EventContentFields.FEDERATE] = False
# Copy the room type as per MSC3818.
room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
if room_type is not None:
creation_content[EventContentFields.ROOM_TYPE] = room_type
return creation_content
async def clone_existing_room(
self,
requester: Requester,
@@ -460,6 +533,10 @@ class RoomCreationHandler:
new_room_id: str,
new_room_version: RoomVersion,
tombstone_event_id: str,
additional_creators: Optional[List[str]],
creation_event_with_context: Optional[
Tuple[EventBase, synapse.events.snapshot.EventContext]
] = None,
) -> None:
"""Populate a new room based on an old room
@@ -470,24 +547,23 @@ class RoomCreationHandler:
created with _generate_room_id())
new_room_version: the new room version to use
tombstone_event_id: the ID of the tombstone event in the old room.
creation_event_with_context: The create event of the new room, if the new room supports
room ID as create event ID hash.
"""
user_id = requester.user.to_string()
creation_content: JsonDict = {
"room_version": new_room_version.identifier,
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
}
# Check if old room was non-federatable
# Get old room's create event
old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room
if not old_room_create_event.content.get(EventContentFields.FEDERATE, True):
# If so, mark the new room as non-federatable as well
creation_content[EventContentFields.FEDERATE] = False
if creation_event_with_context:
create_event, _ = creation_event_with_context
creation_content = create_event.content
else:
creation_content = self._calculate_upgraded_room_creation_content(
old_room_create_event,
tombstone_event_id,
new_room_version,
)
initial_state = {}
# Replicate relevant room events
@@ -503,11 +579,8 @@ class RoomCreationHandler:
(EventTypes.PowerLevels, ""),
]
# Copy the room type as per MSC3818.
room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
if room_type is not None:
creation_content[EventContentFields.ROOM_TYPE] = room_type
# If the old room was a space, copy over the rooms in the space.
if room_type == RoomTypes.SPACE:
types_to_copy.append((EventTypes.SpaceChild, None))
@@ -579,6 +652,14 @@ class RoomCreationHandler:
if current_power_level_int < needed_power_level:
user_power_levels[user_id] = needed_power_level
if new_room_version.msc4289_creator_power_enabled:
# the creator(s) cannot be in the users map
self._remove_creators_from_pl_users_map(
user_power_levels,
user_id,
additional_creators,
)
# We construct what the body of a call to /createRoom would look like for passing
# to the spam checker. We don't include a preset here, as we expect the
# initial state to contain everything we need.
@@ -607,6 +688,7 @@ class RoomCreationHandler:
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
creation_event_with_context=creation_event_with_context,
)
# Transfer membership events
@@ -890,6 +972,7 @@ class RoomCreationHandler:
power_level_content_override = config.get("power_level_content_override")
if (
power_level_content_override
and not room_version.msc4289_creator_power_enabled # this validation doesn't apply in MSC4289 rooms
and "users" in power_level_content_override
and user_id not in power_level_content_override["users"]
):
@@ -906,11 +989,41 @@ class RoomCreationHandler:
self._validate_room_config(config, visibility)
room_id = await self._generate_and_create_room_id(
creator_id=user_id,
is_public=is_public,
room_version=room_version,
)
creation_content = config.get("creation_content", {})
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
# trusted private chats have the invited users marked as additional creators
if (
room_version.msc4289_creator_power_enabled
and config.get("preset", None) == RoomCreationPreset.TRUSTED_PRIVATE_CHAT
and len(config.get("invite", [])) > 0
):
# the other user(s) are additional creators
invitees = config.get("invite", [])
# we don't want to replace any additional_creators additionally specified, and we want
# to remove duplicates.
creation_content[EventContentFields.ADDITIONAL_CREATORS] = list(
set(creation_content.get(EventContentFields.ADDITIONAL_CREATORS, []))
| set(invitees)
)
creation_event_with_context = None
if room_version.msc4291_room_ids_as_hashes:
creation_event_with_context = await self._generate_create_event_for_room_id(
requester,
creation_content,
is_public,
room_version,
)
(create_event, _) = creation_event_with_context
room_id = create_event.room_id
else:
room_id = await self._generate_and_create_room_id(
creator_id=user_id,
is_public=is_public,
room_version=room_version,
)
# Check whether this visibility value is blocked by a third party module
allowed_by_third_party_rules = await (
@@ -947,11 +1060,6 @@ class RoomCreationHandler:
for val in raw_initial_state:
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
creation_content = config.get("creation_content", {})
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
(
last_stream_id,
last_sent_event_id,
@@ -968,6 +1076,7 @@ class RoomCreationHandler:
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
ignore_forced_encryption=ignore_forced_encryption,
creation_event_with_context=creation_event_with_context,
)
# we avoid dropping the lock between invites, as otherwise joins can
@@ -1033,6 +1142,38 @@ class RoomCreationHandler:
return room_id, room_alias, last_stream_id
async def _generate_create_event_for_room_id(
self,
creator: Requester,
creation_content: JsonDict,
is_public: bool,
room_version: RoomVersion,
) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
(
creation_event,
new_unpersisted_context,
) = await self.event_creation_handler.create_event(
creator,
{
"content": creation_content,
"sender": creator.user.to_string(),
"type": EventTypes.Create,
"state_key": "",
},
prev_event_ids=[],
depth=1,
state_map={},
for_batch=False,
)
await self.store.store_room(
room_id=creation_event.room_id,
room_creator_user_id=creator.user.to_string(),
is_public=is_public,
room_version=room_version,
)
creation_context = await new_unpersisted_context.persist(creation_event)
return (creation_event, creation_context)
async def _send_events_for_new_room(
self,
creator: Requester,
@@ -1046,6 +1187,9 @@ class RoomCreationHandler:
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
ignore_forced_encryption: bool = False,
creation_event_with_context: Optional[
Tuple[EventBase, synapse.events.snapshot.EventContext]
] = None,
) -> Tuple[int, str, int]:
"""Sends the initial events into a new room. Sends the room creation, membership,
and power level events into the room sequentially, then creates and batches up the
@@ -1082,7 +1226,10 @@ class RoomCreationHandler:
user in this room.
ignore_forced_encryption:
Ignore encryption forced by `encryption_enabled_by_default_for_room_type` setting.
creation_event_with_context:
Set in MSC4291 rooms where the create event determines the room ID. If provided,
does not create an additional create event but instead appends the remaining new
events onto the provided create event.
Returns:
A tuple containing the stream ID, event ID and depth of the last
event sent to the room.
@@ -1147,13 +1294,26 @@ class RoomCreationHandler:
preset_config, config = self._room_preset_config(room_config)
# MSC2175 removes the creator field from the create event.
if not room_version.implicit_room_creator:
creation_content["creator"] = creator_id
creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
)
creation_context = await unpersisted_creation_context.persist(creation_event)
if creation_event_with_context is None:
# MSC2175 removes the creator field from the create event.
if not room_version.implicit_room_creator:
creation_content["creator"] = creator_id
creation_event, unpersisted_creation_context = await create_event(
EventTypes.Create, creation_content, False
)
creation_context = await unpersisted_creation_context.persist(
creation_event
)
else:
(creation_event, creation_context) = creation_event_with_context
# we had to do the above already in order to have a room ID, so just updates local vars
# and continue.
depth = 2
prev_event = [creation_event.event_id]
state_map[(creation_event.type, creation_event.state_key)] = (
creation_event.event_id
)
logger.debug("Sending %s in new room", EventTypes.Member)
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
@@ -1202,7 +1362,9 @@ class RoomCreationHandler:
# Please update the docs for `default_power_level_content_override` when
# updating the `events` dict below
power_level_content: JsonDict = {
"users": {creator_id: 100},
"users": {creator_id: 100}
if not room_version.msc4289_creator_power_enabled
else {},
"users_default": 0,
"events": {
EventTypes.Name: 50,
@@ -1210,7 +1372,9 @@ class RoomCreationHandler:
EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50,
EventTypes.Tombstone: 100,
EventTypes.Tombstone: 150
if room_version.msc4289_creator_power_enabled
else 100,
EventTypes.ServerACL: 100,
EventTypes.RoomEncryption: 100,
},
@@ -1223,7 +1387,13 @@ class RoomCreationHandler:
"historical": 100,
}
if config["original_invitees_have_ops"]:
# original_invitees_have_ops is set on preset:trusted_private_chat which will already
# have set these users as additional_creators, hence don't set the PL for creators as
# that is invalid.
if (
config["original_invitees_have_ops"]
and not room_version.msc4289_creator_power_enabled
):
for invitee in invite_list:
power_level_content["users"][invitee] = 100
@@ -1396,6 +1566,19 @@ class RoomCreationHandler:
)
return preset_name, preset_config
def _remove_creators_from_pl_users_map(
self,
users_map: Dict[str, int],
creator: str,
additional_creators: Optional[List[str]],
) -> None:
creators = [creator]
if additional_creators:
creators.extend(additional_creators)
for creator in creators:
# the creator(s) cannot be in the users map
users_map.pop(creator, None)
def _generate_room_id(self) -> str:
"""Generates a random room ID.
@@ -1413,7 +1596,7 @@ class RoomCreationHandler:
A random room ID of the form "!opaque_id:domain".
"""
random_string = stringutils.random_string(18)
return RoomID(random_string, self.hs.hostname).to_string()
return RoomIdWithDomain(random_string, self.hs.hostname).to_string()
async def _generate_and_create_room_id(
self,

View File

@@ -42,7 +42,7 @@ from synapse.api.errors import (
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events import EventBase, is_creator
from synapse.events.snapshot import EventContext
from synapse.handlers.pagination import PURGE_ROOM_ACTION_NAME
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
@@ -1154,9 +1154,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.KNOCK:
if not is_host_in_room:
# The knock needs to be sent over federation instead
remote_room_hosts.append(get_domain_from_id(room_id))
# we used to add the domain of the room ID to remote_room_hosts.
# This is not safe in MSC4291 rooms which do not have a domain.
content["membership"] = Membership.KNOCK
try:
@@ -2313,6 +2312,7 @@ def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[s
# Check which members are able to invite by ensuring they're joined and have
# the necessary power level.
create_event = auth_events[(EventTypes.Create, "")]
for (event_type, state_key), event in auth_events.items():
if event_type != EventTypes.Member:
continue
@@ -2320,8 +2320,12 @@ def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[s
if event.membership != Membership.JOIN:
continue
if create_event.room_version.msc4289_creator_power_enabled and is_creator(
create_event, state_key
):
result.append(state_key)
# Check if the user has a custom power level.
if users.get(state_key, users_default_level) >= invite_level:
elif users.get(state_key, users_default_level) >= invite_level:
result.append(state_key)
return result

View File

@@ -627,6 +627,15 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
]
admin_users.sort(key=lambda user: user_power[user])
if create_event.room_version.msc4289_creator_power_enabled:
creators = create_event.content.get("additional_creators", []) + [
create_event.sender
]
for creator in creators:
if self.is_mine_id(creator):
# include the creator as they won't be in the PL users map.
admin_users.insert(0, creator)
if not admin_users:
raise SynapseError(
HTTPStatus.BAD_REQUEST, "No local admin user in room"
@@ -666,7 +675,11 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# updated power level event.
new_pl_content = dict(pl_content)
new_pl_content["users"] = dict(pl_content.get("users", {}))
new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
# give the new user the same PL as the admin, default to 100 in case there is no PL event.
# This means in v12+ rooms we get PL100 if the creator promotes us.
new_pl_content["users"][user_to_add] = new_pl_content["users"].get(
admin_user_id, 100
)
fake_requester = create_requester(
admin_user_id,

View File

@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.event_auth import check_valid_additional_creators
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.server import HttpServer
from synapse.http.servlet import (
@@ -85,13 +86,18 @@ class RoomUpgradeRestServlet(RestServlet):
"Your homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
additional_creators = None
if new_version.msc4289_creator_power_enabled:
additional_creators = content.get("additional_creators")
if additional_creators is not None:
check_valid_additional_creators(additional_creators)
try:
async with self._worker_lock_handler.acquire_read_write_lock(
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
):
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
requester, room_id, new_version, additional_creators
)
except ShadowBanError:
# Generate a random room ID.

View File

@@ -53,6 +53,7 @@ from synapse.logging.context import ContextResourceUsage
from synapse.logging.opentracing import tag_args, trace
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.event_federation import StateDifference
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap, StrCollection
from synapse.types.state import StateFilter
@@ -976,17 +977,35 @@ class StateResolutionStore:
)
def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
self,
room_id: str,
state_sets: List[Set[str]],
conflicted_state: Optional[Set[str]],
additional_backwards_reachable_conflicted_events: Optional[Set[str]],
) -> Awaitable[StateDifference]:
""" "Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
This equivalent to fetching the full auth chain for each set of state
This is equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.
If conflicted_state is not None, calculate and return the conflicted sub-graph as per
state res v2.1. The event IDs in the conflicted state MUST be a subset of the event IDs in
state_sets.
If additional_backwards_reachable_conflicted_events is set, the provided events are included
when calculating the conflicted subgraph. This is primarily useful for calculating the
subgraph across a combination of persisted and unpersisted events.
Returns:
An awaitable that resolves to a set of event IDs.
information on the auth chain difference, and also the conflicted subgraph if
conflicted_state is not None
"""
return self.main_store.get_auth_chain_difference(room_id, state_sets)
return self.main_store.get_auth_chain_difference_extended(
room_id,
state_sets,
conflicted_state,
additional_backwards_reachable_conflicted_events,
)

View File

@@ -39,10 +39,11 @@ from typing import (
)
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.constants import CREATOR_POWER_LEVEL, EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.api.room_versions import RoomVersion, StateResolutionVersions
from synapse.events import EventBase, is_creator
from synapse.storage.databases.main.event_federation import StateDifference
from synapse.types import MutableStateMap, StateMap, StrCollection
logger = logging.getLogger(__name__)
@@ -63,8 +64,12 @@ class StateResolutionStore(Protocol):
) -> Awaitable[Dict[str, EventBase]]: ...
def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]: ...
self,
room_id: str,
state_sets: List[Set[str]],
conflicted_state: Optional[Set[str]],
additional_backwards_reachable_conflicted_events: Optional[set[str]],
) -> Awaitable[StateDifference]: ...
# We want to await to the reactor occasionally during state res when dealing
@@ -123,12 +128,17 @@ async def resolve_events_with_store(
logger.debug("%d conflicted state entries", len(conflicted_state))
logger.debug("Calculating auth chain difference")
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
conflicted_set: Optional[Set[str]] = None
if room_version.state_res == StateResolutionVersions.V2_1:
# calculate the conflicted subgraph
conflicted_set = set(itertools.chain.from_iterable(conflicted_state.values()))
auth_diff = await _get_auth_chain_difference(
room_id, state_sets, event_map, state_res_store
room_id,
state_sets,
event_map,
state_res_store,
conflicted_set,
)
full_conflicted_set = set(
itertools.chain(
itertools.chain.from_iterable(conflicted_state.values()), auth_diff
@@ -168,15 +178,26 @@ async def resolve_events_with_store(
logger.debug("sorted %d power events", len(sorted_power_events))
# v2.1 starts iterative auth checks from the empty set and not the unconflicted state.
# It relies on IAC behaviour which populates the base state with the events from auth_events
# if the state tuple is missing from the base state. This ensures the base state is only
# populated from auth_events rather than whatever the unconflicted state is (which could be
# completely bogus).
base_state = (
{}
if room_version.state_res == StateResolutionVersions.V2_1
else unconflicted_state
)
# Now sequentially auth each one
resolved_state = await _iterative_auth_checks(
clock,
room_id,
room_version,
sorted_power_events,
unconflicted_state,
event_map,
state_res_store,
event_ids=sorted_power_events,
base_state=base_state,
event_map=event_map,
state_res_store=state_res_store,
)
logger.debug("resolved power events")
@@ -239,13 +260,23 @@ async def _get_power_level_for_sender(
event = await _get_event(room_id, event_id, event_map, state_res_store)
pl = None
create = None
for aid in event.auth_event_ids():
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
pl = aev
break
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
create = aev
if event.type != EventTypes.Create:
# we should always have a create event
assert create is not None
if create and create.room_version.msc4289_creator_power_enabled:
if is_creator(create, event.sender):
return CREATOR_POWER_LEVEL
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
@@ -286,6 +317,7 @@ async def _get_auth_chain_difference(
state_sets: Sequence[StateMap[str]],
unpersisted_events: Dict[str, EventBase],
state_res_store: StateResolutionStore,
conflicted_state: Optional[Set[str]],
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some, but not all of the auth chains.
@@ -294,11 +326,18 @@ async def _get_auth_chain_difference(
state_sets: The input state sets we are trying to resolve across.
unpersisted_events: A map from event ID to EventBase containing all unpersisted
events involved in this resolution.
state_res_store:
state_res_store: A way to retrieve events and extract graph information on the auth chains.
conflicted_state: which event IDs are conflicted. Used in v2.1 for calculating the conflicted
subgraph.
Returns:
The auth difference of the given state sets, as a set of event IDs.
The auth difference of the given state sets, as a set of event IDs. Also includes the
conflicted subgraph if `conflicted_state` is set.
"""
is_state_res_v21 = conflicted_state is not None
num_conflicted_state = (
len(conflicted_state) if conflicted_state is not None else None
)
# The `StateResolutionStore.get_auth_chain_difference` function assumes that
# all events passed to it (and their auth chains) have been persisted
@@ -318,14 +357,19 @@ async def _get_auth_chain_difference(
# the event's auth chain with the events in `unpersisted_events` *plus* their
# auth event IDs.
events_to_auth_chain: Dict[str, Set[str]] = {}
# remember the forward links when doing the graph traversal, we'll need it for v2.1 checks
# This is a map from an event to the set of events that contain it as an auth event.
event_to_next_event: Dict[str, Set[str]] = {}
for event in unpersisted_events.values():
chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain
to_search = [event]
while to_search:
for auth_id in to_search.pop().auth_event_ids():
next_event = to_search.pop()
for auth_id in next_event.auth_event_ids():
chain.add(auth_id)
event_to_next_event.setdefault(auth_id, set()).add(next_event.event_id)
auth_event = unpersisted_events.get(auth_id)
if auth_event:
to_search.append(auth_event)
@@ -335,6 +379,8 @@ async def _get_auth_chain_difference(
#
# Note: If there are no `unpersisted_events` (which is the common case), we can do a
# much simpler calculation.
additional_backwards_reachable_conflicted_events: Set[str] = set()
unpersisted_conflicted_events: Set[str] = set()
if unpersisted_events:
# The list of state sets to pass to the store, where each state set is a set
# of the event ids making up the state. This is similar to `state_sets`,
@@ -372,7 +418,16 @@ async def _get_auth_chain_difference(
)
else:
set_ids.add(event_id)
if conflicted_state:
for conflicted_event_id in conflicted_state:
# presence in this map means it is unpersisted.
event_chain = events_to_auth_chain.get(conflicted_event_id)
if event_chain is not None:
unpersisted_conflicted_events.add(conflicted_event_id)
# tell the DB layer that we have some unpersisted conflicted events
additional_backwards_reachable_conflicted_events.update(
e for e in event_chain if e not in unpersisted_events
)
# The auth chain difference of the unpersisted events of the state sets
# is calculated by taking the difference between the union and
# intersections.
@@ -384,12 +439,89 @@ async def _get_auth_chain_difference(
auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
difference = await state_res_store.get_auth_chain_difference(
room_id, state_sets_ids
)
difference.update(auth_difference_unpersisted_part)
if conflicted_state:
# to ensure that conflicted state is a subset of state set IDs, we need to remove UNPERSISTED
# conflicted state set ids as we removed them above.
conflicted_state = conflicted_state - unpersisted_conflicted_events
return difference
difference = await state_res_store.get_auth_chain_difference(
room_id,
state_sets_ids,
conflicted_state,
additional_backwards_reachable_conflicted_events,
)
difference.auth_difference.update(auth_difference_unpersisted_part)
# if we're doing v2.1 we may need to add or expand the conflicted subgraph
if (
is_state_res_v21
and difference.conflicted_subgraph is not None
and unpersisted_events
):
# we always include the conflicted events themselves in the subgraph.
if conflicted_state:
difference.conflicted_subgraph.update(conflicted_state)
# we may need to expand the subgraph in the case where the subgraph starts in the DB and
# ends in unpersisted events. To do this, we first need to see where the subgraph got up to,
# which we can do by finding the intersection between the additional backwards reachable
# conflicted events and the conflicted subgraph. Events in both sets mean A) some unpersisted
# conflicted event could backwards reach it and B) some persisted conflicted event could forward
# reach it.
subgraph_frontier = difference.conflicted_subgraph.intersection(
additional_backwards_reachable_conflicted_events
)
# we can now combine the 2 scenarios:
# - subgraph starts in DB and ends in unpersisted
# - subgraph starts in unpersisted and ends in unpersisted
# by expanding the frontier into unpersisted events.
# The frontier is currently all persisted events. We want to expand this into unpersisted
# events. Mark every forwards reachable event from the frontier in the forwards_conflicted_set
# but NOT the backwards conflicted set. This mirrors what the DB layer does but in reverse:
# we supplied events which are backwards reachable to the DB and now the DB is providing
# forwards reachable events from the DB.
forwards_conflicted_set: Set[str] = set()
# we include unpersisted conflicted events here to process exclusive unpersisted subgraphs
search_queue = subgraph_frontier.union(unpersisted_conflicted_events)
while search_queue:
frontier_event = search_queue.pop()
next_event_ids = event_to_next_event.get(frontier_event, set())
search_queue.update(next_event_ids)
forwards_conflicted_set.add(frontier_event)
# we've already calculated the backwards form as this is the auth chain for each
# unpersisted conflicted event.
backwards_conflicted_set: Set[str] = set()
for uce in unpersisted_conflicted_events:
backwards_conflicted_set.update(events_to_auth_chain.get(uce, []))
# the unpersisted conflicted subgraph is the intersection of the backwards/forwards sets
conflicted_subgraph_unpersisted_part = backwards_conflicted_set.intersection(
forwards_conflicted_set
)
# print(f"event_to_next_event={event_to_next_event}")
# print(f"unpersisted_conflicted_events={unpersisted_conflicted_events}")
# print(f"unperssited backwards_conflicted_set={backwards_conflicted_set}")
# print(f"unperssited forwards_conflicted_set={forwards_conflicted_set}")
difference.conflicted_subgraph.update(conflicted_subgraph_unpersisted_part)
if difference.conflicted_subgraph:
old_events = difference.auth_difference.union(
conflicted_state if conflicted_state else set()
)
additional_events = difference.conflicted_subgraph.difference(old_events)
logger.debug(
"v2.1 %s additional events replayed=%d num_conflicts=%d conflicted_subgraph=%d auth_difference=%d",
room_id,
len(additional_events),
num_conflicted_state,
len(difference.conflicted_subgraph),
len(difference.auth_difference),
)
# State res v2.1 includes the conflicted subgraph in the difference
return difference.auth_difference.union(difference.conflicted_subgraph)
return difference.auth_difference
def _seperate(

View File

@@ -110,6 +110,12 @@ _LONGEST_BACKOFF_PERIOD_MILLISECONDS = (
assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1)
# We use 2^53-1 as a "very large number", it has no particular
# importance other than knowing synapse can support it (given canonical json
# requires it).
MAX_CHAIN_LENGTH = (2**53) - 1
# All the info we need while iterating the DAG while backfilling
@attr.s(frozen=True, slots=True, auto_attribs=True)
class BackfillQueueNavigationItem:
@@ -119,6 +125,14 @@ class BackfillQueueNavigationItem:
type: str
@attr.s(frozen=True, slots=True, auto_attribs=True)
class StateDifference:
# The event IDs in the auth difference.
auth_difference: Set[str]
# The event IDs in the conflicted state subgraph. Used in v2.1 only.
conflicted_subgraph: Optional[Set[str]]
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -467,17 +481,41 @@ class EventFederationWorkerStore(
return results
async def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
self,
room_id: str,
state_sets: List[Set[str]],
) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
state_diff = await self.get_auth_chain_difference_extended(
room_id, state_sets, None, None
)
return state_diff.auth_difference
async def get_auth_chain_difference_extended(
self,
room_id: str,
state_sets: List[Set[str]],
conflicted_set: Optional[Set[str]],
additional_backwards_reachable_conflicted_events: Optional[Set[str]],
) -> StateDifference:
""" "Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
This equivalent to fetching the full auth chain for each set of state
This is equivalent to fetching the full auth chain for each set of state
and returning the events that don't appear in each and every auth
chain.
If conflicted_set is not None, calculate and return the conflicted sub-graph as per
state res v2.1. The event IDs in the conflicted set MUST be a subset of the event IDs in
state_sets.
If additional_backwards_reachable_conflicted_events is set, the provided events are included
when calculating the conflicted subgraph. This is primarily useful for calculating the
subgraph across a combination of persisted and unpersisted events. The event IDs in this set
MUST be a subset of the event IDs in state_sets.
Returns:
The set of the difference in auth chains.
information on the auth chain difference, and also the conflicted subgraph if
conflicted_set is not None
"""
# Check if we have indexed the room so we can use the chain cover
@@ -491,6 +529,8 @@ class EventFederationWorkerStore(
self._get_auth_chain_difference_using_cover_index_txn,
room_id,
state_sets,
conflicted_set,
additional_backwards_reachable_conflicted_events,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
@@ -499,25 +539,48 @@ class EventFederationWorkerStore(
if not self.tests_allow_no_chain_cover_index:
raise
return await self.db_pool.runInteraction(
# It's been 4 years since we added chain cover, so we expect all rooms to have it.
# If they don't, we will error out when trying to do state res v2.1
if conflicted_set is not None:
raise _NoChainCoverIndex(room_id)
auth_diff = await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
)
return StateDifference(auth_difference=auth_diff, conflicted_subgraph=None)
def _get_auth_chain_difference_using_cover_index_txn(
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
self,
txn: LoggingTransaction,
room_id: str,
state_sets: List[Set[str]],
conflicted_set: Optional[Set[str]] = None,
additional_backwards_reachable_conflicted_events: Optional[Set[str]] = None,
) -> StateDifference:
"""Calculates the auth chain difference using the chain index.
See docs/auth_chain_difference_algorithm.md for details
"""
is_state_res_v21 = conflicted_set is not None
# First we look up the chain ID/sequence numbers for all the events, and
# work out the chain/sequence numbers reachable from each state set.
initial_events = set(state_sets[0]).union(*state_sets[1:])
if is_state_res_v21:
# Sanity check v2.1 fields
assert conflicted_set is not None
assert conflicted_set.issubset(initial_events)
# It's possible for the conflicted_set to be empty if all the conflicts are in
# unpersisted events, so we don't assert that conflicted_set has len > 0
if additional_backwards_reachable_conflicted_events:
assert additional_backwards_reachable_conflicted_events.issubset(
initial_events
)
# Map from event_id -> (chain ID, seq no)
chain_info: Dict[str, Tuple[int, int]] = {}
@@ -553,14 +616,14 @@ class EventFederationWorkerStore(
events_missing_chain_info = initial_events.difference(chain_info)
# The result set to return, i.e. the auth chain difference.
result: Set[str] = set()
auth_difference_result: Set[str] = set()
if events_missing_chain_info:
# For some reason we have events we haven't calculated the chain
# index for, so we need to handle those separately. This should only
# happen for older rooms where the server doesn't have all the auth
# events.
result = self._fixup_auth_chain_difference_sets(
auth_difference_result = self._fixup_auth_chain_difference_sets(
txn,
room_id,
state_sets=state_sets,
@@ -579,6 +642,45 @@ class EventFederationWorkerStore(
fetch_chain_info(new_events_to_fetch)
# State Res v2.1 needs extra data structures to calculate the conflicted subgraph which
# are outlined below.
# A subset of chain_info for conflicted events only, as we need to
# loop all conflicted chain positions. Map from event_id -> (chain ID, seq no)
conflicted_chain_positions: Dict[str, Tuple[int, int]] = {}
# For each chain, remember the positions where conflicted events are.
# We need this for calculating the forward reachable events.
conflicted_chain_to_seq: Dict[int, Set[int]] = {} # chain_id => {seq_num}
# A subset of chain_info for additional backwards reachable events only, as we need to
# loop all additional backwards reachable events for calculating backwards reachable events.
additional_backwards_reachable_positions: Dict[
str, Tuple[int, int]
] = {} # event_id => (chain_id, seq_num)
# These next two fields are critical as the intersection of them is the conflicted subgraph.
# We'll populate them when we walk the chain links.
# chain_id => max(seq_num) backwards reachable (e.g 4 means 1,2,3,4 are backwards reachable)
conflicted_backwards_reachable: Dict[int, int] = {}
# chain_id => min(seq_num) forwards reachable (e.g 4 means 4,5,6..n are forwards reachable)
conflicted_forwards_reachable: Dict[int, int] = {}
# populate the v2.1 data structures
if is_state_res_v21:
assert conflicted_set is not None
# provide chain positions for each conflicted event
for conflicted_event_id in conflicted_set:
(chain_id, seq_num) = chain_info[conflicted_event_id]
conflicted_chain_positions[conflicted_event_id] = (chain_id, seq_num)
conflicted_chain_to_seq.setdefault(chain_id, set()).add(seq_num)
if additional_backwards_reachable_conflicted_events:
for (
additional_event_id
) in additional_backwards_reachable_conflicted_events:
(chain_id, seq_num) = chain_info[additional_event_id]
additional_backwards_reachable_positions[additional_event_id] = (
chain_id,
seq_num,
)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
set_to_chain: List[Dict[int, int]] = []
@@ -596,6 +698,8 @@ class EventFederationWorkerStore(
# (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)):
# `links` encodes the backwards reachable events _from a single chain_ all the way to
# the root of the graph.
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:
@@ -604,6 +708,87 @@ class EventFederationWorkerStore(
_materialize(chain_id, chains[chain_id], links, chains)
seen_chains.update(chains)
if is_state_res_v21:
# Apply v2.1 conflicted event reachability checks.
#
# A <-- B <-- C <-- D <-- E
#
# Backwards reachable from C = {A,B}
# Forwards reachable from C = {D,E}
# this handles calculating forwards reachable information and updates
# conflicted_forwards_reachable.
accumulate_forwards_reachable_events(
conflicted_forwards_reachable,
links,
conflicted_chain_positions,
)
# handle backwards reachable information
for (
conflicted_chain_id,
conflicted_chain_seq,
) in conflicted_chain_positions.values():
if conflicted_chain_id not in links:
# This conflicted event does not lie on the path to the root.
continue
# The conflicted chain position itself encodes reachability information
# _within_ the chain. Set it now before walking to other links.
conflicted_backwards_reachable[conflicted_chain_id] = max(
conflicted_chain_seq,
conflicted_backwards_reachable.get(conflicted_chain_id, 0),
)
# Build backwards reachability paths. This is the same as what the auth difference
# code does. We find which chain the conflicted event
# belongs to then walk it backwards to the root. We store reachability info
# for all conflicted events in the same map 'conflicted_backwards_reachable'
# as we don't care about the paths themselves.
_materialize(
conflicted_chain_id,
conflicted_chain_seq,
links,
conflicted_backwards_reachable,
)
# Mark some extra events as backwards reachable. This is used when we have some
# unpersisted events and want to know the subgraph across the persisted/unpersisted
# boundary:
# |
# A <-- B <-- C <-|- D <-- E <-- F
# persisted | unpersisted
#
# Assume {B,E} are conflicted, we want to return {B,C,D,E}
#
# The unpersisted code ensures it passes C as an additional backwards reachable
# event. C is NOT a conflicted event, but we do need to consider it as part of
# the backwards reachable set. When we then calculate the forwards reachable set
# from B, C will be in both the backwards and forwards reachable sets and hence
# will be included in the conflicted subgraph.
for (
additional_chain_id,
additional_chain_seq,
) in additional_backwards_reachable_positions.values():
if additional_chain_id not in links:
# The additional backwards reachable event does not lie on the path to the root.
continue
# the additional event chain position itself encodes reachability information.
# It means that position and all positions earlier in that chain are backwards reachable
# by some unpersisted conflicted event.
conflicted_backwards_reachable[additional_chain_id] = max(
additional_chain_seq,
conflicted_backwards_reachable.get(additional_chain_id, 0),
)
# Now walk the chains back, marking backwards reachable events.
# This is the same thing we do for auth difference / conflicted events.
_materialize(
additional_chain_id, # walk all links back, marking them as backwards reachable
additional_chain_seq,
links,
conflicted_backwards_reachable,
)
# Now for each chain we figure out the maximum sequence number reachable
# from *any* state set and the minimum sequence number reachable from
@@ -612,7 +797,7 @@ class EventFederationWorkerStore(
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
chain_to_gap: Dict[int, Tuple[int, int]] = {}
auth_diff_chain_to_gap: Dict[int, Tuple[int, int]] = {}
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
@@ -625,15 +810,76 @@ class EventFederationWorkerStore(
for seq_no in range(min_seq_no + 1, max_seq_no + 1):
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
if event_id:
result.add(event_id)
auth_difference_result.add(event_id)
else:
chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
auth_diff_chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
break
if not chain_to_gap:
# If there are no gaps to fetch, we're done!
return result
conflicted_subgraph_result: Set[str] = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
conflicted_subgraph_chain_to_gap: Dict[int, Tuple[int, int]] = {}
if is_state_res_v21:
# also include the conflicted subgraph using backward/forward reachability info from all
# the conflicted events. To calculate this, we want to extract the intersection between
# the backwards and forwards reachability sets, e.g:
# A <- B <- C <- D <- E
# Assume B and D are conflicted so we want {C} as the conflicted subgraph.
# B_backwards={A}, B_forwards={C,D,E}
# D_backwards={A,B,C} D_forwards={E}
# ALL_backwards={A,B,C} ALL_forwards={C,D,E}
# Intersection(ALL_backwards, ALL_forwards) = {C}
#
# It's worth noting that once we have the ALL_ sets, we no longer care about the paths.
# We're dealing with chains and not singular events, but we've already got the ALL_ sets.
# As such, we can inspect each chain in isolation and check for overlapping sequence
# numbers:
# 1,2,3,4,5 Seq Num
# Chain N [A,B,C,D,E]
#
# if (N,4) is in the backwards set and (N,2) is in the forwards set, then the
# intersection is events between 2 < 4. We will include the conflicted events themselves
# in the subgraph, but they will already be, hence the full set of events is {B,C,D}.
for chain_id, backwards_seq_num in conflicted_backwards_reachable.items():
forwards_seq_num = conflicted_forwards_reachable.get(chain_id)
if forwards_seq_num is None:
continue # this chain isn't in both sets so can't intersect
if forwards_seq_num > backwards_seq_num:
continue # this chain is in both sets but they don't overap
for seq_no in range(
forwards_seq_num, backwards_seq_num + 1
): # inclusive of both
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
if event_id:
conflicted_subgraph_result.add(event_id)
else:
conflicted_subgraph_chain_to_gap[chain_id] = (
# _fetch_event_ids_from_chains_txn is exclusive of the min value
forwards_seq_num - 1,
backwards_seq_num,
)
break
if auth_diff_chain_to_gap:
auth_difference_result.update(
self._fetch_event_ids_from_chains_txn(txn, auth_diff_chain_to_gap)
)
if conflicted_subgraph_chain_to_gap:
conflicted_subgraph_result.update(
self._fetch_event_ids_from_chains_txn(
txn, conflicted_subgraph_chain_to_gap
)
)
return StateDifference(
auth_difference=auth_difference_result,
conflicted_subgraph=conflicted_subgraph_result,
)
def _fetch_event_ids_from_chains_txn(
self, txn: LoggingTransaction, chains: Dict[int, Tuple[int, int]]
) -> Set[str]:
result: Set[str] = set()
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
# using postgres.
@@ -647,7 +893,7 @@ class EventFederationWorkerStore(
args = [
(chain_id, min_no, max_no)
for chain_id, (min_no, max_no) in chain_to_gap.items()
for chain_id, (min_no, max_no) in chains.items()
]
rows = txn.execute_values(sql, args)
@@ -658,10 +904,9 @@ class EventFederationWorkerStore(
SELECT event_id FROM event_auth_chains
WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
"""
for chain_id, (min_no, max_no) in chain_to_gap.items():
for chain_id, (min_no, max_no) in chains.items():
txn.execute(sql, (chain_id, min_no, max_no))
result.update(r for (r,) in txn)
return result
def _fixup_auth_chain_difference_sets(
@@ -2155,6 +2400,7 @@ def _materialize(
origin_sequence_number: int,
links: Dict[int, List[Tuple[int, int, int]]],
materialized: Dict[int, int],
backwards: bool = True,
) -> None:
"""Helper function for fetching auth chain links. For a given origin chain
ID / sequence number and a dictionary of links, updates the materialized
@@ -2171,6 +2417,7 @@ def _materialize(
target sequence number.
materialized: dict to update with new reachability information, as a
map from chain ID to max sequence number reachable.
backwards: If True, walks backwards down the chains. If False, walks forwards from the chains.
"""
# Do a standard graph traversal.
@@ -2185,12 +2432,104 @@ def _materialize(
target_chain_id,
target_sequence_number,
) in chain_links:
# Ignore any links that are higher up the chain
if sequence_number > s:
continue
if backwards:
# Ignore any links that are higher up the chain
if sequence_number > s:
continue
# Check if we have already visited the target chain before, if so we
# can skip it.
if materialized.get(target_chain_id, 0) < target_sequence_number:
stack.append((target_chain_id, target_sequence_number))
materialized[target_chain_id] = target_sequence_number
# Check if we have already visited the target chain before, if so we
# can skip it.
if materialized.get(target_chain_id, 0) < target_sequence_number:
stack.append((target_chain_id, target_sequence_number))
materialized[target_chain_id] = target_sequence_number
else:
# Ignore any links that are lower down the chain.
if sequence_number < s:
continue
# Check if we have already visited the target chain before, if so we
# can skip it.
if (
materialized.get(target_chain_id, MAX_CHAIN_LENGTH)
> target_sequence_number
):
stack.append((target_chain_id, target_sequence_number))
materialized[target_chain_id] = target_sequence_number
def _generate_forward_links(
links: Dict[int, List[Tuple[int, int, int]]],
) -> Dict[int, List[Tuple[int, int, int]]]:
"""Reverse the input links from the given backwards links"""
new_links: Dict[int, List[Tuple[int, int, int]]] = {}
for origin_chain_id, chain_links in links.items():
for origin_seq_num, target_chain_id, target_seq_num in chain_links:
new_links.setdefault(target_chain_id, []).append(
(target_seq_num, origin_chain_id, origin_seq_num)
)
return new_links
def accumulate_forwards_reachable_events(
conflicted_forwards_reachable: Dict[int, int],
back_links: Dict[int, List[Tuple[int, int, int]]],
conflicted_chain_positions: Dict[str, Tuple[int, int]],
) -> None:
"""Accumulate new forwards reachable events using the back_links provided.
Accumulating forwards reachable information is quite different from backwards reachable information
because _get_chain_links returns the entire linkage information for backwards reachable events,
but not _forwards_ reachable events. We are only interested in the forwards reachable information
that is encoded in the backwards reachable links, so we can just invert all the operations we do
for backwards reachable events to calculate a subset of forwards reachable information. The
caveat with this approach is that it is a _subset_. This means new back_links may encode new
forwards reachable information which we also need. Consider this scenario:
A <-- B <-- C <--- D <-- E <-- F Chain 1
|
`----- G <-- H <-- I Chain 2
|
`---- J <-- K Chain 3
Now consider what happens when B is a conflicted event. _get_chain_links returns the conflicted
chain and ALL links heading towards the root of the graph. This means we will know the
Chain 1 to Chain 2 link via C (as all links for the chain are returned, not strictly ones with
a lower sequence number), but we will NOT know the Chain 2 to Chain 3 link via H. We can be
blissfully unaware of Chain 3 entirely, if and only if there isn't some other conflicted event
on that chain. Consider what happens when K is /also/ conflicted. _get_chain_links will generate
two iterations: one for B and one for K. It's important that we re-evaluate the forwards reachable
information for B to include Chain 3 when we process the K iteration, hence we are "accumulating"
forwards reachability information.
NB: We don't consider 'additional backwards reachable events' here because they have no effect
on forwards reachability calculations, only backwards.
Args:
conflicted_forwards_reachable: The materialised dict of forwards reachable information.
The output to this function are stored here.
back_links: One iteration of _get_chain_links which encodes backwards reachable information.
conflicted_chain_positions: The conflicted events.
"""
# links go backwards but we want them to go forwards as well for v2.1
fwd_links = _generate_forward_links(back_links)
# for each conflicted event, accumulate forwards reachability information
for (
conflicted_chain_id,
conflicted_chain_seq,
) in conflicted_chain_positions.values():
# the conflicted event itself encodes reachability information
# e.g if D was conflicted, it encodes E,F as forwards reachable.
conflicted_forwards_reachable[conflicted_chain_id] = min(
conflicted_chain_seq,
conflicted_forwards_reachable.get(conflicted_chain_id, MAX_CHAIN_LENGTH),
)
# Walk from the conflicted event forwards to explore the links.
# This function checks if we've visited the chain before and skips reprocessing, so this
# does not repeatedly traverse the graph.
_materialize(
conflicted_chain_id,
conflicted_chain_seq,
fwd_links,
conflicted_forwards_reachable,
backwards=False,
)

View File

@@ -354,12 +354,78 @@ class RoomAlias(DomainSpecificString):
@attr.s(slots=True, frozen=True, repr=False)
class RoomID(DomainSpecificString):
"""Structure representing a room id."""
class RoomIdWithDomain(DomainSpecificString):
"""Structure representing a room ID with a domain suffix."""
SIGIL = "!"
# the set of urlsafe base64 characters, no padding.
ROOM_ID_PATTERN_DOMAINLESS = re.compile(r"^[A-Za-z0-9\-_]{43}$")
@attr.define(slots=True, frozen=True, repr=False)
class RoomID:
"""Structure representing a room id without a domain.
There are two forms of room IDs:
- "!localpart:domain" used in most room versions prior to MSC4291.
- "!event_id_base_64" used in room versions post MSC4291.
This class will accept any room ID which meets either of these two criteria.
"""
SIGIL = "!"
id: str
room_id_with_domain: Optional[RoomIdWithDomain]
@classmethod
def is_valid(cls: Type["RoomID"], s: str) -> bool:
if ":" in s:
return RoomIdWithDomain.is_valid(s)
try:
cls.from_string(s)
return True
except Exception:
return False
def get_domain(self) -> Optional[str]:
if not self.room_id_with_domain:
return None
return self.room_id_with_domain.domain
def to_string(self) -> str:
if self.room_id_with_domain:
return self.room_id_with_domain.to_string()
return self.id
__repr__ = to_string
@classmethod
def from_string(cls: Type["RoomID"], s: str) -> "RoomID":
# sigil check
if len(s) < 1 or s[0] != cls.SIGIL:
raise SynapseError(
400,
"Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL),
Codes.INVALID_PARAM,
)
room_id_with_domain: Optional[RoomIdWithDomain] = None
if ":" in s:
room_id_with_domain = RoomIdWithDomain.from_string(s)
else:
# MSC4291 room IDs must be valid urlsafe unpadded base64
val = s[1:]
if not ROOM_ID_PATTERN_DOMAINLESS.match(val):
raise SynapseError(
400,
"Expected %s string to be valid urlsafe unpadded base64 '%s'"
% (cls.__name__, val),
Codes.INVALID_PARAM,
)
return cls(id=s, room_id_with_domain=room_id_with_domain)
@attr.s(slots=True, frozen=True, repr=False)
class EventID(DomainSpecificString):
"""Structure representing an event ID which is namespaced to a homeserver.

View File

@@ -33,6 +33,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import EventTypes, Membership, RoomTypes
from synapse.api.errors import Codes
from synapse.api.room_versions import RoomVersions
from synapse.handlers.pagination import (
PURGE_ROOM_ACTION_NAME,
SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME,
@@ -2892,6 +2893,30 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
"No local admin user in room with power to update power levels.",
)
def test_v12_room(self) -> None:
"""Test that you can be promoted to admin in v12 rooms which won't have the admin the PL event."""
room_id = self.helper.create_room_as(
self.creator,
tok=self.creator_tok,
room_version=RoomVersions.V12.identifier,
)
channel = self.make_request(
"POST",
f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
content={},
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
# Now we test that we can join the room and that the admin user has PL 100.
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
pl = self.helper.get_state(
room_id, EventTypes.PowerLevels, tok=self.creator_tok
)
self.assertEquals(pl["users"][self.admin_user], 100)
class BlockRoomTestCase(unittest.HomeserverTestCase):
servlets = [

View File

@@ -45,6 +45,7 @@ from synapse.state.v2 import (
lexicographical_topological_sort,
resolve_events_with_store,
)
from synapse.storage.databases.main.event_federation import StateDifference
from synapse.types import EventID, StateMap
from tests import unittest
@@ -734,7 +735,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
ROOM_ID,
state_sets,
unpersited_events,
store,
None,
)
difference = self.successResultOf(defer.ensureDeferred(diff_d))
@@ -791,7 +796,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
ROOM_ID,
state_sets,
unpersited_events,
store,
None,
)
difference = self.successResultOf(defer.ensureDeferred(diff_d))
@@ -858,7 +867,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
store = TestStateResolutionStore(persisted_events)
diff_d = _get_auth_chain_difference(
ROOM_ID, state_sets, unpersited_events, store
ROOM_ID,
state_sets,
unpersited_events,
store,
None,
)
difference = self.successResultOf(defer.ensureDeferred(diff_d))
@@ -1070,9 +1083,18 @@ class TestStateResolutionStore:
return list(result)
def get_auth_chain_difference(
self, room_id: str, auth_sets: List[Set[str]]
) -> "defer.Deferred[Set[str]]":
self,
room_id: str,
auth_sets: List[Set[str]],
conflicted_state: Optional[Set[str]],
additional_backwards_reachable_conflicted_events: Optional[Set[str]],
) -> "defer.Deferred[StateDifference]":
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
return defer.succeed(
StateDifference(
auth_difference=set(chains[0]).union(*chains[1:]) - common,
conflicted_subgraph=set(),
),
)

503
tests/state/test_v21.py Normal file
View File

@@ -0,0 +1,503 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import itertools
from typing import Dict, List, Optional, Sequence, Set
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.state import StateResolutionStore
from synapse.state.v2 import (
StateResolutionStore as StateResolutionStoreInterface,
_get_auth_chain_difference,
_seperate,
resolve_events_with_store,
)
from synapse.types import StateMap
from synapse.util import Clock
from tests import unittest
from tests.state.test_v2 import TestStateResolutionStore
ALICE = "@alice:example.com"
BOB = "@bob:example.com"
CHARLIE = "@charlie:example.com"
EVELYN = "@evelyn:example.com"
ZARA = "@zara:example.com"
ROOM_ID = "!test:example.com"
MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
MEMBERSHIP_CONTENT_INVITE = {"membership": Membership.INVITE}
MEMBERSHIP_CONTENT_LEAVE = {"membership": Membership.LEAVE}
ORIGIN_SERVER_TS = 0
def monotonic_timestamp() -> int:
global ORIGIN_SERVER_TS
ORIGIN_SERVER_TS += 1
return ORIGIN_SERVER_TS
class FakeClock:
def sleep(self, msec: float) -> "defer.Deferred[None]":
return defer.succeed(None)
class StateResV21TestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:
self.state = self.hs.get_state_handler()
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self._persistence = persistence
self._state_storage_controller = self.hs.get_storage_controllers().state
self._state_deletion = self.hs.get_datastores().state_deletion
self.store = self.hs.get_datastores().main
self.register_user("user", "pass")
self.token = self.login("user", "pass")
def test_state_reset_replay_conflicted_subgraph(self) -> None:
# 1. Alice creates a room.
e1_create = self.create_event(
EventTypes.Create,
"",
sender=ALICE,
content={"creator": ALICE},
auth_events=[],
)
# 2. Alice joins it.
e2_ma = self.create_event(
EventTypes.Member,
ALICE,
sender=ALICE,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[],
prev_events=[e1_create.event_id],
room_id=e1_create.room_id,
)
# 3. Alice is the creator
e3_power1 = self.create_event(
EventTypes.PowerLevels,
"",
sender=ALICE,
content={"users": {}},
auth_events=[e2_ma.event_id],
room_id=e1_create.room_id,
)
# 4. Alice sets the room to public.
e4_jr = self.create_event(
EventTypes.JoinRules,
"",
sender=ALICE,
content={"join_rule": JoinRules.PUBLIC},
auth_events=[e2_ma.event_id, e3_power1.event_id],
room_id=e1_create.room_id,
)
# 5. Bob joins the room.
e5_mb = self.create_event(
EventTypes.Member,
BOB,
sender=BOB,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[e3_power1.event_id, e4_jr.event_id],
room_id=e1_create.room_id,
)
# 6. Charlie joins the room.
e6_mc = self.create_event(
EventTypes.Member,
CHARLIE,
sender=CHARLIE,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[e3_power1.event_id, e4_jr.event_id],
room_id=e1_create.room_id,
)
# 7. Alice promotes Bob.
e7_power2 = self.create_event(
EventTypes.PowerLevels,
"",
sender=ALICE,
content={"users": {BOB: 50}},
auth_events=[e2_ma.event_id, e3_power1.event_id],
room_id=e1_create.room_id,
)
# 8. Bob promotes Charlie.
e8_power3 = self.create_event(
EventTypes.PowerLevels,
"",
sender=BOB,
content={"users": {BOB: 50, CHARLIE: 50}},
auth_events=[e5_mb.event_id, e7_power2.event_id],
room_id=e1_create.room_id,
)
# 9. Eve joins the room.
e9_me1 = self.create_event(
EventTypes.Member,
EVELYN,
sender=EVELYN,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[e8_power3.event_id, e4_jr.event_id],
room_id=e1_create.room_id,
)
# 10. Eve changes her name, /!\\ but cites old power levels /!\
e10_me2 = self.create_event(
EventTypes.Member,
EVELYN,
sender=EVELYN,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[
e3_power1.event_id,
e4_jr.event_id,
e9_me1.event_id,
],
room_id=e1_create.room_id,
)
# 11. Zara joins the room, citing the most recent power levels.
e11_mz = self.create_event(
EventTypes.Member,
ZARA,
sender=ZARA,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[e8_power3.event_id, e4_jr.event_id],
room_id=e1_create.room_id,
)
# Event 10 above is DODGY: it directly cites old auth events, but indirectly
# cites new ones. If the state after event 10 contains old power level and old
# join events, we are vulnerable to a reset.
dodgy_state_after_eve_rename: StateMap[str] = {
(EventTypes.Create, ""): e1_create.event_id,
(EventTypes.Member, ALICE): e2_ma.event_id,
(EventTypes.Member, BOB): e5_mb.event_id,
(EventTypes.Member, CHARLIE): e6_mc.event_id,
(EventTypes.Member, EVELYN): e10_me2.event_id,
(EventTypes.PowerLevels, ""): e3_power1.event_id, # old and /!\\ DODGY /!\
(EventTypes.JoinRules, ""): e4_jr.event_id,
}
sensible_state_after_zara_joins: StateMap[str] = {
(EventTypes.Create, ""): e1_create.event_id,
(EventTypes.Member, ALICE): e2_ma.event_id,
(EventTypes.Member, BOB): e5_mb.event_id,
(EventTypes.Member, CHARLIE): e6_mc.event_id,
(EventTypes.Member, ZARA): e11_mz.event_id,
(EventTypes.PowerLevels, ""): e8_power3.event_id,
(EventTypes.JoinRules, ""): e4_jr.event_id,
}
expected: StateMap[str] = {
(EventTypes.Create, ""): e1_create.event_id,
(EventTypes.Member, ALICE): e2_ma.event_id,
(EventTypes.Member, BOB): e5_mb.event_id,
(EventTypes.Member, CHARLIE): e6_mc.event_id,
# Expect ME2 replayed first: it's in the POWER 1 epoch
# Then ME1, in the POWER 3 epoch
(EventTypes.Member, EVELYN): e9_me1.event_id,
(EventTypes.Member, ZARA): e11_mz.event_id,
(EventTypes.PowerLevels, ""): e8_power3.event_id,
(EventTypes.JoinRules, ""): e4_jr.event_id,
}
self.get_resolution_and_verify_expected(
[dodgy_state_after_eve_rename, sensible_state_after_zara_joins],
[
e1_create,
e2_ma,
e3_power1,
e4_jr,
e5_mb,
e6_mc,
e7_power2,
e8_power3,
e9_me1,
e10_me2,
e11_mz,
],
expected,
)
def test_state_reset_start_empty_set(self) -> None:
# The join rules reset to missing, when:
# - join rules were in conflict
# - the membership of those join rules' senders were not in conflict
# - those memberships are all leaves.
# 1. Alice creates a room.
e1_create = self.create_event(
EventTypes.Create,
"",
sender=ALICE,
content={"creator": ALICE},
auth_events=[],
)
# 2. Alice joins it.
e2_ma1 = self.create_event(
EventTypes.Member,
ALICE,
sender=ALICE,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[],
room_id=e1_create.room_id,
)
# 3. Alice makes Bob an admin.
e3_power = self.create_event(
EventTypes.PowerLevels,
"",
sender=ALICE,
content={"users": {BOB: 100}},
auth_events=[e2_ma1.event_id],
room_id=e1_create.room_id,
)
# 4. Alice sets the room to public.
e4_jr1 = self.create_event(
EventTypes.JoinRules,
"",
sender=ALICE,
content={"join_rule": JoinRules.PUBLIC},
auth_events=[e2_ma1.event_id, e3_power.event_id],
room_id=e1_create.room_id,
)
# 5. Bob joins.
e5_mb = self.create_event(
EventTypes.Member,
BOB,
sender=BOB,
content=MEMBERSHIP_CONTENT_JOIN,
auth_events=[e3_power.event_id, e4_jr1.event_id],
room_id=e1_create.room_id,
)
# 6. Alice sets join rules to invite.
e6_jr2 = self.create_event(
EventTypes.JoinRules,
"",
sender=ALICE,
content={"join_rule": JoinRules.INVITE},
auth_events=[e2_ma1.event_id, e3_power.event_id],
room_id=e1_create.room_id,
)
# 7. Alice then leaves.
e7_ma2 = self.create_event(
EventTypes.Member,
ALICE,
sender=ALICE,
content=MEMBERSHIP_CONTENT_LEAVE,
auth_events=[e3_power.event_id, e2_ma1.event_id],
room_id=e1_create.room_id,
)
correct_state: StateMap[str] = {
(EventTypes.Create, ""): e1_create.event_id,
(EventTypes.Member, ALICE): e7_ma2.event_id,
(EventTypes.Member, BOB): e5_mb.event_id,
(EventTypes.PowerLevels, ""): e3_power.event_id,
(EventTypes.JoinRules, ""): e6_jr2.event_id,
}
# Imagine that another server gives us incorrect state on a fork
# (via e.g. backfill). It cites the old join rules.
incorrect_state: StateMap[str] = {
(EventTypes.Create, ""): e1_create.event_id,
(EventTypes.Member, ALICE): e7_ma2.event_id,
(EventTypes.Member, BOB): e5_mb.event_id,
(EventTypes.PowerLevels, ""): e3_power.event_id,
(EventTypes.JoinRules, ""): e4_jr1.event_id,
}
# Resolving those two should give us the new join rules.
expected: StateMap[str] = {
(EventTypes.Create, ""): e1_create.event_id,
(EventTypes.Member, ALICE): e7_ma2.event_id,
(EventTypes.Member, BOB): e5_mb.event_id,
(EventTypes.PowerLevels, ""): e3_power.event_id,
(EventTypes.JoinRules, ""): e6_jr2.event_id,
}
self.get_resolution_and_verify_expected(
[correct_state, incorrect_state],
[e1_create, e2_ma1, e3_power, e4_jr1, e5_mb, e6_jr2, e7_ma2],
expected,
)
async def _get_auth_difference_and_conflicted_subgraph(
self,
room_id: str,
state_maps: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: StateResolutionStoreInterface,
) -> Set[str]:
_, conflicted_state = _seperate(state_maps)
conflicted_set: Optional[Set[str]] = set(
itertools.chain.from_iterable(conflicted_state.values())
)
if event_map is None:
event_map = {}
return await _get_auth_chain_difference(
room_id,
state_maps,
event_map,
state_res_store,
conflicted_set,
)
def get_resolution_and_verify_expected(
self,
state_maps: Sequence[StateMap[str]],
events: List[EventBase],
expected: StateMap[str],
) -> None:
room_id = events[0].room_id
# First we try everything in-memory to check that the test case works.
event_map = {ev.event_id: ev for ev in events}
resolution = self.successResultOf(
resolve_events_with_store(
FakeClock(),
room_id,
events[0].room_version,
state_maps,
event_map=event_map,
state_res_store=TestStateResolutionStore(event_map),
)
)
self.assertEqual(resolution, expected)
got_auth_diff = self.successResultOf(
self._get_auth_difference_and_conflicted_subgraph(
room_id,
state_maps,
event_map,
TestStateResolutionStore(event_map),
)
)
# we should never see the create event in the auth diff. If we do, this implies the
# conflicted subgraph is wrong and is returning too many old events.
assert events[0].event_id not in got_auth_diff
# now let's make the room exist on the DB, some queries rely on there being a row in
# the rooms table when persisting
self.get_success(
self.store.store_room(
room_id,
events[0].sender,
True,
events[0].room_version,
)
)
def resolve_and_check() -> None:
event_map = {ev.event_id: ev for ev in events}
store = StateResolutionStore(
self._persistence.main_store,
self._state_deletion,
)
resolution = self.get_success(
resolve_events_with_store(
FakeClock(),
room_id,
RoomVersions.HydraV11,
state_maps,
event_map=event_map,
state_res_store=store,
)
)
self.assertEqual(resolution, expected)
got_auth_diff2 = self.get_success(
self._get_auth_difference_and_conflicted_subgraph(
room_id,
state_maps,
event_map,
store,
)
)
# no matter how many events are persisted, the overall diff should always be the same.
self.assertEquals(got_auth_diff, got_auth_diff2)
# now we will drip feed in `events` one-by-one, persisting them then resolving with the
# rest. This ensures we correctly handle mixed persisted/unpersisted events. We will finish
# with doing the test with all persisted events.
while len(events) > 0:
event_to_persist = events.pop(0)
self.persist_event(event_to_persist)
# now retest
resolve_and_check()
def persist_event(
self, event: EventBase, state: Optional[StateMap[str]] = None
) -> None:
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(
event,
state_ids_before_event=state,
partial_state=None if state is None else False,
)
)
self.get_success(self._persistence.persist_event(event, context))
def create_event(
self,
event_type: str,
state_key: Optional[str],
sender: str,
content: Dict,
auth_events: List[str],
prev_events: Optional[List[str]] = None,
room_id: Optional[str] = None,
) -> EventBase:
"""Short-hand for event_from_pdu_json for fields we typically care about.
Tests can override by just calling event_from_pdu_json directly."""
if prev_events is None:
prev_events = []
pdu = {
"type": event_type,
"state_key": state_key,
"content": content,
"sender": sender,
"depth": 5,
"prev_events": prev_events,
"auth_events": auth_events,
"origin_server_ts": monotonic_timestamp(),
}
if event_type != EventTypes.Create:
if room_id is None:
raise Exception("must specify a room_id to create_event")
pdu["room_id"] = room_id
return event_from_pdu_json(
pdu,
RoomVersions.HydraV11,
)

View File

@@ -26,6 +26,7 @@ from typing import (
Iterable,
List,
Mapping,
NamedTuple,
Set,
Tuple,
TypeVar,
@@ -730,6 +731,202 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
)
self.assertSetEqual(difference, set())
def test_conflicted_subgraph(self) -> None:
"""Test that the conflicted subgraph code in state res v2.1 can walk arbitrary length links
and chains.
We construct a chain cover index like this:
A1 <- A2 <- A3
^--- B1 <- B2 <- B3
^--- C1 <- C2 <- C3 v--- G1 <- G2
^--- D1 <- D2 <- D3
^--- E1
^--- F1 <- F2
..and then pick various events to be conflicted / additional backwards reachable to assert
that the code walks the chains correctly. We're particularly interested in ensuring that
the code walks multiple links between chains, hence why we have so many chains.
"""
class TestNode(NamedTuple):
event_id: str
chain_id: int
seq_num: int
class TestLink(NamedTuple):
origin_chain_and_seq: Tuple[int, int]
target_chain_and_seq: Tuple[int, int]
# Map to chain IDs / seq nums
nodes: List[TestNode] = [
TestNode("A1", 1, 1),
TestNode("A2", 1, 2),
TestNode("A3", 1, 3),
TestNode("B1", 2, 1),
TestNode("B2", 2, 2),
TestNode("B3", 2, 3),
TestNode("C1", 3, 1),
TestNode("C2", 3, 2),
TestNode("C3", 3, 3),
TestNode("D1", 4, 1),
TestNode("D2", 4, 2),
TestNode("D3", 4, 3),
TestNode("E1", 5, 1),
TestNode("F1", 6, 1),
TestNode("F2", 6, 2),
TestNode("G1", 7, 1),
TestNode("G2", 7, 2),
]
links: List[TestLink] = [
TestLink((2, 1), (1, 2)), # B1 -> A2
TestLink((3, 1), (2, 2)), # C1 -> B2
TestLink((4, 1), (3, 1)), # D1 -> C1
TestLink((5, 1), (4, 2)), # E1 -> D2
TestLink((6, 1), (5, 1)), # F1 -> E1
TestLink((7, 1), (4, 3)), # G1 -> D3
]
# populate the chain cover index tables as that's all we need
for node in nodes:
self.get_success(
self.store.db_pool.simple_insert(
"event_auth_chains",
{
"event_id": node.event_id,
"chain_id": node.chain_id,
"sequence_number": node.seq_num,
},
desc="insert",
)
)
for link in links:
self.get_success(
self.store.db_pool.simple_insert(
"event_auth_chain_links",
{
"origin_chain_id": link.origin_chain_and_seq[0],
"origin_sequence_number": link.origin_chain_and_seq[1],
"target_chain_id": link.target_chain_and_seq[0],
"target_sequence_number": link.target_chain_and_seq[1],
},
desc="insert",
)
)
# Define the test cases
class TestCase(NamedTuple):
name: str
conflicted: Set[str]
additional_backwards_reachable: Set[str]
want_conflicted_subgraph: Set[str]
# Reminder:
# A1 <- A2 <- A3
# ^--- B1 <- B2 <- B3
# ^--- C1 <- C2 <- C3 v--- G1 <- G2
# ^--- D1 <- D2 <- D3
# ^--- E1
# ^--- F1 <- F2
test_cases = [
TestCase(
name="basic_single_chain",
conflicted={"B1", "B3"},
additional_backwards_reachable=set(),
want_conflicted_subgraph={"B1", "B2", "B3"},
),
TestCase(
name="basic_single_link",
conflicted={"A1", "B2"},
additional_backwards_reachable=set(),
want_conflicted_subgraph={"A1", "A2", "B1", "B2"},
),
TestCase(
name="basic_multi_link",
conflicted={"B1", "F1"},
additional_backwards_reachable=set(),
want_conflicted_subgraph={"B1", "B2", "C1", "D1", "D2", "E1", "F1"},
),
# Repeat these tests but put the later event as an additional backwards reachable event.
# The output should be the same.
TestCase(
name="basic_single_chain_as_additional",
conflicted={"B1"},
additional_backwards_reachable={"B3"},
want_conflicted_subgraph={"B1", "B2", "B3"},
),
TestCase(
name="basic_single_link_as_additional",
conflicted={"A1"},
additional_backwards_reachable={"B2"},
want_conflicted_subgraph={"A1", "A2", "B1", "B2"},
),
TestCase(
name="basic_multi_link_as_additional",
conflicted={"B1"},
additional_backwards_reachable={"F1"},
want_conflicted_subgraph={"B1", "B2", "C1", "D1", "D2", "E1", "F1"},
),
TestCase(
name="mixed_multi_link",
conflicted={"D1", "F1"},
additional_backwards_reachable={"G1"},
want_conflicted_subgraph={"D1", "D2", "D3", "E1", "F1", "G1"},
),
TestCase(
name="additional_backwards_doesnt_add_forwards_info",
conflicted={"C1", "C3"},
# This is on the path to the root but Chain C isn't forwards reachable so this doesn't
# change anything.
additional_backwards_reachable={"B1"},
want_conflicted_subgraph={"C1", "C2", "C3"},
),
TestCase(
name="empty_subgraph",
conflicted={
"B3",
"C3",
}, # these can't reach each other going forwards, so the subgraph is empty
additional_backwards_reachable=set(),
want_conflicted_subgraph={"B3", "C3"},
),
TestCase(
name="empty_subgraph_with_additional",
conflicted={"C1"},
# This is on the path to the root but Chain C isn't forwards reachable so this doesn't
# change anything.
additional_backwards_reachable={"B1"},
want_conflicted_subgraph={"C1"},
),
TestCase(
name="empty_subgraph_single_conflict",
conflicted={"C1"}, # no subgraph can form as you need 2+
additional_backwards_reachable=set(),
want_conflicted_subgraph={"C1"},
),
]
def run_test(txn: LoggingTransaction, test_case: TestCase) -> None:
result = self.store._get_auth_chain_difference_using_cover_index_txn(
txn,
"!not_relevant",
[test_case.conflicted.union(test_case.additional_backwards_reachable)],
test_case.conflicted,
test_case.additional_backwards_reachable,
)
self.assertEquals(
result.conflicted_subgraph,
test_case.want_conflicted_subgraph,
f"{test_case.name} : conflicted subgraph mismatch",
)
for test_case in test_cases:
self.get_success(
self.store.db_pool.runInteraction(
f"test_case_{test_case.name}", run_test, test_case
)
)
@parameterized.expand(
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
)

View File

@@ -66,6 +66,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
body = self.helper.send(self.room_id, body="Test", tok=self.token)
local_message_event_id = body["event_id"]
current_state = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
# Fudge a remote event and persist it. This will be the extremity before
# the gap.
self.remote_event_1 = event_from_pdu_json(
@@ -77,7 +81,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other",
"depth": 5,
"prev_events": [local_message_event_id],
"auth_events": [],
"auth_events": [
current_state.get((EventTypes.Create, "")),
current_state.get((EventTypes.PowerLevels, "")),
current_state.get((EventTypes.JoinRules, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
@@ -113,6 +121,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
the same domain.
"""
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
# Fudge a second event which points to an event we don't have. This is a
# state event so that the state changes (otherwise we won't prune the
# extremity as they'll have the same state group).
@@ -125,16 +137,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other",
"depth": 50,
"prev_events": ["$some_unknown_message"],
"auth_events": [],
"auth_events": [
state_before_gap.get((EventTypes.Create, "")),
state_before_gap.get((EventTypes.PowerLevels, "")),
state_before_gap.get((EventTypes.JoinRules, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
@@ -145,6 +157,15 @@ class ExtremPruneTestCase(HomeserverTestCase):
state is different.
"""
# Now we persist it with state with a dropped history visibility
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
)
# Fudge a second event which points to an event we don't have.
remote_event_2 = event_from_pdu_json(
{
@@ -155,20 +176,15 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other",
"depth": 10,
"prev_events": ["$some_unknown_message"],
"auth_events": [],
"auth_events": [
state_before_gap.get((EventTypes.Create, "")),
state_before_gap.get((EventTypes.PowerLevels, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
)
# Now we persist it with state with a dropped history visibility
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
@@ -193,6 +209,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
# also set the depth to "lots".
self.reactor.advance(7 * 24 * 60 * 60)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
# Fudge a second event which points to an event we don't have. This is a
# state event so that the state changes (otherwise we won't prune the
# extremity as they'll have the same state group).
@@ -205,16 +225,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other2",
"depth": 10000,
"prev_events": ["$some_unknown_message"],
"auth_events": [],
"auth_events": [
state_before_gap.get((EventTypes.Create, "")),
state_before_gap.get((EventTypes.PowerLevels, "")),
state_before_gap.get((EventTypes.JoinRules, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
@@ -225,6 +245,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
from a different domain.
"""
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
# Fudge a second event which points to an event we don't have. This is a
# state event so that the state changes (otherwise we won't prune the
# extremity as they'll have the same state group).
@@ -237,16 +261,15 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other2",
"depth": 10,
"prev_events": ["$some_unknown_message"],
"auth_events": [],
"auth_events": [
state_before_gap.get((EventTypes.Create, "")),
state_before_gap.get((EventTypes.PowerLevels, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
@@ -267,6 +290,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
# also set the depth to "lots".
self.reactor.advance(7 * 24 * 60 * 60)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
# Fudge a second event which points to an event we don't have. This is a
# state event so that the state changes (otherwise we won't prune the
# extremity as they'll have the same state group).
@@ -279,16 +306,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other2",
"depth": 10000,
"prev_events": ["$some_unknown_message"],
"auth_events": [],
"auth_events": [
state_before_gap.get((EventTypes.Create, "")),
state_before_gap.get((EventTypes.PowerLevels, "")),
state_before_gap.get((EventTypes.JoinRules, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
@@ -311,6 +338,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
# also set the depth to "lots".
self.reactor.advance(7 * 24 * 60 * 60)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
# Fudge a second event which points to an event we don't have. This is a
# state event so that the state changes (otherwise we won't prune the
# extremity as they'll have the same state group).
@@ -323,16 +354,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other2",
"depth": 10000,
"prev_events": ["$some_unknown_message"],
"auth_events": [],
"auth_events": [
state_before_gap.get((EventTypes.Create, "")),
state_before_gap.get((EventTypes.PowerLevels, "")),
state_before_gap.get((EventTypes.JoinRules, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
@@ -347,6 +378,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
local_message_event_id = body["event_id"]
self.assert_extremities([local_message_event_id])
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
# Fudge a second event which points to an event we don't have. This is a
# state event so that the state changes (otherwise we won't prune the
# extremity as they'll have the same state group).
@@ -359,16 +394,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
"sender": "@user:other2",
"depth": 10000,
"prev_events": ["$some_unknown_message"],
"auth_events": [],
"auth_events": [
state_before_gap.get((EventTypes.Create, "")),
state_before_gap.get((EventTypes.PowerLevels, "")),
state_before_gap.get((EventTypes.JoinRules, "")),
],
"origin_server_ts": self.clock.time_msec(),
},
RoomVersions.V6,
)
state_before_gap = self.get_success(
self._state_storage_controller.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.

View File

@@ -25,7 +25,7 @@ from canonicaljson import json
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.server import HomeServer
@@ -263,12 +263,17 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@property
def room_id(self) -> str:
assert self._base_builder.room_id is not None
return self._base_builder.room_id
@property
def type(self) -> str:
return self._base_builder.type
@property
def room_version(self) -> RoomVersion:
return self._base_builder.room_version
@property
def internal_metadata(self) -> EventInternalMetadata:
return self._base_builder.internal_metadata

51
tests/types/test_init.py Normal file
View File

@@ -0,0 +1,51 @@
from synapse.api.errors import SynapseError
from synapse.types import RoomID
from tests.unittest import TestCase
class RoomIDTestCase(TestCase):
def test_can_create_msc4291_room_ids(self) -> None:
valid_msc4291_room_id = "!31hneApxJ_1o-63DmFrpeqnkFfWppnzWso1JvH3ogLM"
room_id = RoomID.from_string(valid_msc4291_room_id)
self.assertEquals(RoomID.is_valid(valid_msc4291_room_id), True)
self.assertEquals(
room_id.to_string(),
valid_msc4291_room_id,
)
self.assertEquals(room_id.id, "!31hneApxJ_1o-63DmFrpeqnkFfWppnzWso1JvH3ogLM")
self.assertEquals(room_id.get_domain(), None)
def test_cannot_create_invalid_msc4291_room_ids(self) -> None:
invalid_room_ids = [
"!wronglength",
"!31hneApxJ_1o-63DmFrpeqnNOTurlsafeBASE64/gLM",
"!",
"! ",
]
for bad_room_id in invalid_room_ids:
with self.assertRaises(SynapseError):
RoomID.from_string(bad_room_id)
if not RoomID.is_valid(bad_room_id):
raise SynapseError(400, "invalid")
def test_cannot_create_invalid_legacy_room_ids(self) -> None:
invalid_room_ids = [
"!something:invalid$_chars.com",
]
for bad_room_id in invalid_room_ids:
with self.assertRaises(SynapseError):
RoomID.from_string(bad_room_id)
if not RoomID.is_valid(bad_room_id):
raise SynapseError(400, "invalid")
def test_can_create_valid_legacy_room_ids(self) -> None:
valid_room_ids = [
"!foo:example.com",
"!foo:example.com:8448",
"!💩💩💩:example.com",
]
for room_id_str in valid_room_ids:
room_id = RoomID.from_string(room_id_str)
self.assertEquals(RoomID.is_valid(room_id_str), True)
self.assertIsNotNone(room_id.get_domain())