diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index a0dd661c70..68303308cd 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -4174,7 +4174,7 @@ The default power levels for each preset are: "m.room.history_visibility": 100 "m.room.canonical_alias": 50 "m.room.avatar": 50 -"m.room.tombstone": 100 +"m.room.tombstone": 100 (150 if MSC4289 is used) "m.room.server_acl": 100 "m.room.encryption": 100 ``` diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index af59ec17e7..584f6e0ae8 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -5184,7 +5184,7 @@ properties: "m.room.avatar": 50 - "m.room.tombstone": 100 + "m.room.tombstone": 100 (150 if MSC4289 is used) "m.room.server_acl": 100 diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f60a94ffc3..7a8f546d6b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -46,6 +46,9 @@ MAX_USERID_LENGTH = 255 # Constant value used for the pseudo-thread which is the main timeline. MAIN_TIMELINE: Final = "main" +# MAX_INT + 1, so it always trumps any PL in canonical JSON. +CREATOR_POWER_LEVEL = 2**53 + class Membership: """Represents the membership states of a user in a room.""" @@ -235,6 +238,8 @@ class EventContentFields: # # This is deprecated in MSC2175. ROOM_CREATOR: Final = "creator" + # MSC4289 + ADDITIONAL_CREATORS: Final = "additional_creators" # The version of the room for `m.room.create` events. ROOM_VERSION: Final = "room_version" diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py index 4bde385f78..12b73546d0 100644 --- a/synapse/api/room_versions.py +++ b/synapse/api/room_versions.py @@ -36,12 +36,14 @@ class EventFormatVersions: ROOM_V1_V2 = 1 # $id:server event id format: used for room v1 and v2 ROOM_V3 = 2 # MSC1659-style $hash event id format: used for room v3 ROOM_V4_PLUS = 3 # MSC1884-style $hash format: introduced for room v4 + ROOM_V11_HYDRA_PLUS = 4 # MSC4291 room IDs as hashes: introduced for room HydraV11 KNOWN_EVENT_FORMAT_VERSIONS = { EventFormatVersions.ROOM_V1_V2, EventFormatVersions.ROOM_V3, EventFormatVersions.ROOM_V4_PLUS, + EventFormatVersions.ROOM_V11_HYDRA_PLUS, } @@ -50,6 +52,7 @@ class StateResolutionVersions: V1 = 1 # room v1 state res V2 = 2 # MSC1442 state res: room v2 and later + V2_1 = 3 # MSC4297 state res class RoomDisposition: @@ -109,6 +112,10 @@ class RoomVersion: msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag # MSC3757: Restricting who can overwrite a state event msc3757_enabled: bool + # MSC4289: Creator power enabled + msc4289_creator_power_enabled: bool + # MSC4291: Room IDs as hashes of the create event + msc4291_room_ids_as_hashes: bool class RoomVersions: @@ -131,6 +138,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V2 = RoomVersion( "2", @@ -151,6 +160,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V3 = RoomVersion( "3", @@ -171,6 +182,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V4 = RoomVersion( "4", @@ -191,6 +204,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V5 = RoomVersion( "5", @@ -211,6 +226,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V6 = RoomVersion( "6", @@ -231,6 +248,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V7 = RoomVersion( "7", @@ -251,6 +270,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V8 = RoomVersion( "8", @@ -271,6 +292,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V9 = RoomVersion( "9", @@ -291,6 +314,8 @@ class RoomVersions: enforce_int_power_levels=False, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V10 = RoomVersion( "10", @@ -311,6 +336,8 @@ class RoomVersions: enforce_int_power_levels=True, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) MSC1767v10 = RoomVersion( # MSC1767 (Extensible Events) based on room version "10" @@ -332,6 +359,8 @@ class RoomVersions: enforce_int_power_levels=True, msc3931_push_features=(PushRuleRoomFlag.EXTENSIBLE_EVENTS,), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) MSC3757v10 = RoomVersion( # MSC3757 (Restricting who can overwrite a state event) based on room version "10" @@ -353,6 +382,8 @@ class RoomVersions: enforce_int_power_levels=True, msc3931_push_features=(), msc3757_enabled=True, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) V11 = RoomVersion( "11", @@ -373,6 +404,8 @@ class RoomVersions: enforce_int_power_levels=True, msc3931_push_features=(), msc3757_enabled=False, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, ) MSC3757v11 = RoomVersion( # MSC3757 (Restricting who can overwrite a state event) based on room version "11" @@ -394,6 +427,52 @@ class RoomVersions: enforce_int_power_levels=True, msc3931_push_features=(), msc3757_enabled=True, + msc4289_creator_power_enabled=False, + msc4291_room_ids_as_hashes=False, + ) + HydraV11 = RoomVersion( + "org.matrix.hydra.11", + RoomDisposition.UNSTABLE, + EventFormatVersions.ROOM_V11_HYDRA_PLUS, + StateResolutionVersions.V2_1, # Changed from v11 + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + implicit_room_creator=True, # Used by MSC3820 + updated_redaction_rules=True, # Used by MSC3820 + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, + msc3389_relation_redactions=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, + msc3931_push_features=(), + msc3757_enabled=False, + msc4289_creator_power_enabled=True, # Changed from v11 + msc4291_room_ids_as_hashes=True, # Changed from v11 + ) + V12 = RoomVersion( + "12", + RoomDisposition.STABLE, + EventFormatVersions.ROOM_V11_HYDRA_PLUS, + StateResolutionVersions.V2_1, # Changed from v11 + enforce_key_validity=True, + special_case_aliases_auth=False, + strict_canonicaljson=True, + limit_notifications_power_levels=True, + implicit_room_creator=True, # Used by MSC3820 + updated_redaction_rules=True, # Used by MSC3820 + restricted_join_rule=True, + restricted_join_rule_fix=True, + knock_join_rule=True, + msc3389_relation_redactions=False, + knock_restricted_join_rule=True, + enforce_int_power_levels=True, + msc3931_push_features=(), + msc3757_enabled=False, + msc4289_creator_power_enabled=True, # Changed from v11 + msc4291_room_ids_as_hashes=True, # Changed from v11 ) @@ -411,6 +490,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = { RoomVersions.V9, RoomVersions.V10, RoomVersions.V11, + RoomVersions.V12, RoomVersions.MSC3757v10, RoomVersions.MSC3757v11, ) diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index b85df1ce42..c36398cec0 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -101,6 +101,9 @@ def compute_content_hash( event_dict.pop("outlier", None) event_dict.pop("destinations", None) + # N.B. no need to pop the room_id from create events in MSC4291 rooms + # as they shouldn't have one. + event_json_bytes = encode_canonical_json(event_dict) hashed = hash_algorithm(event_json_bytes) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 35d02c8294..64de3f7ef8 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -45,6 +45,7 @@ from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 from synapse.api.constants import ( + CREATOR_POWER_LEVEL, MAX_PDU_SIZE, EventContentFields, EventTypes, @@ -64,6 +65,7 @@ from synapse.api.room_versions import ( RoomVersion, RoomVersions, ) +from synapse.events import is_creator from synapse.state import CREATE_KEY from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -261,7 +263,8 @@ async def check_state_independent_auth_rules( f"Event {event.event_id} has unexpected auth_event for {k}: {auth_event_id}", ) - # We also need to check that the auth event itself is not rejected. + # 2.3 ... If there are entries which were themselves rejected under the checks performed on receipt + # of a PDU, reject. if auth_event.rejected_reason: raise AuthError( 403, @@ -271,7 +274,7 @@ async def check_state_independent_auth_rules( auth_dict[k] = auth_event_id - # 3. If event does not have a m.room.create in its auth_events, reject. + # 2.4. If event does not have a m.room.create in its auth_events, reject. creation_event = auth_dict.get((EventTypes.Create, ""), None) if not creation_event: raise AuthError(403, "No create event in auth events") @@ -311,13 +314,14 @@ def check_state_dependent_auth_rules( # Later code relies on there being a create event e.g _can_federate, _is_membership_change_allowed # so produce a more intelligible error if we don't have one. - if auth_dict.get(CREATE_KEY) is None: + create_event = auth_dict.get(CREATE_KEY) + if create_event is None: raise AuthError( 403, f"Event {event.event_id} is missing a create event in auth_events." ) # additional check for m.federate - creating_domain = get_domain_from_id(event.room_id) + creating_domain = get_domain_from_id(create_event.sender) originating_domain = get_domain_from_id(event.sender) if creating_domain != originating_domain: if not _can_federate(event, auth_dict): @@ -470,12 +474,20 @@ def _check_create(event: "EventBase") -> None: if event.prev_event_ids(): raise AuthError(403, "Create event has prev events") - # 1.2 If the domain of the room_id does not match the domain of the sender, - # reject. - sender_domain = get_domain_from_id(event.sender) - room_id_domain = get_domain_from_id(event.room_id) - if room_id_domain != sender_domain: - raise AuthError(403, "Creation event's room_id domain does not match sender's") + if event.room_version.msc4291_room_ids_as_hashes: + # 1.2 If the create event has a room_id, reject + if "room_id" in event: + raise AuthError(403, "Create event has a room_id") + else: + # 1.2 If the domain of the room_id does not match the domain of the sender, + # reject. + if not event.room_version.msc4291_room_ids_as_hashes: + sender_domain = get_domain_from_id(event.sender) + room_id_domain = get_domain_from_id(event.room_id) + if room_id_domain != sender_domain: + raise AuthError( + 403, "Creation event's room_id domain does not match sender's" + ) # 1.3 If content.room_version is present and is not a recognised version, reject room_version_prop = event.content.get("room_version", "1") @@ -492,6 +504,16 @@ def _check_create(event: "EventBase") -> None: ): raise AuthError(403, "Create event lacks a 'creator' property") + # 1.5 If the additional_creators field is present and is not an array of strings where each + # string is a valid user ID, reject. + if ( + event.room_version.msc4289_creator_power_enabled + and EventContentFields.ADDITIONAL_CREATORS in event.content + ): + check_valid_additional_creators( + event.content[EventContentFields.ADDITIONAL_CREATORS] + ) + def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool: creation_event = auth_events.get((EventTypes.Create, "")) @@ -533,7 +555,13 @@ def _is_membership_change_allowed( target_user_id = event.state_key - creating_domain = get_domain_from_id(event.room_id) + # We need the create event in order to check if we can federate or not. + # If it's missing, yell loudly. Previously we only did this inside the + # _can_federate check. + create_event = auth_events.get((EventTypes.Create, "")) + if not create_event: + raise AuthError(403, "Create event missing from auth_events") + creating_domain = get_domain_from_id(create_event.sender) target_domain = get_domain_from_id(target_user_id) if creating_domain != target_domain: if not _can_federate(event, auth_events): @@ -903,6 +931,32 @@ def _check_power_levels( except Exception: raise SynapseError(400, "Not a valid power level: %s" % (v,)) + if room_version_obj.msc4289_creator_power_enabled: + # Enforce the creator does not appear in the users map + create_event = auth_events.get((EventTypes.Create, "")) + if not create_event: + raise SynapseError( + 400, "Cannot check power levels without a create event in auth_events" + ) + if create_event.sender in user_list: + raise SynapseError( + 400, + "Creator user %s must not appear in content.users" + % (create_event.sender,), + ) + additional_creators = create_event.content.get( + EventContentFields.ADDITIONAL_CREATORS, [] + ) + if additional_creators: + creators_in_user_list = set(additional_creators).intersection( + set(user_list) + ) + if len(creators_in_user_list) > 0: + raise SynapseError( + 400, + "Additional creators users must not appear in content.users", + ) + # Reject events with stringy power levels if required by room version if ( event.type == EventTypes.PowerLevels @@ -1028,6 +1082,9 @@ def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> in "A create event in the auth events chain is required to calculate user power level correctly," " but was not found. This indicates a bug" ) + if create_event.room_version.msc4289_creator_power_enabled: + if is_creator(create_event, user_id): + return CREATOR_POWER_LEVEL power_level_event = get_power_level_event(auth_events) if power_level_event: level = power_level_event.content.get("users", {}).get(user_id) @@ -1188,3 +1245,26 @@ def auth_types_for_event( auth_types.add(key) return auth_types + + +def check_valid_additional_creators(additional_creators: Any) -> None: + """Check if the additional_creators provided is valid according to MSC4289. + + The additional_creators can be supplied from an m.room.create event or from an /upgrade request. + + Raises: + AuthError if the additional_creators is invalid for some reason. + """ + if type(additional_creators) is not list: + raise AuthError(400, "additional_creators must be an array") + for entry in additional_creators: + if type(entry) is not str: + raise AuthError(400, "entry in additional_creators is not a string") + if not UserID.is_valid(entry): + raise AuthError(400, "entry in additional_creators is not a valid user ID") + # UserID.is_valid doesn't actually validate everything, so check the rest manually. + if len(entry) > 255 or len(entry.encode("utf-8")) > 255: + raise AuthError( + 400, + "entry in additional_creators too long", + ) diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index c77d569e2e..db38754280 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -41,10 +41,13 @@ from typing import ( import attr from unpaddedbase64 import encode_base64 -from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata -from synapse.types import JsonDict, StrCollection +from synapse.types import ( + JsonDict, + StrCollection, +) from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze @@ -209,7 +212,6 @@ class EventBase(metaclass=abc.ABCMeta): content: DictProperty[JsonDict] = DictProperty("content") hashes: DictProperty[Dict[str, str]] = DictProperty("hashes") origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts") - room_id: DictProperty[str] = DictProperty("room_id") sender: DictProperty[str] = DictProperty("sender") # TODO state_key should be Optional[str]. This is generally asserted in Synapse # by calling is_state() first (which ensures it is not None), but it is hard (not possible?) @@ -224,6 +226,10 @@ class EventBase(metaclass=abc.ABCMeta): def event_id(self) -> str: raise NotImplementedError() + @property + def room_id(self) -> str: + raise NotImplementedError() + @property def membership(self) -> str: return self.content["membership"] @@ -386,6 +392,10 @@ class FrozenEvent(EventBase): def event_id(self) -> str: return self._event_id + @property + def room_id(self) -> str: + return self._dict["room_id"] + class FrozenEventV2(EventBase): format_version = EventFormatVersions.ROOM_V3 # All events of this type are V2 @@ -443,6 +453,10 @@ class FrozenEventV2(EventBase): self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) return self._event_id + @property + def room_id(self) -> str: + return self._dict["room_id"] + def prev_event_ids(self) -> List[str]: """Returns the list of prev event IDs. The order matches the order specified in the event, though there is no meaning to it. @@ -481,6 +495,67 @@ class FrozenEventV3(FrozenEventV2): return self._event_id +class FrozenEventV4(FrozenEventV3): + """FrozenEventV4 for MSC4291 room IDs are hashes""" + + format_version = EventFormatVersions.ROOM_V11_HYDRA_PLUS + + """Override the room_id for m.room.create events""" + + def __init__( + self, + event_dict: JsonDict, + room_version: RoomVersion, + internal_metadata_dict: Optional[JsonDict] = None, + rejected_reason: Optional[str] = None, + ): + super().__init__( + event_dict=event_dict, + room_version=room_version, + internal_metadata_dict=internal_metadata_dict, + rejected_reason=rejected_reason, + ) + self._room_id: Optional[str] = None + + @property + def room_id(self) -> str: + # if we have calculated the room ID already, don't do it again. + if self._room_id: + return self._room_id + + is_create_event = self.type == EventTypes.Create and self.get_state_key() == "" + + # for non-create events: use the supplied value from the JSON, as per FrozenEventV3 + if not is_create_event: + self._room_id = self._dict["room_id"] + assert self._room_id is not None + return self._room_id + + # for create events: calculate the room ID + from synapse.crypto.event_signing import compute_event_reference_hash + + self._room_id = "!" + encode_base64( + compute_event_reference_hash(self)[1], urlsafe=True + ) + return self._room_id + + def auth_event_ids(self) -> StrCollection: + """Returns the list of auth event IDs. The order matches the order + specified in the event, though there is no meaning to it. + Returns: + The list of event IDs of this event's auth_events + Includes the creation event ID for convenience of all the codepaths + which expects the auth chain to include the creator ID, even though + it's explicitly not included on the wire. Excludes the create event + for the create event itself. + """ + create_event_id = "$" + self.room_id[1:] + assert create_event_id not in self._dict["auth_events"] + if self.type == EventTypes.Create and self.get_state_key() == "": + return self._dict["auth_events"] # should be [] + return self._dict["auth_events"] + [create_event_id] + + def _event_type_from_format_version( format_version: int, ) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]: @@ -500,6 +575,8 @@ def _event_type_from_format_version( return FrozenEventV2 elif format_version == EventFormatVersions.ROOM_V4_PLUS: return FrozenEventV3 + elif format_version == EventFormatVersions.ROOM_V11_HYDRA_PLUS: + return FrozenEventV4 else: raise Exception("No event format %r" % (format_version,)) @@ -559,6 +636,23 @@ def relation_from_event(event: EventBase) -> Optional[_EventRelation]: return _EventRelation(parent_id, rel_type, aggregation_key) +def is_creator(create: EventBase, user_id: str) -> bool: + """ + Return true if the provided user ID is the room creator. + + This includes additional creators in MSC4289. + """ + assert create.type == EventTypes.Create + if create.sender == user_id: + return True + if create.room_version.msc4289_creator_power_enabled: + additional_creators = set( + create.content.get(EventContentFields.ADDITIONAL_CREATORS, []) + ) + return user_id in additional_creators + return False + + @attr.s(slots=True, frozen=True, auto_attribs=True) class StrippedStateEvent: """ diff --git a/synapse/events/builder.py b/synapse/events/builder.py index afb04881df..5e1913d389 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -82,7 +82,8 @@ class EventBuilder: room_version: RoomVersion - room_id: str + # MSC4291 makes the room ID == the create event ID. This means the create event has no room_id. + room_id: Optional[str] type: str sender: str @@ -142,7 +143,14 @@ class EventBuilder: Returns: The signed and hashed event. """ + # Create events always have empty auth_events. + if self.type == EventTypes.Create and self.is_state() and self.state_key == "": + auth_event_ids = [] + + # Calculate auth_events for non-create events if auth_event_ids is None: + # Every non-create event must have a room ID + assert self.room_id is not None state_ids = await self._state.compute_state_after_events( self.room_id, prev_event_ids, @@ -224,12 +232,31 @@ class EventBuilder: "auth_events": auth_events, "prev_events": prev_events, "type": self.type, - "room_id": self.room_id, "sender": self.sender, "content": self.content, "unsigned": self.unsigned, "depth": depth, } + if self.room_id is not None: + event_dict["room_id"] = self.room_id + + if self.room_version.msc4291_room_ids_as_hashes: + # In MSC4291: the create event has no room ID as the create event ID /is/ the room ID. + if ( + self.type == EventTypes.Create + and self.is_state() + and self._state_key == "" + ): + assert self.room_id is None + else: + # All other events do not reference the create event in auth_events, as the room ID + # /is/ the create event. However, the rest of the code (for consistency between room + # versions) assume that the create event remains part of the auth events. c.f. event + # class which automatically adds the create event when `.auth_event_ids()` is called + assert self.room_id is not None + create_event_id = "$" + self.room_id[1:] + auth_event_ids.remove(create_event_id) + event_dict["auth_events"] = auth_event_ids if self.is_state(): event_dict["state_key"] = self._state_key @@ -285,7 +312,7 @@ class EventBuilderFactory: room_version=room_version, type=key_values["type"], state_key=key_values.get("state_key"), - room_id=key_values["room_id"], + room_id=key_values.get("room_id"), sender=key_values["sender"], content=key_values.get("content", {}), unsigned=key_values.get("unsigned", {}), diff --git a/synapse/events/utils.py b/synapse/events/utils.py index cd7d3e6687..cae27136ce 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -176,9 +176,12 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic if room_version.updated_redaction_rules: # MSC2176 rules state that create events cannot have their `content` redacted. new_content = event_dict["content"] - elif not room_version.implicit_room_creator: + if not room_version.implicit_room_creator: # Some room versions give meaning to `creator` add_fields("creator") + if room_version.msc4291_room_ids_as_hashes: + # room_id is not allowed on the create event as it's derived from the event ID + allowed_keys.remove("room_id") elif event_type == EventTypes.JoinRules: add_fields("join_rule") @@ -527,6 +530,10 @@ def serialize_event( if config.as_client_event: d = config.event_format(d) + # Ensure the room_id field is set for create events in MSC4291 rooms + if e.type == EventTypes.Create and e.room_version.msc4291_room_ids_as_hashes: + d["room_id"] = e.room_id + # If the event is a redaction, the field with the redacted event ID appears # in a different location depending on the room version. e.redacts handles # fetching from the proper location; copy it to the other location for forwards- @@ -872,6 +879,14 @@ def strip_event(event: EventBase) -> JsonDict: Stripped state events can only have the `sender`, `type`, `state_key` and `content` properties present. """ + # MSC4311: Ensure the create event is available on invites and knocks. + # TODO: Implement the rest of MSC4311 + if ( + event.room_version.msc4291_room_ids_as_hashes + and event.type == EventTypes.Create + and event.get_state_key() == "" + ): + return event.get_pdu_json() return { "type": event.type, diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 15095cc4ef..4d9ba15829 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -183,8 +183,18 @@ class EventValidator: fields an event would have """ + create_event_as_room_id = ( + event.room_version.msc4291_room_ids_as_hashes + and event.type == EventTypes.Create + and hasattr(event, "state_key") + and event.state_key == "" + ) + strings = ["room_id", "sender", "type"] + if create_event_as_room_id: + strings.remove("room_id") + if hasattr(event, "state_key"): strings.append("state_key") @@ -192,7 +202,14 @@ class EventValidator: if not isinstance(getattr(event, s), str): raise SynapseError(400, "Not '%s' a string type" % (s,)) - RoomID.from_string(event.room_id) + if not create_event_as_room_id: + assert event.room_id is not None + RoomID.from_string(event.room_id) + if event.room_version.msc4291_room_ids_as_hashes and not RoomID.is_valid( + event.room_id + ): + raise SynapseError(400, f"Invalid room ID '{event.room_id}'") + UserID.from_string(event.sender) if event.type == EventTypes.Message: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c6be60ac78..a1c9c286ac 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -343,6 +343,21 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB if room_version.strict_canonicaljson: validate_canonicaljson(pdu_json) + # enforce that MSC4291 auth events don't include the create event. + # N.B. if they DO include a spurious create event, it'll fail auth checks elsewhere, so we don't + # need to do expensive DB lookups to find which event ID is the create event here. + if room_version.msc4291_room_ids_as_hashes: + room_id = pdu_json.get("room_id") + if room_id: + create_event_id = "$" + room_id[1:] + auth_events = pdu_json.get("auth_events") + if auth_events: + if create_event_id in auth_events: + raise SynapseError( + 400, + "auth_events must not contain the create event", + Codes.BAD_JSON, + ) event = make_event_from_dict(pdu_json, room_version) return event diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py index c4dbf22408..1f1f67dc0d 100644 --- a/synapse/handlers/event_auth.py +++ b/synapse/handlers/event_auth.py @@ -23,6 +23,8 @@ from typing import TYPE_CHECKING, List, Mapping, Optional, Union from synapse import event_auth from synapse.api.constants import ( + CREATOR_POWER_LEVEL, + EventContentFields, EventTypes, JoinRules, Membership, @@ -141,6 +143,8 @@ class EventAuthHandler: Raises: SynapseError if no appropriate user is found. """ + create_event_id = current_state_ids[(EventTypes.Create, "")] + create_event = await self._store.get_event(create_event_id) power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, "")) invite_level = 0 users_default_level = 0 @@ -156,15 +160,28 @@ class EventAuthHandler: # Find the user with the highest power level (only interested in local # users). + user_power_level = 0 + chosen_user = None local_users_in_room = await self._store.get_local_users_in_room(room_id) - chosen_user = max( - local_users_in_room, - key=lambda user: users.get(user, users_default_level), - default=None, - ) + if create_event.room_version.msc4289_creator_power_enabled: + creators = set( + create_event.content.get(EventContentFields.ADDITIONAL_CREATORS, []) + ) + creators.add(create_event.sender) + local_creators = creators.intersection(set(local_users_in_room)) + if len(local_creators) > 0: + chosen_user = local_creators.pop() # random creator + user_power_level = CREATOR_POWER_LEVEL + else: + chosen_user = max( + local_users_in_room, + key=lambda user: users.get(user, users_default_level), + default=None, + ) + # Return the chosen if they can issue invites. + if chosen_user: + user_power_level = users.get(chosen_user, users_default_level) - # Return the chosen if they can issue invites. - user_power_level = users.get(chosen_user, users_default_level) if chosen_user and user_power_level >= invite_level: logger.debug( "Found a user who can issue invites %s with power level %d >= invite level %d", diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 22301f9e63..4a939b9646 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -688,7 +688,10 @@ class EventCreationHandler: Codes.USER_ACCOUNT_SUSPENDED, ) - if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": + is_create_event = ( + event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "" + ) + if is_create_event: room_version_id = event_dict["content"]["room_version"] maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id) if not maybe_room_version_obj: @@ -794,6 +797,7 @@ class EventCreationHandler: """ # the only thing the user can do is join the server notices room. if builder.type == EventTypes.Member: + assert builder.room_id is not None membership = builder.content.get("membership", None) if membership == Membership.JOIN: return await self.store.is_server_notice_room(builder.room_id) @@ -1259,13 +1263,40 @@ class EventCreationHandler: for_verification=False, ) + if ( + builder.room_version.msc4291_room_ids_as_hashes + and builder.type == EventTypes.Create + and builder.is_state() + ): + if builder.room_id is not None: + raise SynapseError( + 400, + "Cannot resend m.room.create event", + Codes.INVALID_PARAM, + ) + else: + assert builder.room_id is not None + if prev_event_ids is not None: assert len(prev_event_ids) <= 10, ( "Attempting to create an event with %i prev_events" % (len(prev_event_ids),) ) else: - prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + if builder.room_id: + prev_event_ids = await self.store.get_prev_events_for_room( + builder.room_id + ) + else: + prev_event_ids = [] # can only happen for the create event in MSC4291 rooms + + if builder.type == EventTypes.Create and builder.is_state(): + if len(prev_event_ids) != 0: + raise SynapseError( + 400, + "Cannot resend m.room.create event", + Codes.INVALID_PARAM, + ) # We now ought to have some `prev_events` (unless it's a create event). # @@ -2228,6 +2259,7 @@ class EventCreationHandler: original_event.room_version, third_party_result ) self.validator.validate_builder(builder) + assert builder.room_id is not None except SynapseError as e: raise Exception( "Third party rules module created an invalid event: " + e.msg, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 50ef25b09d..72afb35ed7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -82,6 +82,7 @@ from synapse.types import ( Requester, RoomAlias, RoomID, + RoomIdWithDomain, RoomStreamToken, StateMap, StrCollection, @@ -195,7 +196,11 @@ class RoomCreationHandler: ) async def upgrade_room( - self, requester: Requester, old_room_id: str, new_version: RoomVersion + self, + requester: Requester, + old_room_id: str, + new_version: RoomVersion, + additional_creators: Optional[List[str]], ) -> str: """Replace a room with a new room with a different version @@ -203,6 +208,7 @@ class RoomCreationHandler: requester: the user requesting the upgrade old_room_id: the id of the room to be replaced new_version: the new room version to use + additional_creators: additional room creators, for MSC4289. Returns: the new room id @@ -235,8 +241,29 @@ class RoomCreationHandler: old_room = await self.store.get_room(old_room_id) if old_room is None: raise NotFoundError("Unknown room id %s" % (old_room_id,)) + old_room_is_public, _ = old_room - new_room_id = self._generate_room_id() + creation_event_with_context = None + if new_version.msc4291_room_ids_as_hashes: + old_room_create_event = await self.store.get_create_event_for_room( + old_room_id + ) + creation_content = self._calculate_upgraded_room_creation_content( + old_room_create_event, + tombstone_event_id=None, + new_room_version=new_version, + additional_creators=additional_creators, + ) + creation_event_with_context = await self._generate_create_event_for_room_id( + requester, + creation_content, + old_room_is_public, + new_version, + ) + (create_event, _) = creation_event_with_context + new_room_id = create_event.room_id + else: + new_room_id = self._generate_room_id() # Try several times, it could fail with PartialStateConflictError # in _upgrade_room, cf comment in except block. @@ -285,6 +312,8 @@ class RoomCreationHandler: new_version, tombstone_event, tombstone_context, + additional_creators, + creation_event_with_context, ) return ret @@ -308,6 +337,10 @@ class RoomCreationHandler: new_version: RoomVersion, tombstone_event: EventBase, tombstone_context: synapse.events.snapshot.EventContext, + additional_creators: Optional[List[str]], + creation_event_with_context: Optional[ + Tuple[EventBase, synapse.events.snapshot.EventContext] + ] = None, ) -> str: """ Args: @@ -319,6 +352,8 @@ class RoomCreationHandler: new_version: the version to upgrade the room to tombstone_event: the tombstone event to send to the old room tombstone_context: the context for the tombstone event + additional_creators: additional room creators, for MSC4289. + creation_event_with_context: The new room's create event, for room IDs as create event IDs. Raises: ShadowBanError if the requester is shadow-banned. @@ -328,14 +363,16 @@ class RoomCreationHandler: logger.info("Creating new room %s to replace %s", new_room_id, old_room_id) - # create the new room. may raise a `StoreError` in the exceedingly unlikely - # event of a room ID collision. - await self.store.store_room( - room_id=new_room_id, - room_creator_user_id=user_id, - is_public=old_room[0], - room_version=new_version, - ) + # We've already stored the room if we have the create event + if not creation_event_with_context: + # create the new room. may raise a `StoreError` in the exceedingly unlikely + # event of a room ID collision. + await self.store.store_room( + room_id=new_room_id, + room_creator_user_id=user_id, + is_public=old_room[0], + room_version=new_version, + ) await self.clone_existing_room( requester, @@ -343,6 +380,8 @@ class RoomCreationHandler: new_room_id=new_room_id, new_room_version=new_version, tombstone_event_id=tombstone_event.event_id, + additional_creators=additional_creators, + creation_event_with_context=creation_event_with_context, ) # now send the tombstone @@ -376,6 +415,7 @@ class RoomCreationHandler: old_room_id, new_room_id, old_room_state, + additional_creators, ) return new_room_id @@ -386,6 +426,7 @@ class RoomCreationHandler: old_room_id: str, new_room_id: str, old_room_state: StateMap[str], + additional_creators: Optional[List[str]], ) -> None: """Send updated power levels in both rooms after an upgrade @@ -394,7 +435,7 @@ class RoomCreationHandler: old_room_id: the id of the room to be replaced new_room_id: the id of the replacement room old_room_state: the state map for the old room - + additional_creators: Additional creators in the new room. Raises: ShadowBanError if the requester is shadow-banned. """ @@ -450,6 +491,14 @@ class RoomCreationHandler: except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) + new_room_version = await self.store.get_room_version(new_room_id) + if new_room_version.msc4289_creator_power_enabled: + self._remove_creators_from_pl_users_map( + old_room_pl_state.content.get("users", {}), + requester.user.to_string(), + additional_creators, + ) + await self.event_creation_handler.create_and_send_nonmember_event( requester, { @@ -464,6 +513,30 @@ class RoomCreationHandler: ratelimit=False, ) + def _calculate_upgraded_room_creation_content( + self, + old_room_create_event: EventBase, + tombstone_event_id: Optional[str], + new_room_version: RoomVersion, + ) -> JsonDict: + creation_content: JsonDict = { + "room_version": new_room_version.identifier, + "predecessor": { + "room_id": old_room_create_event.room_id, + }, + } + if tombstone_event_id is not None: + creation_content["predecessor"]["event_id"] = tombstone_event_id + # Check if old room was non-federatable + if not old_room_create_event.content.get(EventContentFields.FEDERATE, True): + # If so, mark the new room as non-federatable as well + creation_content[EventContentFields.FEDERATE] = False + # Copy the room type as per MSC3818. + room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE) + if room_type is not None: + creation_content[EventContentFields.ROOM_TYPE] = room_type + return creation_content + async def clone_existing_room( self, requester: Requester, @@ -471,6 +544,10 @@ class RoomCreationHandler: new_room_id: str, new_room_version: RoomVersion, tombstone_event_id: str, + additional_creators: Optional[List[str]], + creation_event_with_context: Optional[ + Tuple[EventBase, synapse.events.snapshot.EventContext] + ] = None, ) -> None: """Populate a new room based on an old room @@ -481,24 +558,24 @@ class RoomCreationHandler: created with _generate_room_id()) new_room_version: the new room version to use tombstone_event_id: the ID of the tombstone event in the old room. + additional_creators: additional room creators, for MSC4289. + creation_event_with_context: The create event of the new room, if the new room supports + room ID as create event ID hash. """ user_id = requester.user.to_string() - creation_content: JsonDict = { - "room_version": new_room_version.identifier, - "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id}, - } - - # Check if old room was non-federatable - # Get old room's create event old_room_create_event = await self.store.get_create_event_for_room(old_room_id) - # Check if the create event specified a non-federatable room - if not old_room_create_event.content.get(EventContentFields.FEDERATE, True): - # If so, mark the new room as non-federatable as well - creation_content[EventContentFields.FEDERATE] = False - + if creation_event_with_context: + create_event, _ = creation_event_with_context + creation_content = create_event.content + else: + creation_content = self._calculate_upgraded_room_creation_content( + old_room_create_event, + tombstone_event_id, + new_room_version, + ) initial_state = {} # Replicate relevant room events @@ -514,11 +591,8 @@ class RoomCreationHandler: (EventTypes.PowerLevels, ""), ] - # Copy the room type as per MSC3818. room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE) if room_type is not None: - creation_content[EventContentFields.ROOM_TYPE] = room_type - # If the old room was a space, copy over the rooms in the space. if room_type == RoomTypes.SPACE: types_to_copy.append((EventTypes.SpaceChild, None)) @@ -590,6 +664,14 @@ class RoomCreationHandler: if current_power_level_int < needed_power_level: user_power_levels[user_id] = needed_power_level + if new_room_version.msc4289_creator_power_enabled: + # the creator(s) cannot be in the users map + self._remove_creators_from_pl_users_map( + user_power_levels, + user_id, + additional_creators, + ) + # We construct what the body of a call to /createRoom would look like for passing # to the spam checker. We don't include a preset here, as we expect the # initial state to contain everything we need. @@ -618,6 +700,7 @@ class RoomCreationHandler: invite_list=[], initial_state=initial_state, creation_content=creation_content, + creation_event_with_context=creation_event_with_context, ) # Transfer membership events @@ -904,6 +987,7 @@ class RoomCreationHandler: power_level_content_override = config.get("power_level_content_override") if ( power_level_content_override + and not room_version.msc4289_creator_power_enabled # this validation doesn't apply in MSC4289 rooms and "users" in power_level_content_override and user_id not in power_level_content_override["users"] ): @@ -933,11 +1017,41 @@ class RoomCreationHandler: additional_fields=spam_check[1], ) - room_id = await self._generate_and_create_room_id( - creator_id=user_id, - is_public=is_public, - room_version=room_version, - ) + creation_content = config.get("creation_content", {}) + # override any attempt to set room versions via the creation_content + creation_content["room_version"] = room_version.identifier + + # trusted private chats have the invited users marked as additional creators + if ( + room_version.msc4289_creator_power_enabled + and config.get("preset", None) == RoomCreationPreset.TRUSTED_PRIVATE_CHAT + and len(config.get("invite", [])) > 0 + ): + # the other user(s) are additional creators + invitees = config.get("invite", []) + # we don't want to replace any additional_creators additionally specified, and we want + # to remove duplicates. + creation_content[EventContentFields.ADDITIONAL_CREATORS] = list( + set(creation_content.get(EventContentFields.ADDITIONAL_CREATORS, [])) + | set(invitees) + ) + + creation_event_with_context = None + if room_version.msc4291_room_ids_as_hashes: + creation_event_with_context = await self._generate_create_event_for_room_id( + requester, + creation_content, + is_public, + room_version, + ) + (create_event, _) = creation_event_with_context + room_id = create_event.room_id + else: + room_id = await self._generate_and_create_room_id( + creator_id=user_id, + is_public=is_public, + room_version=room_version, + ) # Check whether this visibility value is blocked by a third party module allowed_by_third_party_rules = await ( @@ -974,11 +1088,6 @@ class RoomCreationHandler: for val in raw_initial_state: initial_state[(val["type"], val.get("state_key", ""))] = val["content"] - creation_content = config.get("creation_content", {}) - - # override any attempt to set room versions via the creation_content - creation_content["room_version"] = room_version.identifier - ( last_stream_id, last_sent_event_id, @@ -995,6 +1104,7 @@ class RoomCreationHandler: power_level_content_override=power_level_content_override, creator_join_profile=creator_join_profile, ignore_forced_encryption=ignore_forced_encryption, + creation_event_with_context=creation_event_with_context, ) # we avoid dropping the lock between invites, as otherwise joins can @@ -1060,6 +1170,38 @@ class RoomCreationHandler: return room_id, room_alias, last_stream_id + async def _generate_create_event_for_room_id( + self, + creator: Requester, + creation_content: JsonDict, + is_public: bool, + room_version: RoomVersion, + ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]: + ( + creation_event, + new_unpersisted_context, + ) = await self.event_creation_handler.create_event( + creator, + { + "content": creation_content, + "sender": creator.user.to_string(), + "type": EventTypes.Create, + "state_key": "", + }, + prev_event_ids=[], + depth=1, + state_map={}, + for_batch=False, + ) + await self.store.store_room( + room_id=creation_event.room_id, + room_creator_user_id=creator.user.to_string(), + is_public=is_public, + room_version=room_version, + ) + creation_context = await new_unpersisted_context.persist(creation_event) + return (creation_event, creation_context) + async def _send_events_for_new_room( self, creator: Requester, @@ -1073,6 +1215,9 @@ class RoomCreationHandler: power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, ignore_forced_encryption: bool = False, + creation_event_with_context: Optional[ + Tuple[EventBase, synapse.events.snapshot.EventContext] + ] = None, ) -> Tuple[int, str, int]: """Sends the initial events into a new room. Sends the room creation, membership, and power level events into the room sequentially, then creates and batches up the @@ -1109,7 +1254,10 @@ class RoomCreationHandler: user in this room. ignore_forced_encryption: Ignore encryption forced by `encryption_enabled_by_default_for_room_type` setting. - + creation_event_with_context: + Set in MSC4291 rooms where the create event determines the room ID. If provided, + does not create an additional create event but instead appends the remaining new + events onto the provided create event. Returns: A tuple containing the stream ID, event ID and depth of the last event sent to the room. @@ -1174,13 +1322,26 @@ class RoomCreationHandler: preset_config, config = self._room_preset_config(room_config) - # MSC2175 removes the creator field from the create event. - if not room_version.implicit_room_creator: - creation_content["creator"] = creator_id - creation_event, unpersisted_creation_context = await create_event( - EventTypes.Create, creation_content, False - ) - creation_context = await unpersisted_creation_context.persist(creation_event) + if creation_event_with_context is None: + # MSC2175 removes the creator field from the create event. + if not room_version.implicit_room_creator: + creation_content["creator"] = creator_id + creation_event, unpersisted_creation_context = await create_event( + EventTypes.Create, creation_content, False + ) + creation_context = await unpersisted_creation_context.persist( + creation_event + ) + else: + (creation_event, creation_context) = creation_event_with_context + # we had to do the above already in order to have a room ID, so just updates local vars + # and continue. + depth = 2 + prev_event = [creation_event.event_id] + state_map[(creation_event.type, creation_event.state_key)] = ( + creation_event.event_id + ) + logger.debug("Sending %s in new room", EventTypes.Member) ev = await self.event_creation_handler.handle_new_client_event( requester=creator, @@ -1229,7 +1390,9 @@ class RoomCreationHandler: # Please update the docs for `default_power_level_content_override` when # updating the `events` dict below power_level_content: JsonDict = { - "users": {creator_id: 100}, + "users": {creator_id: 100} + if not room_version.msc4289_creator_power_enabled + else {}, "users_default": 0, "events": { EventTypes.Name: 50, @@ -1237,7 +1400,9 @@ class RoomCreationHandler: EventTypes.RoomHistoryVisibility: 100, EventTypes.CanonicalAlias: 50, EventTypes.RoomAvatar: 50, - EventTypes.Tombstone: 100, + EventTypes.Tombstone: 150 + if room_version.msc4289_creator_power_enabled + else 100, EventTypes.ServerACL: 100, EventTypes.RoomEncryption: 100, }, @@ -1250,7 +1415,13 @@ class RoomCreationHandler: "historical": 100, } - if config["original_invitees_have_ops"]: + # original_invitees_have_ops is set on preset:trusted_private_chat which will already + # have set these users as additional_creators, hence don't set the PL for creators as + # that is invalid. + if ( + config["original_invitees_have_ops"] + and not room_version.msc4289_creator_power_enabled + ): for invitee in invite_list: power_level_content["users"][invitee] = 100 @@ -1423,6 +1594,19 @@ class RoomCreationHandler: ) return preset_name, preset_config + def _remove_creators_from_pl_users_map( + self, + users_map: Dict[str, int], + creator: str, + additional_creators: Optional[List[str]], + ) -> None: + creators = [creator] + if additional_creators: + creators.extend(additional_creators) + for creator in creators: + # the creator(s) cannot be in the users map + users_map.pop(creator, None) + def _generate_room_id(self) -> str: """Generates a random room ID. @@ -1440,7 +1624,7 @@ class RoomCreationHandler: A random room ID of the form "!opaque_id:domain". """ random_string = stringutils.random_string(18) - return RoomID(random_string, self.hs.hostname).to_string() + return RoomIdWithDomain(random_string, self.hs.hostname).to_string() async def _generate_and_create_room_id( self, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 3ab2d29c75..5ba64912c9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -42,7 +42,7 @@ from synapse.api.errors import ( ) from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event -from synapse.events import EventBase +from synapse.events import EventBase, is_creator from synapse.events.snapshot import EventContext from synapse.handlers.pagination import PURGE_ROOM_ACTION_NAME from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN @@ -1160,9 +1160,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): elif effective_membership_state == Membership.KNOCK: if not is_host_in_room: - # The knock needs to be sent over federation instead - remote_room_hosts.append(get_domain_from_id(room_id)) - + # we used to add the domain of the room ID to remote_room_hosts. + # This is not safe in MSC4291 rooms which do not have a domain. content["membership"] = Membership.KNOCK try: @@ -2324,6 +2323,7 @@ def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[s # Check which members are able to invite by ensuring they're joined and have # the necessary power level. + create_event = auth_events[(EventTypes.Create, "")] for (event_type, state_key), event in auth_events.items(): if event_type != EventTypes.Member: continue @@ -2331,8 +2331,12 @@ def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[s if event.membership != Membership.JOIN: continue + if create_event.room_version.msc4289_creator_power_enabled and is_creator( + create_event, state_key + ): + result.append(state_key) # Check if the user has a custom power level. - if users.get(state_key, users_default_level) >= invite_level: + elif users.get(state_key, users_default_level) >= invite_level: result.append(state_key) return result diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f8c5bf18d4..efcc60a2de 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -627,6 +627,15 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): ] admin_users.sort(key=lambda user: user_power[user]) + if create_event.room_version.msc4289_creator_power_enabled: + creators = create_event.content.get("additional_creators", []) + [ + create_event.sender + ] + for creator in creators: + if self.is_mine_id(creator): + # include the creator as they won't be in the PL users map. + admin_users.insert(0, creator) + if not admin_users: raise SynapseError( HTTPStatus.BAD_REQUEST, "No local admin user in room" @@ -666,7 +675,11 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # updated power level event. new_pl_content = dict(pl_content) new_pl_content["users"] = dict(pl_content.get("users", {})) - new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id] + # give the new user the same PL as the admin, default to 100 in case there is no PL event. + # This means in v12+ rooms we get PL100 if the creator promotes us. + new_pl_content["users"][user_to_add] = new_pl_content["users"].get( + admin_user_id, 100 + ) fake_requester = create_requester( admin_user_id, diff --git a/synapse/rest/client/room_upgrade_rest_servlet.py b/synapse/rest/client/room_upgrade_rest_servlet.py index 130ae31619..a9717781b0 100644 --- a/synapse/rest/client/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/room_upgrade_rest_servlet.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Tuple from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.event_auth import check_valid_additional_creators from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.http.server import HttpServer from synapse.http.servlet import ( @@ -85,13 +86,18 @@ class RoomUpgradeRestServlet(RestServlet): "Your homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) + additional_creators = None + if new_version.msc4289_creator_power_enabled: + additional_creators = content.get("additional_creators") + if additional_creators is not None: + check_valid_additional_creators(additional_creators) try: async with self._worker_lock_handler.acquire_read_write_lock( NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False ): new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version + requester, room_id, new_version, additional_creators ) except ShadowBanError: # Generate a random room ID. diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index c9f952b817..3d8016c264 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -54,6 +54,7 @@ from synapse.logging.opentracing import tag_args, trace from synapse.metrics import SERVER_NAME_LABEL from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 +from synapse.storage.databases.main.event_federation import StateDifference from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import StateMap, StrCollection from synapse.types.state import StateFilter @@ -990,17 +991,35 @@ class StateResolutionStore: ) def get_auth_chain_difference( - self, room_id: str, state_sets: List[Set[str]] - ) -> Awaitable[Set[str]]: - """Given sets of state events figure out the auth chain difference (as + self, + room_id: str, + state_sets: List[Set[str]], + conflicted_state: Optional[Set[str]], + additional_backwards_reachable_conflicted_events: Optional[Set[str]], + ) -> Awaitable[StateDifference]: + """ "Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). - This equivalent to fetching the full auth chain for each set of state + This is equivalent to fetching the full auth chain for each set of state and returning the events that don't appear in each and every auth chain. + If conflicted_state is not None, calculate and return the conflicted sub-graph as per + state res v2.1. The event IDs in the conflicted state MUST be a subset of the event IDs in + state_sets. + + If additional_backwards_reachable_conflicted_events is set, the provided events are included + when calculating the conflicted subgraph. This is primarily useful for calculating the + subgraph across a combination of persisted and unpersisted events. + Returns: - An awaitable that resolves to a set of event IDs. + information on the auth chain difference, and also the conflicted subgraph if + conflicted_state is not None """ - return self.main_store.get_auth_chain_difference(room_id, state_sets) + return self.main_store.get_auth_chain_difference_extended( + room_id, + state_sets, + conflicted_state, + additional_backwards_reachable_conflicted_events, + ) diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 44b191d4e4..8bf6706434 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -39,10 +39,11 @@ from typing import ( ) from synapse import event_auth -from synapse.api.constants import EventTypes +from synapse.api.constants import CREATOR_POWER_LEVEL, EventTypes from synapse.api.errors import AuthError -from synapse.api.room_versions import RoomVersion -from synapse.events import EventBase +from synapse.api.room_versions import RoomVersion, StateResolutionVersions +from synapse.events import EventBase, is_creator +from synapse.storage.databases.main.event_federation import StateDifference from synapse.types import MutableStateMap, StateMap, StrCollection logger = logging.getLogger(__name__) @@ -63,8 +64,12 @@ class StateResolutionStore(Protocol): ) -> Awaitable[Dict[str, EventBase]]: ... def get_auth_chain_difference( - self, room_id: str, state_sets: List[Set[str]] - ) -> Awaitable[Set[str]]: ... + self, + room_id: str, + state_sets: List[Set[str]], + conflicted_state: Optional[Set[str]], + additional_backwards_reachable_conflicted_events: Optional[set[str]], + ) -> Awaitable[StateDifference]: ... # We want to await to the reactor occasionally during state res when dealing @@ -123,12 +128,17 @@ async def resolve_events_with_store( logger.debug("%d conflicted state entries", len(conflicted_state)) logger.debug("Calculating auth chain difference") - # Also fetch all auth events that appear in only some of the state sets' - # auth chains. + conflicted_set: Optional[Set[str]] = None + if room_version.state_res == StateResolutionVersions.V2_1: + # calculate the conflicted subgraph + conflicted_set = set(itertools.chain.from_iterable(conflicted_state.values())) auth_diff = await _get_auth_chain_difference( - room_id, state_sets, event_map, state_res_store + room_id, + state_sets, + event_map, + state_res_store, + conflicted_set, ) - full_conflicted_set = set( itertools.chain( itertools.chain.from_iterable(conflicted_state.values()), auth_diff @@ -168,15 +178,26 @@ async def resolve_events_with_store( logger.debug("sorted %d power events", len(sorted_power_events)) + # v2.1 starts iterative auth checks from the empty set and not the unconflicted state. + # It relies on IAC behaviour which populates the base state with the events from auth_events + # if the state tuple is missing from the base state. This ensures the base state is only + # populated from auth_events rather than whatever the unconflicted state is (which could be + # completely bogus). + base_state = ( + {} + if room_version.state_res == StateResolutionVersions.V2_1 + else unconflicted_state + ) + # Now sequentially auth each one resolved_state = await _iterative_auth_checks( clock, room_id, room_version, - sorted_power_events, - unconflicted_state, - event_map, - state_res_store, + event_ids=sorted_power_events, + base_state=base_state, + event_map=event_map, + state_res_store=state_res_store, ) logger.debug("resolved power events") @@ -239,13 +260,23 @@ async def _get_power_level_for_sender( event = await _get_event(room_id, event_id, event_map, state_res_store) pl = None + create = None for aid in event.auth_event_ids(): aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): pl = aev - break + if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): + create = aev + + if event.type != EventTypes.Create: + # we should always have a create event + assert create is not None + + if create and create.room_version.msc4289_creator_power_enabled: + if is_creator(create, event.sender): + return CREATOR_POWER_LEVEL if pl is None: # Couldn't find power level. Check if they're the creator of the room @@ -286,6 +317,7 @@ async def _get_auth_chain_difference( state_sets: Sequence[StateMap[str]], unpersisted_events: Dict[str, EventBase], state_res_store: StateResolutionStore, + conflicted_state: Optional[Set[str]], ) -> Set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some, but not all of the auth chains. @@ -294,11 +326,18 @@ async def _get_auth_chain_difference( state_sets: The input state sets we are trying to resolve across. unpersisted_events: A map from event ID to EventBase containing all unpersisted events involved in this resolution. - state_res_store: + state_res_store: A way to retrieve events and extract graph information on the auth chains. + conflicted_state: which event IDs are conflicted. Used in v2.1 for calculating the conflicted + subgraph. Returns: - The auth difference of the given state sets, as a set of event IDs. + The auth difference of the given state sets, as a set of event IDs. Also includes the + conflicted subgraph if `conflicted_state` is set. """ + is_state_res_v21 = conflicted_state is not None + num_conflicted_state = ( + len(conflicted_state) if conflicted_state is not None else None + ) # The `StateResolutionStore.get_auth_chain_difference` function assumes that # all events passed to it (and their auth chains) have been persisted @@ -318,14 +357,19 @@ async def _get_auth_chain_difference( # the event's auth chain with the events in `unpersisted_events` *plus* their # auth event IDs. events_to_auth_chain: Dict[str, Set[str]] = {} + # remember the forward links when doing the graph traversal, we'll need it for v2.1 checks + # This is a map from an event to the set of events that contain it as an auth event. + event_to_next_event: Dict[str, Set[str]] = {} for event in unpersisted_events.values(): chain = {event.event_id} events_to_auth_chain[event.event_id] = chain to_search = [event] while to_search: - for auth_id in to_search.pop().auth_event_ids(): + next_event = to_search.pop() + for auth_id in next_event.auth_event_ids(): chain.add(auth_id) + event_to_next_event.setdefault(auth_id, set()).add(next_event.event_id) auth_event = unpersisted_events.get(auth_id) if auth_event: to_search.append(auth_event) @@ -335,6 +379,8 @@ async def _get_auth_chain_difference( # # Note: If there are no `unpersisted_events` (which is the common case), we can do a # much simpler calculation. + additional_backwards_reachable_conflicted_events: Set[str] = set() + unpersisted_conflicted_events: Set[str] = set() if unpersisted_events: # The list of state sets to pass to the store, where each state set is a set # of the event ids making up the state. This is similar to `state_sets`, @@ -372,7 +418,16 @@ async def _get_auth_chain_difference( ) else: set_ids.add(event_id) - + if conflicted_state: + for conflicted_event_id in conflicted_state: + # presence in this map means it is unpersisted. + event_chain = events_to_auth_chain.get(conflicted_event_id) + if event_chain is not None: + unpersisted_conflicted_events.add(conflicted_event_id) + # tell the DB layer that we have some unpersisted conflicted events + additional_backwards_reachable_conflicted_events.update( + e for e in event_chain if e not in unpersisted_events + ) # The auth chain difference of the unpersisted events of the state sets # is calculated by taking the difference between the union and # intersections. @@ -384,12 +439,89 @@ async def _get_auth_chain_difference( auth_difference_unpersisted_part = () state_sets_ids = [set(state_set.values()) for state_set in state_sets] - difference = await state_res_store.get_auth_chain_difference( - room_id, state_sets_ids - ) - difference.update(auth_difference_unpersisted_part) + if conflicted_state: + # to ensure that conflicted state is a subset of state set IDs, we need to remove UNPERSISTED + # conflicted state set ids as we removed them above. + conflicted_state = conflicted_state - unpersisted_conflicted_events - return difference + difference = await state_res_store.get_auth_chain_difference( + room_id, + state_sets_ids, + conflicted_state, + additional_backwards_reachable_conflicted_events, + ) + difference.auth_difference.update(auth_difference_unpersisted_part) + + # if we're doing v2.1 we may need to add or expand the conflicted subgraph + if ( + is_state_res_v21 + and difference.conflicted_subgraph is not None + and unpersisted_events + ): + # we always include the conflicted events themselves in the subgraph. + if conflicted_state: + difference.conflicted_subgraph.update(conflicted_state) + # we may need to expand the subgraph in the case where the subgraph starts in the DB and + # ends in unpersisted events. To do this, we first need to see where the subgraph got up to, + # which we can do by finding the intersection between the additional backwards reachable + # conflicted events and the conflicted subgraph. Events in both sets mean A) some unpersisted + # conflicted event could backwards reach it and B) some persisted conflicted event could forward + # reach it. + subgraph_frontier = difference.conflicted_subgraph.intersection( + additional_backwards_reachable_conflicted_events + ) + # we can now combine the 2 scenarios: + # - subgraph starts in DB and ends in unpersisted + # - subgraph starts in unpersisted and ends in unpersisted + # by expanding the frontier into unpersisted events. + # The frontier is currently all persisted events. We want to expand this into unpersisted + # events. Mark every forwards reachable event from the frontier in the forwards_conflicted_set + # but NOT the backwards conflicted set. This mirrors what the DB layer does but in reverse: + # we supplied events which are backwards reachable to the DB and now the DB is providing + # forwards reachable events from the DB. + forwards_conflicted_set: Set[str] = set() + # we include unpersisted conflicted events here to process exclusive unpersisted subgraphs + search_queue = subgraph_frontier.union(unpersisted_conflicted_events) + while search_queue: + frontier_event = search_queue.pop() + next_event_ids = event_to_next_event.get(frontier_event, set()) + search_queue.update(next_event_ids) + forwards_conflicted_set.add(frontier_event) + + # we've already calculated the backwards form as this is the auth chain for each + # unpersisted conflicted event. + backwards_conflicted_set: Set[str] = set() + for uce in unpersisted_conflicted_events: + backwards_conflicted_set.update(events_to_auth_chain.get(uce, [])) + + # the unpersisted conflicted subgraph is the intersection of the backwards/forwards sets + conflicted_subgraph_unpersisted_part = backwards_conflicted_set.intersection( + forwards_conflicted_set + ) + # print(f"event_to_next_event={event_to_next_event}") + # print(f"unpersisted_conflicted_events={unpersisted_conflicted_events}") + # print(f"unperssited backwards_conflicted_set={backwards_conflicted_set}") + # print(f"unperssited forwards_conflicted_set={forwards_conflicted_set}") + difference.conflicted_subgraph.update(conflicted_subgraph_unpersisted_part) + + if difference.conflicted_subgraph: + old_events = difference.auth_difference.union( + conflicted_state if conflicted_state else set() + ) + additional_events = difference.conflicted_subgraph.difference(old_events) + + logger.debug( + "v2.1 %s additional events replayed=%d num_conflicts=%d conflicted_subgraph=%d auth_difference=%d", + room_id, + len(additional_events), + num_conflicted_state, + len(difference.conflicted_subgraph), + len(difference.auth_difference), + ) + # State res v2.1 includes the conflicted subgraph in the difference + return difference.auth_difference.union(difference.conflicted_subgraph) + + return difference.auth_difference def _seperate( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index bee34ef6a3..26a91109df 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -114,6 +114,12 @@ _LONGEST_BACKOFF_PERIOD_MILLISECONDS = ( assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1) +# We use 2^53-1 as a "very large number", it has no particular +# importance other than knowing synapse can support it (given canonical json +# requires it). +MAX_CHAIN_LENGTH = (2**53) - 1 + + # All the info we need while iterating the DAG while backfilling @attr.s(frozen=True, slots=True, auto_attribs=True) class BackfillQueueNavigationItem: @@ -123,6 +129,14 @@ class BackfillQueueNavigationItem: type: str +@attr.s(frozen=True, slots=True, auto_attribs=True) +class StateDifference: + # The event IDs in the auth difference. + auth_difference: Set[str] + # The event IDs in the conflicted state subgraph. Used in v2.1 only. + conflicted_subgraph: Optional[Set[str]] + + class _NoChainCoverIndex(Exception): def __init__(self, room_id: str): super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) @@ -471,17 +485,41 @@ class EventFederationWorkerStore( return results async def get_auth_chain_difference( - self, room_id: str, state_sets: List[Set[str]] + self, + room_id: str, + state_sets: List[Set[str]], ) -> Set[str]: - """Given sets of state events figure out the auth chain difference (as + state_diff = await self.get_auth_chain_difference_extended( + room_id, state_sets, None, None + ) + return state_diff.auth_difference + + async def get_auth_chain_difference_extended( + self, + room_id: str, + state_sets: List[Set[str]], + conflicted_set: Optional[Set[str]], + additional_backwards_reachable_conflicted_events: Optional[Set[str]], + ) -> StateDifference: + """ "Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). - This equivalent to fetching the full auth chain for each set of state + This is equivalent to fetching the full auth chain for each set of state and returning the events that don't appear in each and every auth chain. + If conflicted_set is not None, calculate and return the conflicted sub-graph as per + state res v2.1. The event IDs in the conflicted set MUST be a subset of the event IDs in + state_sets. + + If additional_backwards_reachable_conflicted_events is set, the provided events are included + when calculating the conflicted subgraph. This is primarily useful for calculating the + subgraph across a combination of persisted and unpersisted events. The event IDs in this set + MUST be a subset of the event IDs in state_sets. + Returns: - The set of the difference in auth chains. + information on the auth chain difference, and also the conflicted subgraph if + conflicted_set is not None """ # Check if we have indexed the room so we can use the chain cover @@ -495,6 +533,8 @@ class EventFederationWorkerStore( self._get_auth_chain_difference_using_cover_index_txn, room_id, state_sets, + conflicted_set, + additional_backwards_reachable_conflicted_events, ) except _NoChainCoverIndex: # For whatever reason we don't actually have a chain cover index @@ -503,25 +543,48 @@ class EventFederationWorkerStore( if not self.tests_allow_no_chain_cover_index: raise - return await self.db_pool.runInteraction( + # It's been 4 years since we added chain cover, so we expect all rooms to have it. + # If they don't, we will error out when trying to do state res v2.1 + if conflicted_set is not None: + raise _NoChainCoverIndex(room_id) + + auth_diff = await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, state_sets, ) + return StateDifference(auth_difference=auth_diff, conflicted_subgraph=None) def _get_auth_chain_difference_using_cover_index_txn( - self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]] - ) -> Set[str]: + self, + txn: LoggingTransaction, + room_id: str, + state_sets: List[Set[str]], + conflicted_set: Optional[Set[str]] = None, + additional_backwards_reachable_conflicted_events: Optional[Set[str]] = None, + ) -> StateDifference: """Calculates the auth chain difference using the chain index. See docs/auth_chain_difference_algorithm.md for details """ + is_state_res_v21 = conflicted_set is not None # First we look up the chain ID/sequence numbers for all the events, and # work out the chain/sequence numbers reachable from each state set. initial_events = set(state_sets[0]).union(*state_sets[1:]) + if is_state_res_v21: + # Sanity check v2.1 fields + assert conflicted_set is not None + assert conflicted_set.issubset(initial_events) + # It's possible for the conflicted_set to be empty if all the conflicts are in + # unpersisted events, so we don't assert that conflicted_set has len > 0 + if additional_backwards_reachable_conflicted_events: + assert additional_backwards_reachable_conflicted_events.issubset( + initial_events + ) + # Map from event_id -> (chain ID, seq no) chain_info: Dict[str, Tuple[int, int]] = {} @@ -557,14 +620,14 @@ class EventFederationWorkerStore( events_missing_chain_info = initial_events.difference(chain_info) # The result set to return, i.e. the auth chain difference. - result: Set[str] = set() + auth_difference_result: Set[str] = set() if events_missing_chain_info: # For some reason we have events we haven't calculated the chain # index for, so we need to handle those separately. This should only # happen for older rooms where the server doesn't have all the auth # events. - result = self._fixup_auth_chain_difference_sets( + auth_difference_result = self._fixup_auth_chain_difference_sets( txn, room_id, state_sets=state_sets, @@ -583,6 +646,45 @@ class EventFederationWorkerStore( fetch_chain_info(new_events_to_fetch) + # State Res v2.1 needs extra data structures to calculate the conflicted subgraph which + # are outlined below. + + # A subset of chain_info for conflicted events only, as we need to + # loop all conflicted chain positions. Map from event_id -> (chain ID, seq no) + conflicted_chain_positions: Dict[str, Tuple[int, int]] = {} + # For each chain, remember the positions where conflicted events are. + # We need this for calculating the forward reachable events. + conflicted_chain_to_seq: Dict[int, Set[int]] = {} # chain_id => {seq_num} + # A subset of chain_info for additional backwards reachable events only, as we need to + # loop all additional backwards reachable events for calculating backwards reachable events. + additional_backwards_reachable_positions: Dict[ + str, Tuple[int, int] + ] = {} # event_id => (chain_id, seq_num) + # These next two fields are critical as the intersection of them is the conflicted subgraph. + # We'll populate them when we walk the chain links. + # chain_id => max(seq_num) backwards reachable (e.g 4 means 1,2,3,4 are backwards reachable) + conflicted_backwards_reachable: Dict[int, int] = {} + # chain_id => min(seq_num) forwards reachable (e.g 4 means 4,5,6..n are forwards reachable) + conflicted_forwards_reachable: Dict[int, int] = {} + + # populate the v2.1 data structures + if is_state_res_v21: + assert conflicted_set is not None + # provide chain positions for each conflicted event + for conflicted_event_id in conflicted_set: + (chain_id, seq_num) = chain_info[conflicted_event_id] + conflicted_chain_positions[conflicted_event_id] = (chain_id, seq_num) + conflicted_chain_to_seq.setdefault(chain_id, set()).add(seq_num) + if additional_backwards_reachable_conflicted_events: + for ( + additional_event_id + ) in additional_backwards_reachable_conflicted_events: + (chain_id, seq_num) = chain_info[additional_event_id] + additional_backwards_reachable_positions[additional_event_id] = ( + chain_id, + seq_num, + ) + # Corresponds to `state_sets`, except as a map from chain ID to max # sequence number reachable from the state set. set_to_chain: List[Dict[int, int]] = [] @@ -600,6 +702,8 @@ class EventFederationWorkerStore( # (We need to take a copy of `seen_chains` as the function mutates it) for links in self._get_chain_links(txn, set(seen_chains)): + # `links` encodes the backwards reachable events _from a single chain_ all the way to + # the root of the graph. for chains in set_to_chain: for chain_id in links: if chain_id not in chains: @@ -608,6 +712,87 @@ class EventFederationWorkerStore( _materialize(chain_id, chains[chain_id], links, chains) seen_chains.update(chains) + if is_state_res_v21: + # Apply v2.1 conflicted event reachability checks. + # + # A <-- B <-- C <-- D <-- E + # + # Backwards reachable from C = {A,B} + # Forwards reachable from C = {D,E} + + # this handles calculating forwards reachable information and updates + # conflicted_forwards_reachable. + accumulate_forwards_reachable_events( + conflicted_forwards_reachable, + links, + conflicted_chain_positions, + ) + + # handle backwards reachable information + for ( + conflicted_chain_id, + conflicted_chain_seq, + ) in conflicted_chain_positions.values(): + if conflicted_chain_id not in links: + # This conflicted event does not lie on the path to the root. + continue + + # The conflicted chain position itself encodes reachability information + # _within_ the chain. Set it now before walking to other links. + conflicted_backwards_reachable[conflicted_chain_id] = max( + conflicted_chain_seq, + conflicted_backwards_reachable.get(conflicted_chain_id, 0), + ) + + # Build backwards reachability paths. This is the same as what the auth difference + # code does. We find which chain the conflicted event + # belongs to then walk it backwards to the root. We store reachability info + # for all conflicted events in the same map 'conflicted_backwards_reachable' + # as we don't care about the paths themselves. + _materialize( + conflicted_chain_id, + conflicted_chain_seq, + links, + conflicted_backwards_reachable, + ) + # Mark some extra events as backwards reachable. This is used when we have some + # unpersisted events and want to know the subgraph across the persisted/unpersisted + # boundary: + # | + # A <-- B <-- C <-|- D <-- E <-- F + # persisted | unpersisted + # + # Assume {B,E} are conflicted, we want to return {B,C,D,E} + # + # The unpersisted code ensures it passes C as an additional backwards reachable + # event. C is NOT a conflicted event, but we do need to consider it as part of + # the backwards reachable set. When we then calculate the forwards reachable set + # from B, C will be in both the backwards and forwards reachable sets and hence + # will be included in the conflicted subgraph. + for ( + additional_chain_id, + additional_chain_seq, + ) in additional_backwards_reachable_positions.values(): + if additional_chain_id not in links: + # The additional backwards reachable event does not lie on the path to the root. + continue + + # the additional event chain position itself encodes reachability information. + # It means that position and all positions earlier in that chain are backwards reachable + # by some unpersisted conflicted event. + conflicted_backwards_reachable[additional_chain_id] = max( + additional_chain_seq, + conflicted_backwards_reachable.get(additional_chain_id, 0), + ) + + # Now walk the chains back, marking backwards reachable events. + # This is the same thing we do for auth difference / conflicted events. + _materialize( + additional_chain_id, # walk all links back, marking them as backwards reachable + additional_chain_seq, + links, + conflicted_backwards_reachable, + ) # Now for each chain we figure out the maximum sequence number reachable # from *any* state set and the minimum sequence number reachable from @@ -616,7 +801,7 @@ class EventFederationWorkerStore( # Mapping from chain ID to the range of sequence numbers that should be # pulled from the database. - chain_to_gap: Dict[int, Tuple[int, int]] = {} + auth_diff_chain_to_gap: Dict[int, Tuple[int, int]] = {} for chain_id in seen_chains: min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) @@ -629,15 +814,76 @@ class EventFederationWorkerStore( for seq_no in range(min_seq_no + 1, max_seq_no + 1): event_id = chain_to_event.get(chain_id, {}).get(seq_no) if event_id: - result.add(event_id) + auth_difference_result.add(event_id) else: - chain_to_gap[chain_id] = (min_seq_no, max_seq_no) + auth_diff_chain_to_gap[chain_id] = (min_seq_no, max_seq_no) break - if not chain_to_gap: - # If there are no gaps to fetch, we're done! - return result + conflicted_subgraph_result: Set[str] = set() + # Mapping from chain ID to the range of sequence numbers that should be + # pulled from the database. + conflicted_subgraph_chain_to_gap: Dict[int, Tuple[int, int]] = {} + if is_state_res_v21: + # also include the conflicted subgraph using backward/forward reachability info from all + # the conflicted events. To calculate this, we want to extract the intersection between + # the backwards and forwards reachability sets, e.g: + # A <- B <- C <- D <- E + # Assume B and D are conflicted so we want {C} as the conflicted subgraph. + # B_backwards={A}, B_forwards={C,D,E} + # D_backwards={A,B,C} D_forwards={E} + # ALL_backwards={A,B,C} ALL_forwards={C,D,E} + # Intersection(ALL_backwards, ALL_forwards) = {C} + # + # It's worth noting that once we have the ALL_ sets, we no longer care about the paths. + # We're dealing with chains and not singular events, but we've already got the ALL_ sets. + # As such, we can inspect each chain in isolation and check for overlapping sequence + # numbers: + # 1,2,3,4,5 Seq Num + # Chain N [A,B,C,D,E] + # + # if (N,4) is in the backwards set and (N,2) is in the forwards set, then the + # intersection is events between 2 < 4. We will include the conflicted events themselves + # in the subgraph, but they will already be, hence the full set of events is {B,C,D}. + for chain_id, backwards_seq_num in conflicted_backwards_reachable.items(): + forwards_seq_num = conflicted_forwards_reachable.get(chain_id) + if forwards_seq_num is None: + continue # this chain isn't in both sets so can't intersect + if forwards_seq_num > backwards_seq_num: + continue # this chain is in both sets but they don't overap + for seq_no in range( + forwards_seq_num, backwards_seq_num + 1 + ): # inclusive of both + event_id = chain_to_event.get(chain_id, {}).get(seq_no) + if event_id: + conflicted_subgraph_result.add(event_id) + else: + conflicted_subgraph_chain_to_gap[chain_id] = ( + # _fetch_event_ids_from_chains_txn is exclusive of the min value + forwards_seq_num - 1, + backwards_seq_num, + ) + break + if auth_diff_chain_to_gap: + auth_difference_result.update( + self._fetch_event_ids_from_chains_txn(txn, auth_diff_chain_to_gap) + ) + if conflicted_subgraph_chain_to_gap: + conflicted_subgraph_result.update( + self._fetch_event_ids_from_chains_txn( + txn, conflicted_subgraph_chain_to_gap + ) + ) + + return StateDifference( + auth_difference=auth_difference_result, + conflicted_subgraph=conflicted_subgraph_result, + ) + + def _fetch_event_ids_from_chains_txn( + self, txn: LoggingTransaction, chains: Dict[int, Tuple[int, int]] + ) -> Set[str]: + result: Set[str] = set() if isinstance(self.database_engine, PostgresEngine): # We can use `execute_values` to efficiently fetch the gaps when # using postgres. @@ -651,7 +897,7 @@ class EventFederationWorkerStore( args = [ (chain_id, min_no, max_no) - for chain_id, (min_no, max_no) in chain_to_gap.items() + for chain_id, (min_no, max_no) in chains.items() ] rows = txn.execute_values(sql, args) @@ -662,10 +908,9 @@ class EventFederationWorkerStore( SELECT event_id FROM event_auth_chains WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? """ - for chain_id, (min_no, max_no) in chain_to_gap.items(): + for chain_id, (min_no, max_no) in chains.items(): txn.execute(sql, (chain_id, min_no, max_no)) result.update(r for (r,) in txn) - return result def _fixup_auth_chain_difference_sets( @@ -2165,6 +2410,7 @@ def _materialize( origin_sequence_number: int, links: Dict[int, List[Tuple[int, int, int]]], materialized: Dict[int, int], + backwards: bool = True, ) -> None: """Helper function for fetching auth chain links. For a given origin chain ID / sequence number and a dictionary of links, updates the materialized @@ -2181,6 +2427,7 @@ def _materialize( target sequence number. materialized: dict to update with new reachability information, as a map from chain ID to max sequence number reachable. + backwards: If True, walks backwards down the chains. If False, walks forwards from the chains. """ # Do a standard graph traversal. @@ -2195,12 +2442,104 @@ def _materialize( target_chain_id, target_sequence_number, ) in chain_links: - # Ignore any links that are higher up the chain - if sequence_number > s: - continue + if backwards: + # Ignore any links that are higher up the chain + if sequence_number > s: + continue - # Check if we have already visited the target chain before, if so we - # can skip it. - if materialized.get(target_chain_id, 0) < target_sequence_number: - stack.append((target_chain_id, target_sequence_number)) - materialized[target_chain_id] = target_sequence_number + # Check if we have already visited the target chain before, if so we + # can skip it. + if materialized.get(target_chain_id, 0) < target_sequence_number: + stack.append((target_chain_id, target_sequence_number)) + materialized[target_chain_id] = target_sequence_number + else: + # Ignore any links that are lower down the chain. + if sequence_number < s: + continue + # Check if we have already visited the target chain before, if so we + # can skip it. + if ( + materialized.get(target_chain_id, MAX_CHAIN_LENGTH) + > target_sequence_number + ): + stack.append((target_chain_id, target_sequence_number)) + materialized[target_chain_id] = target_sequence_number + + +def _generate_forward_links( + links: Dict[int, List[Tuple[int, int, int]]], +) -> Dict[int, List[Tuple[int, int, int]]]: + """Reverse the input links from the given backwards links""" + new_links: Dict[int, List[Tuple[int, int, int]]] = {} + for origin_chain_id, chain_links in links.items(): + for origin_seq_num, target_chain_id, target_seq_num in chain_links: + new_links.setdefault(target_chain_id, []).append( + (target_seq_num, origin_chain_id, origin_seq_num) + ) + return new_links + + +def accumulate_forwards_reachable_events( + conflicted_forwards_reachable: Dict[int, int], + back_links: Dict[int, List[Tuple[int, int, int]]], + conflicted_chain_positions: Dict[str, Tuple[int, int]], +) -> None: + """Accumulate new forwards reachable events using the back_links provided. + + Accumulating forwards reachable information is quite different from backwards reachable information + because _get_chain_links returns the entire linkage information for backwards reachable events, + but not _forwards_ reachable events. We are only interested in the forwards reachable information + that is encoded in the backwards reachable links, so we can just invert all the operations we do + for backwards reachable events to calculate a subset of forwards reachable information. The + caveat with this approach is that it is a _subset_. This means new back_links may encode new + forwards reachable information which we also need. Consider this scenario: + + A <-- B <-- C <--- D <-- E <-- F Chain 1 + | + `----- G <-- H <-- I Chain 2 + | + `---- J <-- K Chain 3 + + Now consider what happens when B is a conflicted event. _get_chain_links returns the conflicted + chain and ALL links heading towards the root of the graph. This means we will know the + Chain 1 to Chain 2 link via C (as all links for the chain are returned, not strictly ones with + a lower sequence number), but we will NOT know the Chain 2 to Chain 3 link via H. We can be + blissfully unaware of Chain 3 entirely, if and only if there isn't some other conflicted event + on that chain. Consider what happens when K is /also/ conflicted. _get_chain_links will generate + two iterations: one for B and one for K. It's important that we re-evaluate the forwards reachable + information for B to include Chain 3 when we process the K iteration, hence we are "accumulating" + forwards reachability information. + + NB: We don't consider 'additional backwards reachable events' here because they have no effect + on forwards reachability calculations, only backwards. + + Args: + conflicted_forwards_reachable: The materialised dict of forwards reachable information. + The output to this function are stored here. + back_links: One iteration of _get_chain_links which encodes backwards reachable information. + conflicted_chain_positions: The conflicted events. + """ + # links go backwards but we want them to go forwards as well for v2.1 + fwd_links = _generate_forward_links(back_links) + + # for each conflicted event, accumulate forwards reachability information + for ( + conflicted_chain_id, + conflicted_chain_seq, + ) in conflicted_chain_positions.values(): + # the conflicted event itself encodes reachability information + # e.g if D was conflicted, it encodes E,F as forwards reachable. + conflicted_forwards_reachable[conflicted_chain_id] = min( + conflicted_chain_seq, + conflicted_forwards_reachable.get(conflicted_chain_id, MAX_CHAIN_LENGTH), + ) + # Walk from the conflicted event forwards to explore the links. + # This function checks if we've visited the chain before and skips reprocessing, so this + # does not repeatedly traverse the graph. + _materialize( + conflicted_chain_id, + conflicted_chain_seq, + fwd_links, + conflicted_forwards_reachable, + backwards=False, + ) diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 0ea3a0a4a8..8850c2616f 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -355,12 +355,78 @@ class RoomAlias(DomainSpecificString): @attr.s(slots=True, frozen=True, repr=False) -class RoomID(DomainSpecificString): - """Structure representing a room id.""" +class RoomIdWithDomain(DomainSpecificString): + """Structure representing a room ID with a domain suffix.""" SIGIL = "!" +# the set of urlsafe base64 characters, no padding. +ROOM_ID_PATTERN_DOMAINLESS = re.compile(r"^[A-Za-z0-9\-_]{43}$") + + +@attr.define(slots=True, frozen=True, repr=False) +class RoomID: + """Structure representing a room id without a domain. + There are two forms of room IDs: + - "!localpart:domain" used in most room versions prior to MSC4291. + - "!event_id_base_64" used in room versions post MSC4291. + This class will accept any room ID which meets either of these two criteria. + """ + + SIGIL = "!" + id: str + room_id_with_domain: Optional[RoomIdWithDomain] + + @classmethod + def is_valid(cls: Type["RoomID"], s: str) -> bool: + if ":" in s: + return RoomIdWithDomain.is_valid(s) + try: + cls.from_string(s) + return True + except Exception: + return False + + def get_domain(self) -> Optional[str]: + if not self.room_id_with_domain: + return None + return self.room_id_with_domain.domain + + def to_string(self) -> str: + if self.room_id_with_domain: + return self.room_id_with_domain.to_string() + return self.id + + __repr__ = to_string + + @classmethod + def from_string(cls: Type["RoomID"], s: str) -> "RoomID": + # sigil check + if len(s) < 1 or s[0] != cls.SIGIL: + raise SynapseError( + 400, + "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL), + Codes.INVALID_PARAM, + ) + + room_id_with_domain: Optional[RoomIdWithDomain] = None + if ":" in s: + room_id_with_domain = RoomIdWithDomain.from_string(s) + else: + # MSC4291 room IDs must be valid urlsafe unpadded base64 + val = s[1:] + if not ROOM_ID_PATTERN_DOMAINLESS.match(val): + raise SynapseError( + 400, + "Expected %s string to be valid urlsafe unpadded base64 '%s'" + % (cls.__name__, val), + Codes.INVALID_PARAM, + ) + + return cls(id=s, room_id_with_domain=room_id_with_domain) + + @attr.s(slots=True, frozen=True, repr=False) class EventID(DomainSpecificString): """Structure representing an event ID which is namespaced to a homeserver. diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a07264463c..2dff58bf12 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -33,6 +33,7 @@ from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import Codes +from synapse.api.room_versions import RoomVersions from synapse.handlers.pagination import ( PURGE_ROOM_ACTION_NAME, SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME, @@ -2892,6 +2893,30 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): "No local admin user in room with power to update power levels.", ) + def test_v12_room(self) -> None: + """Test that you can be promoted to admin in v12 rooms which won't have the admin the PL event.""" + room_id = self.helper.create_room_as( + self.creator, + tok=self.creator_tok, + room_version=RoomVersions.V12.identifier, + ) + + channel = self.make_request( + "POST", + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Now we test that we can join the room and that the admin user has PL 100. + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + pl = self.helper.get_state( + room_id, EventTypes.PowerLevels, tok=self.creator_tok + ) + self.assertEquals(pl["users"][self.admin_user], 100) + class BlockRoomTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 5a0096d8cb..b4f2b98cc4 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -45,6 +45,7 @@ from synapse.state.v2 import ( lexicographical_topological_sort, resolve_events_with_store, ) +from synapse.storage.databases.main.event_federation import StateDifference from synapse.types import EventID, StateMap from tests import unittest @@ -734,7 +735,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) diff_d = _get_auth_chain_difference( - ROOM_ID, state_sets, unpersited_events, store + ROOM_ID, + state_sets, + unpersited_events, + store, + None, ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) @@ -791,7 +796,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) diff_d = _get_auth_chain_difference( - ROOM_ID, state_sets, unpersited_events, store + ROOM_ID, + state_sets, + unpersited_events, + store, + None, ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) @@ -858,7 +867,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase): store = TestStateResolutionStore(persisted_events) diff_d = _get_auth_chain_difference( - ROOM_ID, state_sets, unpersited_events, store + ROOM_ID, + state_sets, + unpersited_events, + store, + None, ) difference = self.successResultOf(defer.ensureDeferred(diff_d)) @@ -1070,9 +1083,18 @@ class TestStateResolutionStore: return list(result) def get_auth_chain_difference( - self, room_id: str, auth_sets: List[Set[str]] - ) -> "defer.Deferred[Set[str]]": + self, + room_id: str, + auth_sets: List[Set[str]], + conflicted_state: Optional[Set[str]], + additional_backwards_reachable_conflicted_events: Optional[Set[str]], + ) -> "defer.Deferred[StateDifference]": chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) - return defer.succeed(set(chains[0]).union(*chains[1:]) - common) + return defer.succeed( + StateDifference( + auth_difference=set(chains[0]).union(*chains[1:]) - common, + conflicted_subgraph=set(), + ), + ) diff --git a/tests/state/test_v21.py b/tests/state/test_v21.py new file mode 100644 index 0000000000..b40c1f125e --- /dev/null +++ b/tests/state/test_v21.py @@ -0,0 +1,503 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# [This file includes modifications made by New Vector Limited] +# +# +import itertools +from typing import Dict, List, Optional, Sequence, Set + +from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes, JoinRules, Membership +from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase +from synapse.federation.federation_base import event_from_pdu_json +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.state import StateResolutionStore +from synapse.state.v2 import ( + StateResolutionStore as StateResolutionStoreInterface, + _get_auth_chain_difference, + _seperate, + resolve_events_with_store, +) +from synapse.types import StateMap +from synapse.util import Clock + +from tests import unittest +from tests.state.test_v2 import TestStateResolutionStore + +ALICE = "@alice:example.com" +BOB = "@bob:example.com" +CHARLIE = "@charlie:example.com" +EVELYN = "@evelyn:example.com" +ZARA = "@zara:example.com" + +ROOM_ID = "!test:example.com" + +MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN} +MEMBERSHIP_CONTENT_INVITE = {"membership": Membership.INVITE} +MEMBERSHIP_CONTENT_LEAVE = {"membership": Membership.LEAVE} + + +ORIGIN_SERVER_TS = 0 + + +def monotonic_timestamp() -> int: + global ORIGIN_SERVER_TS + ORIGIN_SERVER_TS += 1 + return ORIGIN_SERVER_TS + + +class FakeClock: + def sleep(self, msec: float) -> "defer.Deferred[None]": + return defer.succeed(None) + + +class StateResV21TestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.state = self.hs.get_state_handler() + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self._persistence = persistence + self._state_storage_controller = self.hs.get_storage_controllers().state + self._state_deletion = self.hs.get_datastores().state_deletion + self.store = self.hs.get_datastores().main + + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + def test_state_reset_replay_conflicted_subgraph(self) -> None: + # 1. Alice creates a room. + e1_create = self.create_event( + EventTypes.Create, + "", + sender=ALICE, + content={"creator": ALICE}, + auth_events=[], + ) + # 2. Alice joins it. + e2_ma = self.create_event( + EventTypes.Member, + ALICE, + sender=ALICE, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[], + prev_events=[e1_create.event_id], + room_id=e1_create.room_id, + ) + # 3. Alice is the creator + e3_power1 = self.create_event( + EventTypes.PowerLevels, + "", + sender=ALICE, + content={"users": {}}, + auth_events=[e2_ma.event_id], + room_id=e1_create.room_id, + ) + # 4. Alice sets the room to public. + e4_jr = self.create_event( + EventTypes.JoinRules, + "", + sender=ALICE, + content={"join_rule": JoinRules.PUBLIC}, + auth_events=[e2_ma.event_id, e3_power1.event_id], + room_id=e1_create.room_id, + ) + # 5. Bob joins the room. + e5_mb = self.create_event( + EventTypes.Member, + BOB, + sender=BOB, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[e3_power1.event_id, e4_jr.event_id], + room_id=e1_create.room_id, + ) + # 6. Charlie joins the room. + e6_mc = self.create_event( + EventTypes.Member, + CHARLIE, + sender=CHARLIE, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[e3_power1.event_id, e4_jr.event_id], + room_id=e1_create.room_id, + ) + # 7. Alice promotes Bob. + e7_power2 = self.create_event( + EventTypes.PowerLevels, + "", + sender=ALICE, + content={"users": {BOB: 50}}, + auth_events=[e2_ma.event_id, e3_power1.event_id], + room_id=e1_create.room_id, + ) + # 8. Bob promotes Charlie. + e8_power3 = self.create_event( + EventTypes.PowerLevels, + "", + sender=BOB, + content={"users": {BOB: 50, CHARLIE: 50}}, + auth_events=[e5_mb.event_id, e7_power2.event_id], + room_id=e1_create.room_id, + ) + # 9. Eve joins the room. + e9_me1 = self.create_event( + EventTypes.Member, + EVELYN, + sender=EVELYN, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[e8_power3.event_id, e4_jr.event_id], + room_id=e1_create.room_id, + ) + # 10. Eve changes her name, /!\\ but cites old power levels /!\ + e10_me2 = self.create_event( + EventTypes.Member, + EVELYN, + sender=EVELYN, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[ + e3_power1.event_id, + e4_jr.event_id, + e9_me1.event_id, + ], + room_id=e1_create.room_id, + ) + # 11. Zara joins the room, citing the most recent power levels. + e11_mz = self.create_event( + EventTypes.Member, + ZARA, + sender=ZARA, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[e8_power3.event_id, e4_jr.event_id], + room_id=e1_create.room_id, + ) + + # Event 10 above is DODGY: it directly cites old auth events, but indirectly + # cites new ones. If the state after event 10 contains old power level and old + # join events, we are vulnerable to a reset. + + dodgy_state_after_eve_rename: StateMap[str] = { + (EventTypes.Create, ""): e1_create.event_id, + (EventTypes.Member, ALICE): e2_ma.event_id, + (EventTypes.Member, BOB): e5_mb.event_id, + (EventTypes.Member, CHARLIE): e6_mc.event_id, + (EventTypes.Member, EVELYN): e10_me2.event_id, + (EventTypes.PowerLevels, ""): e3_power1.event_id, # old and /!\\ DODGY /!\ + (EventTypes.JoinRules, ""): e4_jr.event_id, + } + + sensible_state_after_zara_joins: StateMap[str] = { + (EventTypes.Create, ""): e1_create.event_id, + (EventTypes.Member, ALICE): e2_ma.event_id, + (EventTypes.Member, BOB): e5_mb.event_id, + (EventTypes.Member, CHARLIE): e6_mc.event_id, + (EventTypes.Member, ZARA): e11_mz.event_id, + (EventTypes.PowerLevels, ""): e8_power3.event_id, + (EventTypes.JoinRules, ""): e4_jr.event_id, + } + + expected: StateMap[str] = { + (EventTypes.Create, ""): e1_create.event_id, + (EventTypes.Member, ALICE): e2_ma.event_id, + (EventTypes.Member, BOB): e5_mb.event_id, + (EventTypes.Member, CHARLIE): e6_mc.event_id, + # Expect ME2 replayed first: it's in the POWER 1 epoch + # Then ME1, in the POWER 3 epoch + (EventTypes.Member, EVELYN): e9_me1.event_id, + (EventTypes.Member, ZARA): e11_mz.event_id, + (EventTypes.PowerLevels, ""): e8_power3.event_id, + (EventTypes.JoinRules, ""): e4_jr.event_id, + } + + self.get_resolution_and_verify_expected( + [dodgy_state_after_eve_rename, sensible_state_after_zara_joins], + [ + e1_create, + e2_ma, + e3_power1, + e4_jr, + e5_mb, + e6_mc, + e7_power2, + e8_power3, + e9_me1, + e10_me2, + e11_mz, + ], + expected, + ) + + def test_state_reset_start_empty_set(self) -> None: + # The join rules reset to missing, when: + # - join rules were in conflict + # - the membership of those join rules' senders were not in conflict + # - those memberships are all leaves. + + # 1. Alice creates a room. + e1_create = self.create_event( + EventTypes.Create, + "", + sender=ALICE, + content={"creator": ALICE}, + auth_events=[], + ) + # 2. Alice joins it. + e2_ma1 = self.create_event( + EventTypes.Member, + ALICE, + sender=ALICE, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[], + room_id=e1_create.room_id, + ) + # 3. Alice makes Bob an admin. + e3_power = self.create_event( + EventTypes.PowerLevels, + "", + sender=ALICE, + content={"users": {BOB: 100}}, + auth_events=[e2_ma1.event_id], + room_id=e1_create.room_id, + ) + # 4. Alice sets the room to public. + e4_jr1 = self.create_event( + EventTypes.JoinRules, + "", + sender=ALICE, + content={"join_rule": JoinRules.PUBLIC}, + auth_events=[e2_ma1.event_id, e3_power.event_id], + room_id=e1_create.room_id, + ) + # 5. Bob joins. + e5_mb = self.create_event( + EventTypes.Member, + BOB, + sender=BOB, + content=MEMBERSHIP_CONTENT_JOIN, + auth_events=[e3_power.event_id, e4_jr1.event_id], + room_id=e1_create.room_id, + ) + # 6. Alice sets join rules to invite. + e6_jr2 = self.create_event( + EventTypes.JoinRules, + "", + sender=ALICE, + content={"join_rule": JoinRules.INVITE}, + auth_events=[e2_ma1.event_id, e3_power.event_id], + room_id=e1_create.room_id, + ) + # 7. Alice then leaves. + e7_ma2 = self.create_event( + EventTypes.Member, + ALICE, + sender=ALICE, + content=MEMBERSHIP_CONTENT_LEAVE, + auth_events=[e3_power.event_id, e2_ma1.event_id], + room_id=e1_create.room_id, + ) + + correct_state: StateMap[str] = { + (EventTypes.Create, ""): e1_create.event_id, + (EventTypes.Member, ALICE): e7_ma2.event_id, + (EventTypes.Member, BOB): e5_mb.event_id, + (EventTypes.PowerLevels, ""): e3_power.event_id, + (EventTypes.JoinRules, ""): e6_jr2.event_id, + } + + # Imagine that another server gives us incorrect state on a fork + # (via e.g. backfill). It cites the old join rules. + incorrect_state: StateMap[str] = { + (EventTypes.Create, ""): e1_create.event_id, + (EventTypes.Member, ALICE): e7_ma2.event_id, + (EventTypes.Member, BOB): e5_mb.event_id, + (EventTypes.PowerLevels, ""): e3_power.event_id, + (EventTypes.JoinRules, ""): e4_jr1.event_id, + } + + # Resolving those two should give us the new join rules. + expected: StateMap[str] = { + (EventTypes.Create, ""): e1_create.event_id, + (EventTypes.Member, ALICE): e7_ma2.event_id, + (EventTypes.Member, BOB): e5_mb.event_id, + (EventTypes.PowerLevels, ""): e3_power.event_id, + (EventTypes.JoinRules, ""): e6_jr2.event_id, + } + + self.get_resolution_and_verify_expected( + [correct_state, incorrect_state], + [e1_create, e2_ma1, e3_power, e4_jr1, e5_mb, e6_jr2, e7_ma2], + expected, + ) + + async def _get_auth_difference_and_conflicted_subgraph( + self, + room_id: str, + state_maps: Sequence[StateMap[str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: StateResolutionStoreInterface, + ) -> Set[str]: + _, conflicted_state = _seperate(state_maps) + conflicted_set: Optional[Set[str]] = set( + itertools.chain.from_iterable(conflicted_state.values()) + ) + if event_map is None: + event_map = {} + return await _get_auth_chain_difference( + room_id, + state_maps, + event_map, + state_res_store, + conflicted_set, + ) + + def get_resolution_and_verify_expected( + self, + state_maps: Sequence[StateMap[str]], + events: List[EventBase], + expected: StateMap[str], + ) -> None: + room_id = events[0].room_id + # First we try everything in-memory to check that the test case works. + event_map = {ev.event_id: ev for ev in events} + resolution = self.successResultOf( + resolve_events_with_store( + FakeClock(), + room_id, + events[0].room_version, + state_maps, + event_map=event_map, + state_res_store=TestStateResolutionStore(event_map), + ) + ) + self.assertEqual(resolution, expected) + + got_auth_diff = self.successResultOf( + self._get_auth_difference_and_conflicted_subgraph( + room_id, + state_maps, + event_map, + TestStateResolutionStore(event_map), + ) + ) + # we should never see the create event in the auth diff. If we do, this implies the + # conflicted subgraph is wrong and is returning too many old events. + assert events[0].event_id not in got_auth_diff + + # now let's make the room exist on the DB, some queries rely on there being a row in + # the rooms table when persisting + self.get_success( + self.store.store_room( + room_id, + events[0].sender, + True, + events[0].room_version, + ) + ) + + def resolve_and_check() -> None: + event_map = {ev.event_id: ev for ev in events} + store = StateResolutionStore( + self._persistence.main_store, + self._state_deletion, + ) + resolution = self.get_success( + resolve_events_with_store( + FakeClock(), + room_id, + RoomVersions.HydraV11, + state_maps, + event_map=event_map, + state_res_store=store, + ) + ) + self.assertEqual(resolution, expected) + got_auth_diff2 = self.get_success( + self._get_auth_difference_and_conflicted_subgraph( + room_id, + state_maps, + event_map, + store, + ) + ) + # no matter how many events are persisted, the overall diff should always be the same. + self.assertEquals(got_auth_diff, got_auth_diff2) + + # now we will drip feed in `events` one-by-one, persisting them then resolving with the + # rest. This ensures we correctly handle mixed persisted/unpersisted events. We will finish + # with doing the test with all persisted events. + while len(events) > 0: + event_to_persist = events.pop(0) + self.persist_event(event_to_persist) + # now retest + resolve_and_check() + + def persist_event( + self, event: EventBase, state: Optional[StateMap[str]] = None + ) -> None: + """Persist the event, with optional state""" + context = self.get_success( + self.state.compute_event_context( + event, + state_ids_before_event=state, + partial_state=None if state is None else False, + ) + ) + self.get_success(self._persistence.persist_event(event, context)) + + def create_event( + self, + event_type: str, + state_key: Optional[str], + sender: str, + content: Dict, + auth_events: List[str], + prev_events: Optional[List[str]] = None, + room_id: Optional[str] = None, + ) -> EventBase: + """Short-hand for event_from_pdu_json for fields we typically care about. + Tests can override by just calling event_from_pdu_json directly.""" + if prev_events is None: + prev_events = [] + + pdu = { + "type": event_type, + "state_key": state_key, + "content": content, + "sender": sender, + "depth": 5, + "prev_events": prev_events, + "auth_events": auth_events, + "origin_server_ts": monotonic_timestamp(), + } + if event_type != EventTypes.Create: + if room_id is None: + raise Exception("must specify a room_id to create_event") + pdu["room_id"] = room_id + return event_from_pdu_json( + pdu, + RoomVersions.HydraV11, + ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index a943ed975a..2f79068f6b 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -26,6 +26,7 @@ from typing import ( Iterable, List, Mapping, + NamedTuple, Set, Tuple, TypeVar, @@ -730,6 +731,202 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): ) self.assertSetEqual(difference, set()) + def test_conflicted_subgraph(self) -> None: + """Test that the conflicted subgraph code in state res v2.1 can walk arbitrary length links + and chains. + + We construct a chain cover index like this: + + A1 <- A2 <- A3 + ^--- B1 <- B2 <- B3 + ^--- C1 <- C2 <- C3 v--- G1 <- G2 + ^--- D1 <- D2 <- D3 + ^--- E1 + ^--- F1 <- F2 + + ..and then pick various events to be conflicted / additional backwards reachable to assert + that the code walks the chains correctly. We're particularly interested in ensuring that + the code walks multiple links between chains, hence why we have so many chains. + """ + + class TestNode(NamedTuple): + event_id: str + chain_id: int + seq_num: int + + class TestLink(NamedTuple): + origin_chain_and_seq: Tuple[int, int] + target_chain_and_seq: Tuple[int, int] + + # Map to chain IDs / seq nums + nodes: List[TestNode] = [ + TestNode("A1", 1, 1), + TestNode("A2", 1, 2), + TestNode("A3", 1, 3), + TestNode("B1", 2, 1), + TestNode("B2", 2, 2), + TestNode("B3", 2, 3), + TestNode("C1", 3, 1), + TestNode("C2", 3, 2), + TestNode("C3", 3, 3), + TestNode("D1", 4, 1), + TestNode("D2", 4, 2), + TestNode("D3", 4, 3), + TestNode("E1", 5, 1), + TestNode("F1", 6, 1), + TestNode("F2", 6, 2), + TestNode("G1", 7, 1), + TestNode("G2", 7, 2), + ] + links: List[TestLink] = [ + TestLink((2, 1), (1, 2)), # B1 -> A2 + TestLink((3, 1), (2, 2)), # C1 -> B2 + TestLink((4, 1), (3, 1)), # D1 -> C1 + TestLink((5, 1), (4, 2)), # E1 -> D2 + TestLink((6, 1), (5, 1)), # F1 -> E1 + TestLink((7, 1), (4, 3)), # G1 -> D3 + ] + + # populate the chain cover index tables as that's all we need + for node in nodes: + self.get_success( + self.store.db_pool.simple_insert( + "event_auth_chains", + { + "event_id": node.event_id, + "chain_id": node.chain_id, + "sequence_number": node.seq_num, + }, + desc="insert", + ) + ) + for link in links: + self.get_success( + self.store.db_pool.simple_insert( + "event_auth_chain_links", + { + "origin_chain_id": link.origin_chain_and_seq[0], + "origin_sequence_number": link.origin_chain_and_seq[1], + "target_chain_id": link.target_chain_and_seq[0], + "target_sequence_number": link.target_chain_and_seq[1], + }, + desc="insert", + ) + ) + + # Define the test cases + class TestCase(NamedTuple): + name: str + conflicted: Set[str] + additional_backwards_reachable: Set[str] + want_conflicted_subgraph: Set[str] + + # Reminder: + # A1 <- A2 <- A3 + # ^--- B1 <- B2 <- B3 + # ^--- C1 <- C2 <- C3 v--- G1 <- G2 + # ^--- D1 <- D2 <- D3 + # ^--- E1 + # ^--- F1 <- F2 + test_cases = [ + TestCase( + name="basic_single_chain", + conflicted={"B1", "B3"}, + additional_backwards_reachable=set(), + want_conflicted_subgraph={"B1", "B2", "B3"}, + ), + TestCase( + name="basic_single_link", + conflicted={"A1", "B2"}, + additional_backwards_reachable=set(), + want_conflicted_subgraph={"A1", "A2", "B1", "B2"}, + ), + TestCase( + name="basic_multi_link", + conflicted={"B1", "F1"}, + additional_backwards_reachable=set(), + want_conflicted_subgraph={"B1", "B2", "C1", "D1", "D2", "E1", "F1"}, + ), + # Repeat these tests but put the later event as an additional backwards reachable event. + # The output should be the same. + TestCase( + name="basic_single_chain_as_additional", + conflicted={"B1"}, + additional_backwards_reachable={"B3"}, + want_conflicted_subgraph={"B1", "B2", "B3"}, + ), + TestCase( + name="basic_single_link_as_additional", + conflicted={"A1"}, + additional_backwards_reachable={"B2"}, + want_conflicted_subgraph={"A1", "A2", "B1", "B2"}, + ), + TestCase( + name="basic_multi_link_as_additional", + conflicted={"B1"}, + additional_backwards_reachable={"F1"}, + want_conflicted_subgraph={"B1", "B2", "C1", "D1", "D2", "E1", "F1"}, + ), + TestCase( + name="mixed_multi_link", + conflicted={"D1", "F1"}, + additional_backwards_reachable={"G1"}, + want_conflicted_subgraph={"D1", "D2", "D3", "E1", "F1", "G1"}, + ), + TestCase( + name="additional_backwards_doesnt_add_forwards_info", + conflicted={"C1", "C3"}, + # This is on the path to the root but Chain C isn't forwards reachable so this doesn't + # change anything. + additional_backwards_reachable={"B1"}, + want_conflicted_subgraph={"C1", "C2", "C3"}, + ), + TestCase( + name="empty_subgraph", + conflicted={ + "B3", + "C3", + }, # these can't reach each other going forwards, so the subgraph is empty + additional_backwards_reachable=set(), + want_conflicted_subgraph={"B3", "C3"}, + ), + TestCase( + name="empty_subgraph_with_additional", + conflicted={"C1"}, + # This is on the path to the root but Chain C isn't forwards reachable so this doesn't + # change anything. + additional_backwards_reachable={"B1"}, + want_conflicted_subgraph={"C1"}, + ), + TestCase( + name="empty_subgraph_single_conflict", + conflicted={"C1"}, # no subgraph can form as you need 2+ + additional_backwards_reachable=set(), + want_conflicted_subgraph={"C1"}, + ), + ] + + def run_test(txn: LoggingTransaction, test_case: TestCase) -> None: + result = self.store._get_auth_chain_difference_using_cover_index_txn( + txn, + "!not_relevant", + [test_case.conflicted.union(test_case.additional_backwards_reachable)], + test_case.conflicted, + test_case.additional_backwards_reachable, + ) + self.assertEquals( + result.conflicted_subgraph, + test_case.want_conflicted_subgraph, + f"{test_case.name} : conflicted subgraph mismatch", + ) + + for test_case in test_cases: + self.get_success( + self.store.db_pool.runInteraction( + f"test_case_{test_case.name}", run_test, test_case + ) + ) + @parameterized.expand( [(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()] ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index f4469a8d8d..6d2e4e4bbe 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -66,6 +66,10 @@ class ExtremPruneTestCase(HomeserverTestCase): body = self.helper.send(self.room_id, body="Test", tok=self.token) local_message_event_id = body["event_id"] + current_state = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + # Fudge a remote event and persist it. This will be the extremity before # the gap. self.remote_event_1 = event_from_pdu_json( @@ -77,7 +81,11 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other", "depth": 5, "prev_events": [local_message_event_id], - "auth_events": [], + "auth_events": [ + current_state.get((EventTypes.Create, "")), + current_state.get((EventTypes.PowerLevels, "")), + current_state.get((EventTypes.JoinRules, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, @@ -113,6 +121,10 @@ class ExtremPruneTestCase(HomeserverTestCase): the same domain. """ + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + # Fudge a second event which points to an event we don't have. This is a # state event so that the state changes (otherwise we won't prune the # extremity as they'll have the same state group). @@ -125,16 +137,16 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other", "depth": 50, "prev_events": ["$some_unknown_message"], - "auth_events": [], + "auth_events": [ + state_before_gap.get((EventTypes.Create, "")), + state_before_gap.get((EventTypes.PowerLevels, "")), + state_before_gap.get((EventTypes.JoinRules, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, ) - state_before_gap = self.get_success( - self._state_storage_controller.get_current_state_ids(self.room_id) - ) - self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. @@ -145,6 +157,15 @@ class ExtremPruneTestCase(HomeserverTestCase): state is different. """ + # Now we persist it with state with a dropped history visibility + # setting. The state resolution across the old and new event will then + # include it, and so the resolved state won't match the new state. + state_before_gap = dict( + self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + ) + # Fudge a second event which points to an event we don't have. remote_event_2 = event_from_pdu_json( { @@ -155,20 +176,15 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other", "depth": 10, "prev_events": ["$some_unknown_message"], - "auth_events": [], + "auth_events": [ + state_before_gap.get((EventTypes.Create, "")), + state_before_gap.get((EventTypes.PowerLevels, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, ) - # Now we persist it with state with a dropped history visibility - # setting. The state resolution across the old and new event will then - # include it, and so the resolved state won't match the new state. - state_before_gap = dict( - self.get_success( - self._state_storage_controller.get_current_state_ids(self.room_id) - ) - ) state_before_gap.pop(("m.room.history_visibility", "")) context = self.get_success( @@ -193,6 +209,10 @@ class ExtremPruneTestCase(HomeserverTestCase): # also set the depth to "lots". self.reactor.advance(7 * 24 * 60 * 60) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + # Fudge a second event which points to an event we don't have. This is a # state event so that the state changes (otherwise we won't prune the # extremity as they'll have the same state group). @@ -205,16 +225,16 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other2", "depth": 10000, "prev_events": ["$some_unknown_message"], - "auth_events": [], + "auth_events": [ + state_before_gap.get((EventTypes.Create, "")), + state_before_gap.get((EventTypes.PowerLevels, "")), + state_before_gap.get((EventTypes.JoinRules, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, ) - state_before_gap = self.get_success( - self._state_storage_controller.get_current_state_ids(self.room_id) - ) - self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. @@ -225,6 +245,10 @@ class ExtremPruneTestCase(HomeserverTestCase): from a different domain. """ + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + # Fudge a second event which points to an event we don't have. This is a # state event so that the state changes (otherwise we won't prune the # extremity as they'll have the same state group). @@ -237,16 +261,15 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other2", "depth": 10, "prev_events": ["$some_unknown_message"], - "auth_events": [], + "auth_events": [ + state_before_gap.get((EventTypes.Create, "")), + state_before_gap.get((EventTypes.PowerLevels, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, ) - state_before_gap = self.get_success( - self._state_storage_controller.get_current_state_ids(self.room_id) - ) - self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. @@ -267,6 +290,10 @@ class ExtremPruneTestCase(HomeserverTestCase): # also set the depth to "lots". self.reactor.advance(7 * 24 * 60 * 60) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + # Fudge a second event which points to an event we don't have. This is a # state event so that the state changes (otherwise we won't prune the # extremity as they'll have the same state group). @@ -279,16 +306,16 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other2", "depth": 10000, "prev_events": ["$some_unknown_message"], - "auth_events": [], + "auth_events": [ + state_before_gap.get((EventTypes.Create, "")), + state_before_gap.get((EventTypes.PowerLevels, "")), + state_before_gap.get((EventTypes.JoinRules, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, ) - state_before_gap = self.get_success( - self._state_storage_controller.get_current_state_ids(self.room_id) - ) - self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. @@ -311,6 +338,10 @@ class ExtremPruneTestCase(HomeserverTestCase): # also set the depth to "lots". self.reactor.advance(7 * 24 * 60 * 60) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + # Fudge a second event which points to an event we don't have. This is a # state event so that the state changes (otherwise we won't prune the # extremity as they'll have the same state group). @@ -323,16 +354,16 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other2", "depth": 10000, "prev_events": ["$some_unknown_message"], - "auth_events": [], + "auth_events": [ + state_before_gap.get((EventTypes.Create, "")), + state_before_gap.get((EventTypes.PowerLevels, "")), + state_before_gap.get((EventTypes.JoinRules, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, ) - state_before_gap = self.get_success( - self._state_storage_controller.get_current_state_ids(self.room_id) - ) - self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. @@ -347,6 +378,10 @@ class ExtremPruneTestCase(HomeserverTestCase): local_message_event_id = body["event_id"] self.assert_extremities([local_message_event_id]) + state_before_gap = self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) + # Fudge a second event which points to an event we don't have. This is a # state event so that the state changes (otherwise we won't prune the # extremity as they'll have the same state group). @@ -359,16 +394,16 @@ class ExtremPruneTestCase(HomeserverTestCase): "sender": "@user:other2", "depth": 10000, "prev_events": ["$some_unknown_message"], - "auth_events": [], + "auth_events": [ + state_before_gap.get((EventTypes.Create, "")), + state_before_gap.get((EventTypes.PowerLevels, "")), + state_before_gap.get((EventTypes.JoinRules, "")), + ], "origin_server_ts": self.clock.time_msec(), }, RoomVersions.V6, ) - state_before_gap = self.get_success( - self._state_storage_controller.get_current_state_ids(self.room_id) - ) - self.persist_event(remote_event_2, state=state_before_gap) # Check the new extremity is just the new remote event. diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 4f9cfede7b..a9c0d7d9a9 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -25,7 +25,7 @@ from canonicaljson import json from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.events import EventBase from synapse.events.builder import EventBuilder from synapse.server import HomeServer @@ -263,12 +263,17 @@ class RedactionTestCase(unittest.HomeserverTestCase): @property def room_id(self) -> str: + assert self._base_builder.room_id is not None return self._base_builder.room_id @property def type(self) -> str: return self._base_builder.type + @property + def room_version(self) -> RoomVersion: + return self._base_builder.room_version + @property def internal_metadata(self) -> EventInternalMetadata: return self._base_builder.internal_metadata diff --git a/tests/types/test_init.py b/tests/types/test_init.py new file mode 100644 index 0000000000..b7a5b93ce9 --- /dev/null +++ b/tests/types/test_init.py @@ -0,0 +1,51 @@ +from synapse.api.errors import SynapseError +from synapse.types import RoomID + +from tests.unittest import TestCase + + +class RoomIDTestCase(TestCase): + def test_can_create_msc4291_room_ids(self) -> None: + valid_msc4291_room_id = "!31hneApxJ_1o-63DmFrpeqnkFfWppnzWso1JvH3ogLM" + room_id = RoomID.from_string(valid_msc4291_room_id) + self.assertEquals(RoomID.is_valid(valid_msc4291_room_id), True) + self.assertEquals( + room_id.to_string(), + valid_msc4291_room_id, + ) + self.assertEquals(room_id.id, "!31hneApxJ_1o-63DmFrpeqnkFfWppnzWso1JvH3ogLM") + self.assertEquals(room_id.get_domain(), None) + + def test_cannot_create_invalid_msc4291_room_ids(self) -> None: + invalid_room_ids = [ + "!wronglength", + "!31hneApxJ_1o-63DmFrpeqnNOTurlsafeBASE64/gLM", + "!", + "! ", + ] + for bad_room_id in invalid_room_ids: + with self.assertRaises(SynapseError): + RoomID.from_string(bad_room_id) + if not RoomID.is_valid(bad_room_id): + raise SynapseError(400, "invalid") + + def test_cannot_create_invalid_legacy_room_ids(self) -> None: + invalid_room_ids = [ + "!something:invalid$_chars.com", + ] + for bad_room_id in invalid_room_ids: + with self.assertRaises(SynapseError): + RoomID.from_string(bad_room_id) + if not RoomID.is_valid(bad_room_id): + raise SynapseError(400, "invalid") + + def test_can_create_valid_legacy_room_ids(self) -> None: + valid_room_ids = [ + "!foo:example.com", + "!foo:example.com:8448", + "!💩💩💩:example.com", + ] + for room_id_str in valid_room_ids: + room_id = RoomID.from_string(room_id_str) + self.assertEquals(RoomID.is_valid(room_id_str), True) + self.assertIsNotNone(room_id.get_domain())