mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Support for room version 12
This commit is contained in:
committed by
Andrew Morgan
parent
edac7a471f
commit
731e81c9a3
@@ -4094,7 +4094,7 @@ The default power levels for each preset are:
|
||||
"m.room.history_visibility": 100
|
||||
"m.room.canonical_alias": 50
|
||||
"m.room.avatar": 50
|
||||
"m.room.tombstone": 100
|
||||
"m.room.tombstone": 100 (150 if MSC4289 is used)
|
||||
"m.room.server_acl": 100
|
||||
"m.room.encryption": 100
|
||||
```
|
||||
|
||||
@@ -5109,7 +5109,7 @@ properties:
|
||||
|
||||
"m.room.avatar": 50
|
||||
|
||||
"m.room.tombstone": 100
|
||||
"m.room.tombstone": 100 (150 if MSC4289 is used)
|
||||
|
||||
"m.room.server_acl": 100
|
||||
|
||||
|
||||
@@ -46,6 +46,9 @@ MAX_USERID_LENGTH = 255
|
||||
# Constant value used for the pseudo-thread which is the main timeline.
|
||||
MAIN_TIMELINE: Final = "main"
|
||||
|
||||
# MAX_INT + 1, so it always trumps any PL in canonical JSON.
|
||||
CREATOR_POWER_LEVEL = 2**53
|
||||
|
||||
|
||||
class Membership:
|
||||
"""Represents the membership states of a user in a room."""
|
||||
@@ -235,6 +238,8 @@ class EventContentFields:
|
||||
#
|
||||
# This is deprecated in MSC2175.
|
||||
ROOM_CREATOR: Final = "creator"
|
||||
# MSC4289
|
||||
ADDITIONAL_CREATORS: Final = "additional_creators"
|
||||
|
||||
# The version of the room for `m.room.create` events.
|
||||
ROOM_VERSION: Final = "room_version"
|
||||
|
||||
@@ -36,12 +36,14 @@ class EventFormatVersions:
|
||||
ROOM_V1_V2 = 1 # $id:server event id format: used for room v1 and v2
|
||||
ROOM_V3 = 2 # MSC1659-style $hash event id format: used for room v3
|
||||
ROOM_V4_PLUS = 3 # MSC1884-style $hash format: introduced for room v4
|
||||
ROOM_V11_HYDRA_PLUS = 4 # MSC4291 room IDs as hashes: introduced for room HydraV11
|
||||
|
||||
|
||||
KNOWN_EVENT_FORMAT_VERSIONS = {
|
||||
EventFormatVersions.ROOM_V1_V2,
|
||||
EventFormatVersions.ROOM_V3,
|
||||
EventFormatVersions.ROOM_V4_PLUS,
|
||||
EventFormatVersions.ROOM_V11_HYDRA_PLUS,
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +52,7 @@ class StateResolutionVersions:
|
||||
|
||||
V1 = 1 # room v1 state res
|
||||
V2 = 2 # MSC1442 state res: room v2 and later
|
||||
V2_1 = 3 # MSC4297 state res
|
||||
|
||||
|
||||
class RoomDisposition:
|
||||
@@ -109,6 +112,10 @@ class RoomVersion:
|
||||
msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag
|
||||
# MSC3757: Restricting who can overwrite a state event
|
||||
msc3757_enabled: bool
|
||||
# MSC4289: Creator power enabled
|
||||
msc4289_creator_power_enabled: bool
|
||||
# MSC4291: Room IDs as hashes of the create event
|
||||
msc4291_room_ids_as_hashes: bool
|
||||
|
||||
|
||||
class RoomVersions:
|
||||
@@ -131,6 +138,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V2 = RoomVersion(
|
||||
"2",
|
||||
@@ -151,6 +160,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V3 = RoomVersion(
|
||||
"3",
|
||||
@@ -171,6 +182,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V4 = RoomVersion(
|
||||
"4",
|
||||
@@ -191,6 +204,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V5 = RoomVersion(
|
||||
"5",
|
||||
@@ -211,6 +226,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V6 = RoomVersion(
|
||||
"6",
|
||||
@@ -231,6 +248,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V7 = RoomVersion(
|
||||
"7",
|
||||
@@ -251,6 +270,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V8 = RoomVersion(
|
||||
"8",
|
||||
@@ -271,6 +292,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V9 = RoomVersion(
|
||||
"9",
|
||||
@@ -291,6 +314,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=False,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V10 = RoomVersion(
|
||||
"10",
|
||||
@@ -311,6 +336,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=True,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
MSC1767v10 = RoomVersion(
|
||||
# MSC1767 (Extensible Events) based on room version "10"
|
||||
@@ -332,6 +359,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=True,
|
||||
msc3931_push_features=(PushRuleRoomFlag.EXTENSIBLE_EVENTS,),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
MSC3757v10 = RoomVersion(
|
||||
# MSC3757 (Restricting who can overwrite a state event) based on room version "10"
|
||||
@@ -353,6 +382,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=True,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=True,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
V11 = RoomVersion(
|
||||
"11",
|
||||
@@ -373,6 +404,8 @@ class RoomVersions:
|
||||
enforce_int_power_levels=True,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
MSC3757v11 = RoomVersion(
|
||||
# MSC3757 (Restricting who can overwrite a state event) based on room version "11"
|
||||
@@ -394,6 +427,52 @@ class RoomVersions:
|
||||
enforce_int_power_levels=True,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=True,
|
||||
msc4289_creator_power_enabled=False,
|
||||
msc4291_room_ids_as_hashes=False,
|
||||
)
|
||||
HydraV11 = RoomVersion(
|
||||
"org.matrix.hydra.11",
|
||||
RoomDisposition.UNSTABLE,
|
||||
EventFormatVersions.ROOM_V11_HYDRA_PLUS,
|
||||
StateResolutionVersions.V2_1, # Changed from v11
|
||||
enforce_key_validity=True,
|
||||
special_case_aliases_auth=False,
|
||||
strict_canonicaljson=True,
|
||||
limit_notifications_power_levels=True,
|
||||
implicit_room_creator=True, # Used by MSC3820
|
||||
updated_redaction_rules=True, # Used by MSC3820
|
||||
restricted_join_rule=True,
|
||||
restricted_join_rule_fix=True,
|
||||
knock_join_rule=True,
|
||||
msc3389_relation_redactions=False,
|
||||
knock_restricted_join_rule=True,
|
||||
enforce_int_power_levels=True,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=True, # Changed from v11
|
||||
msc4291_room_ids_as_hashes=True, # Changed from v11
|
||||
)
|
||||
V12 = RoomVersion(
|
||||
"12",
|
||||
RoomDisposition.STABLE,
|
||||
EventFormatVersions.ROOM_V11_HYDRA_PLUS,
|
||||
StateResolutionVersions.V2_1, # Changed from v11
|
||||
enforce_key_validity=True,
|
||||
special_case_aliases_auth=False,
|
||||
strict_canonicaljson=True,
|
||||
limit_notifications_power_levels=True,
|
||||
implicit_room_creator=True, # Used by MSC3820
|
||||
updated_redaction_rules=True, # Used by MSC3820
|
||||
restricted_join_rule=True,
|
||||
restricted_join_rule_fix=True,
|
||||
knock_join_rule=True,
|
||||
msc3389_relation_redactions=False,
|
||||
knock_restricted_join_rule=True,
|
||||
enforce_int_power_levels=True,
|
||||
msc3931_push_features=(),
|
||||
msc3757_enabled=False,
|
||||
msc4289_creator_power_enabled=True, # Changed from v11
|
||||
msc4291_room_ids_as_hashes=True, # Changed from v11
|
||||
)
|
||||
|
||||
|
||||
@@ -411,6 +490,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
|
||||
RoomVersions.V9,
|
||||
RoomVersions.V10,
|
||||
RoomVersions.V11,
|
||||
RoomVersions.V12,
|
||||
RoomVersions.MSC3757v10,
|
||||
RoomVersions.MSC3757v11,
|
||||
)
|
||||
|
||||
@@ -101,6 +101,9 @@ def compute_content_hash(
|
||||
event_dict.pop("outlier", None)
|
||||
event_dict.pop("destinations", None)
|
||||
|
||||
# N.B. no need to pop the room_id from create events in MSC4291 rooms
|
||||
# as they shouldn't have one.
|
||||
|
||||
event_json_bytes = encode_canonical_json(event_dict)
|
||||
|
||||
hashed = hash_algorithm(event_json_bytes)
|
||||
|
||||
@@ -45,6 +45,7 @@ from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.api.constants import (
|
||||
CREATOR_POWER_LEVEL,
|
||||
MAX_PDU_SIZE,
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
@@ -65,6 +66,7 @@ from synapse.api.room_versions import (
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.state import CREATE_KEY
|
||||
from synapse.events import is_creator
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import (
|
||||
MutableStateMap,
|
||||
@@ -261,7 +263,8 @@ async def check_state_independent_auth_rules(
|
||||
f"Event {event.event_id} has unexpected auth_event for {k}: {auth_event_id}",
|
||||
)
|
||||
|
||||
# We also need to check that the auth event itself is not rejected.
|
||||
# 2.3 ... If there are entries which were themselves rejected under the checks performed on receipt
|
||||
# of a PDU, reject.
|
||||
if auth_event.rejected_reason:
|
||||
raise AuthError(
|
||||
403,
|
||||
@@ -271,7 +274,7 @@ async def check_state_independent_auth_rules(
|
||||
|
||||
auth_dict[k] = auth_event_id
|
||||
|
||||
# 3. If event does not have a m.room.create in its auth_events, reject.
|
||||
# 2.4. If event does not have a m.room.create in its auth_events, reject.
|
||||
creation_event = auth_dict.get((EventTypes.Create, ""), None)
|
||||
if not creation_event:
|
||||
raise AuthError(403, "No create event in auth events")
|
||||
@@ -311,13 +314,14 @@ def check_state_dependent_auth_rules(
|
||||
|
||||
# Later code relies on there being a create event e.g _can_federate, _is_membership_change_allowed
|
||||
# so produce a more intelligible error if we don't have one.
|
||||
if auth_dict.get(CREATE_KEY) is None:
|
||||
create_event = auth_dict.get(CREATE_KEY)
|
||||
if create_event is None:
|
||||
raise AuthError(
|
||||
403, f"Event {event.event_id} is missing a create event in auth_events."
|
||||
)
|
||||
|
||||
# additional check for m.federate
|
||||
creating_domain = get_domain_from_id(event.room_id)
|
||||
creating_domain = get_domain_from_id(create_event.sender)
|
||||
originating_domain = get_domain_from_id(event.sender)
|
||||
if creating_domain != originating_domain:
|
||||
if not _can_federate(event, auth_dict):
|
||||
@@ -470,12 +474,20 @@ def _check_create(event: "EventBase") -> None:
|
||||
if event.prev_event_ids():
|
||||
raise AuthError(403, "Create event has prev events")
|
||||
|
||||
# 1.2 If the domain of the room_id does not match the domain of the sender,
|
||||
# reject.
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
room_id_domain = get_domain_from_id(event.room_id)
|
||||
if room_id_domain != sender_domain:
|
||||
raise AuthError(403, "Creation event's room_id domain does not match sender's")
|
||||
if event.room_version.msc4291_room_ids_as_hashes:
|
||||
# 1.2 If the create event has a room_id, reject
|
||||
if "room_id" in event:
|
||||
raise AuthError(403, "Create event has a room_id")
|
||||
else:
|
||||
# 1.2 If the domain of the room_id does not match the domain of the sender,
|
||||
# reject.
|
||||
if not event.room_version.msc4291_room_ids_as_hashes:
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
room_id_domain = get_domain_from_id(event.room_id)
|
||||
if room_id_domain != sender_domain:
|
||||
raise AuthError(
|
||||
403, "Creation event's room_id domain does not match sender's"
|
||||
)
|
||||
|
||||
# 1.3 If content.room_version is present and is not a recognised version, reject
|
||||
room_version_prop = event.content.get("room_version", "1")
|
||||
@@ -492,6 +504,16 @@ def _check_create(event: "EventBase") -> None:
|
||||
):
|
||||
raise AuthError(403, "Create event lacks a 'creator' property")
|
||||
|
||||
# 1.5 If the additional_creators field is present and is not an array of strings where each
|
||||
# string is a valid user ID, reject.
|
||||
if (
|
||||
event.room_version.msc4289_creator_power_enabled
|
||||
and EventContentFields.ADDITIONAL_CREATORS in event.content
|
||||
):
|
||||
check_valid_additional_creators(
|
||||
event.content[EventContentFields.ADDITIONAL_CREATORS]
|
||||
)
|
||||
|
||||
|
||||
def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
|
||||
creation_event = auth_events.get((EventTypes.Create, ""))
|
||||
@@ -533,7 +555,13 @@ def _is_membership_change_allowed(
|
||||
|
||||
target_user_id = event.state_key
|
||||
|
||||
creating_domain = get_domain_from_id(event.room_id)
|
||||
# We need the create event in order to check if we can federate or not.
|
||||
# If it's missing, yell loudly. Previously we only did this inside the
|
||||
# _can_federate check.
|
||||
create_event = auth_events.get((EventTypes.Create, ""))
|
||||
if not create_event:
|
||||
raise AuthError(403, "Create event missing from auth_events")
|
||||
creating_domain = get_domain_from_id(create_event.sender)
|
||||
target_domain = get_domain_from_id(target_user_id)
|
||||
if creating_domain != target_domain:
|
||||
if not _can_federate(event, auth_events):
|
||||
@@ -903,6 +931,32 @@ def _check_power_levels(
|
||||
except Exception:
|
||||
raise SynapseError(400, "Not a valid power level: %s" % (v,))
|
||||
|
||||
if room_version_obj.msc4289_creator_power_enabled:
|
||||
# Enforce the creator does not appear in the users map
|
||||
create_event = auth_events.get((EventTypes.Create, ""))
|
||||
if not create_event:
|
||||
raise SynapseError(
|
||||
400, "Cannot check power levels without a create event in auth_events"
|
||||
)
|
||||
if create_event.sender in user_list:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Creator user %s must not appear in content.users"
|
||||
% (create_event.sender,),
|
||||
)
|
||||
additional_creators = create_event.content.get(
|
||||
EventContentFields.ADDITIONAL_CREATORS, []
|
||||
)
|
||||
if additional_creators:
|
||||
creators_in_user_list = set(additional_creators).intersection(
|
||||
set(user_list)
|
||||
)
|
||||
if len(creators_in_user_list) > 0:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Additional creators users must not appear in content.users",
|
||||
)
|
||||
|
||||
# Reject events with stringy power levels if required by room version
|
||||
if (
|
||||
event.type == EventTypes.PowerLevels
|
||||
@@ -1028,6 +1082,9 @@ def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> in
|
||||
"A create event in the auth events chain is required to calculate user power level correctly,"
|
||||
" but was not found. This indicates a bug"
|
||||
)
|
||||
if create_event.room_version.msc4289_creator_power_enabled:
|
||||
if is_creator(create_event, user_id):
|
||||
return CREATOR_POWER_LEVEL
|
||||
power_level_event = get_power_level_event(auth_events)
|
||||
if power_level_event:
|
||||
level = power_level_event.content.get("users", {}).get(user_id)
|
||||
@@ -1188,3 +1245,26 @@ def auth_types_for_event(
|
||||
auth_types.add(key)
|
||||
|
||||
return auth_types
|
||||
|
||||
|
||||
def check_valid_additional_creators(additional_creators: Any) -> None:
|
||||
"""Check if the additional_creators provided is valid according to MSC4289.
|
||||
|
||||
The additional_creators can be supplied from an m.room.create event or from an /upgrade request.
|
||||
|
||||
Raises:
|
||||
AuthError if the additional_creators is invalid for some reason.
|
||||
"""
|
||||
if type(additional_creators) is not list:
|
||||
raise AuthError(400, "additional_creators must be an array")
|
||||
for entry in additional_creators:
|
||||
if type(entry) is not str:
|
||||
raise AuthError(400, "entry in additional_creators is not a string")
|
||||
if not UserID.is_valid(entry):
|
||||
raise AuthError(400, "entry in additional_creators is not a valid user ID")
|
||||
# UserID.is_valid doesn't actually validate everything, so check the rest manually.
|
||||
if len(entry) > 255 or len(entry.encode("utf-8")) > 255:
|
||||
raise AuthError(
|
||||
400,
|
||||
"entry in additional_creators too long",
|
||||
)
|
||||
|
||||
@@ -41,10 +41,13 @@ from typing import (
|
||||
import attr
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
|
||||
from synapse.synapse_rust.events import EventInternalMetadata
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
StrCollection,
|
||||
)
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
@@ -209,7 +212,6 @@ class EventBase(metaclass=abc.ABCMeta):
|
||||
content: DictProperty[JsonDict] = DictProperty("content")
|
||||
hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
|
||||
origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
|
||||
room_id: DictProperty[str] = DictProperty("room_id")
|
||||
sender: DictProperty[str] = DictProperty("sender")
|
||||
# TODO state_key should be Optional[str]. This is generally asserted in Synapse
|
||||
# by calling is_state() first (which ensures it is not None), but it is hard (not possible?)
|
||||
@@ -224,6 +226,10 @@ class EventBase(metaclass=abc.ABCMeta):
|
||||
def event_id(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def room_id(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def membership(self) -> str:
|
||||
return self.content["membership"]
|
||||
@@ -386,6 +392,10 @@ class FrozenEvent(EventBase):
|
||||
def event_id(self) -> str:
|
||||
return self._event_id
|
||||
|
||||
@property
|
||||
def room_id(self) -> str:
|
||||
return self._dict["room_id"]
|
||||
|
||||
|
||||
class FrozenEventV2(EventBase):
|
||||
format_version = EventFormatVersions.ROOM_V3 # All events of this type are V2
|
||||
@@ -443,6 +453,10 @@ class FrozenEventV2(EventBase):
|
||||
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
|
||||
return self._event_id
|
||||
|
||||
@property
|
||||
def room_id(self) -> str:
|
||||
return self._dict["room_id"]
|
||||
|
||||
def prev_event_ids(self) -> List[str]:
|
||||
"""Returns the list of prev event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
@@ -481,6 +495,67 @@ class FrozenEventV3(FrozenEventV2):
|
||||
return self._event_id
|
||||
|
||||
|
||||
class FrozenEventV4(FrozenEventV3):
|
||||
"""FrozenEventV4 for MSC4291 room IDs are hashes"""
|
||||
|
||||
format_version = EventFormatVersions.ROOM_V11_HYDRA_PLUS
|
||||
|
||||
"""Override the room_id for m.room.create events"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_dict: JsonDict,
|
||||
room_version: RoomVersion,
|
||||
internal_metadata_dict: Optional[JsonDict] = None,
|
||||
rejected_reason: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
event_dict=event_dict,
|
||||
room_version=room_version,
|
||||
internal_metadata_dict=internal_metadata_dict,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
self._room_id: Optional[str] = None
|
||||
|
||||
@property
|
||||
def room_id(self) -> str:
|
||||
# if we have calculated the room ID already, don't do it again.
|
||||
if self._room_id:
|
||||
return self._room_id
|
||||
|
||||
is_create_event = self.type == EventTypes.Create and self.get_state_key() == ""
|
||||
|
||||
# for non-create events: use the supplied value from the JSON, as per FrozenEventV3
|
||||
if not is_create_event:
|
||||
self._room_id = self._dict["room_id"]
|
||||
assert self._room_id is not None
|
||||
return self._room_id
|
||||
|
||||
# for create events: calculate the room ID
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
self._room_id = "!" + encode_base64(
|
||||
compute_event_reference_hash(self)[1], urlsafe=True
|
||||
)
|
||||
return self._room_id
|
||||
|
||||
def auth_event_ids(self) -> StrCollection:
|
||||
"""Returns the list of auth event IDs. The order matches the order
|
||||
specified in the event, though there is no meaning to it.
|
||||
Returns:
|
||||
The list of event IDs of this event's auth_events
|
||||
Includes the creation event ID for convenience of all the codepaths
|
||||
which expects the auth chain to include the creator ID, even though
|
||||
it's explicitly not included on the wire. Excludes the create event
|
||||
for the create event itself.
|
||||
"""
|
||||
create_event_id = "$" + self.room_id[1:]
|
||||
assert create_event_id not in self._dict["auth_events"]
|
||||
if self.type == EventTypes.Create and self.get_state_key() == "":
|
||||
return self._dict["auth_events"] # should be []
|
||||
return self._dict["auth_events"] + [create_event_id]
|
||||
|
||||
|
||||
def _event_type_from_format_version(
|
||||
format_version: int,
|
||||
) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
|
||||
@@ -500,6 +575,8 @@ def _event_type_from_format_version(
|
||||
return FrozenEventV2
|
||||
elif format_version == EventFormatVersions.ROOM_V4_PLUS:
|
||||
return FrozenEventV3
|
||||
elif format_version == EventFormatVersions.ROOM_V11_HYDRA_PLUS:
|
||||
return FrozenEventV4
|
||||
else:
|
||||
raise Exception("No event format %r" % (format_version,))
|
||||
|
||||
@@ -559,6 +636,23 @@ def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
|
||||
return _EventRelation(parent_id, rel_type, aggregation_key)
|
||||
|
||||
|
||||
def is_creator(create: EventBase, user_id: str) -> bool:
|
||||
"""
|
||||
Return true if the provided user ID is the room creator.
|
||||
|
||||
This includes additional creators in MSC4289.
|
||||
"""
|
||||
assert create.type == EventTypes.Create
|
||||
if create.sender == user_id:
|
||||
return True
|
||||
if create.room_version.msc4289_creator_power_enabled:
|
||||
additional_creators = set(
|
||||
create.content.get(EventContentFields.ADDITIONAL_CREATORS, [])
|
||||
)
|
||||
return user_id in additional_creators
|
||||
return False
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class StrippedStateEvent:
|
||||
"""
|
||||
|
||||
@@ -82,7 +82,8 @@ class EventBuilder:
|
||||
|
||||
room_version: RoomVersion
|
||||
|
||||
room_id: str
|
||||
# MSC4291 makes the room ID == the create event ID. This means the create event has no room_id.
|
||||
room_id: Optional[str]
|
||||
type: str
|
||||
sender: str
|
||||
|
||||
@@ -142,7 +143,14 @@ class EventBuilder:
|
||||
Returns:
|
||||
The signed and hashed event.
|
||||
"""
|
||||
# Create events always have empty auth_events.
|
||||
if self.type == EventTypes.Create and self.is_state() and self.state_key == "":
|
||||
auth_event_ids = []
|
||||
|
||||
# Calculate auth_events for non-create events
|
||||
if auth_event_ids is None:
|
||||
# Every non-create event must have a room ID
|
||||
assert self.room_id is not None
|
||||
state_ids = await self._state.compute_state_after_events(
|
||||
self.room_id,
|
||||
prev_event_ids,
|
||||
@@ -224,12 +232,31 @@ class EventBuilder:
|
||||
"auth_events": auth_events,
|
||||
"prev_events": prev_events,
|
||||
"type": self.type,
|
||||
"room_id": self.room_id,
|
||||
"sender": self.sender,
|
||||
"content": self.content,
|
||||
"unsigned": self.unsigned,
|
||||
"depth": depth,
|
||||
}
|
||||
if self.room_id is not None:
|
||||
event_dict["room_id"] = self.room_id
|
||||
|
||||
if self.room_version.msc4291_room_ids_as_hashes:
|
||||
# In MSC4291: the create event has no room ID as the create event ID /is/ the room ID.
|
||||
if (
|
||||
self.type == EventTypes.Create
|
||||
and self.is_state()
|
||||
and self._state_key == ""
|
||||
):
|
||||
assert self.room_id is None
|
||||
else:
|
||||
# All other events do not reference the create event in auth_events, as the room ID
|
||||
# /is/ the create event. However, the rest of the code (for consistency between room
|
||||
# versions) assume that the create event remains part of the auth events. c.f. event
|
||||
# class which automatically adds the create event when `.auth_event_ids()` is called
|
||||
assert self.room_id is not None
|
||||
create_event_id = "$" + self.room_id[1:]
|
||||
auth_event_ids.remove(create_event_id)
|
||||
event_dict["auth_events"] = auth_event_ids
|
||||
|
||||
if self.is_state():
|
||||
event_dict["state_key"] = self._state_key
|
||||
@@ -285,7 +312,7 @@ class EventBuilderFactory:
|
||||
room_version=room_version,
|
||||
type=key_values["type"],
|
||||
state_key=key_values.get("state_key"),
|
||||
room_id=key_values["room_id"],
|
||||
room_id=key_values.get("room_id"),
|
||||
sender=key_values["sender"],
|
||||
content=key_values.get("content", {}),
|
||||
unsigned=key_values.get("unsigned", {}),
|
||||
|
||||
@@ -176,9 +176,12 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
|
||||
if room_version.updated_redaction_rules:
|
||||
# MSC2176 rules state that create events cannot have their `content` redacted.
|
||||
new_content = event_dict["content"]
|
||||
elif not room_version.implicit_room_creator:
|
||||
if not room_version.implicit_room_creator:
|
||||
# Some room versions give meaning to `creator`
|
||||
add_fields("creator")
|
||||
if room_version.msc4291_room_ids_as_hashes:
|
||||
# room_id is not allowed on the create event as it's derived from the event ID
|
||||
allowed_keys.remove("room_id")
|
||||
|
||||
elif event_type == EventTypes.JoinRules:
|
||||
add_fields("join_rule")
|
||||
@@ -527,6 +530,10 @@ def serialize_event(
|
||||
if config.as_client_event:
|
||||
d = config.event_format(d)
|
||||
|
||||
# Ensure the room_id field is set for create events in MSC4291 rooms
|
||||
if e.type == EventTypes.Create and e.room_version.msc4291_room_ids_as_hashes:
|
||||
d["room_id"] = e.room_id
|
||||
|
||||
# If the event is a redaction, the field with the redacted event ID appears
|
||||
# in a different location depending on the room version. e.redacts handles
|
||||
# fetching from the proper location; copy it to the other location for forwards-
|
||||
@@ -869,6 +876,14 @@ def strip_event(event: EventBase) -> JsonDict:
|
||||
Stripped state events can only have the `sender`, `type`, `state_key` and `content`
|
||||
properties present.
|
||||
"""
|
||||
# MSC4311: Ensure the create event is available on invites and knocks.
|
||||
# TODO: Implement the rest of MSC4311
|
||||
if (
|
||||
event.room_version.msc4291_room_ids_as_hashes
|
||||
and event.type == EventTypes.Create
|
||||
and event.get_state_key() == ""
|
||||
):
|
||||
return event.get_pdu_json()
|
||||
|
||||
return {
|
||||
"type": event.type,
|
||||
|
||||
@@ -183,8 +183,18 @@ class EventValidator:
|
||||
fields an event would have
|
||||
"""
|
||||
|
||||
create_event_as_room_id = (
|
||||
event.room_version.msc4291_room_ids_as_hashes
|
||||
and event.type == EventTypes.Create
|
||||
and hasattr(event, "state_key")
|
||||
and event.state_key == ""
|
||||
)
|
||||
|
||||
strings = ["room_id", "sender", "type"]
|
||||
|
||||
if create_event_as_room_id:
|
||||
strings.remove("room_id")
|
||||
|
||||
if hasattr(event, "state_key"):
|
||||
strings.append("state_key")
|
||||
|
||||
@@ -192,7 +202,14 @@ class EventValidator:
|
||||
if not isinstance(getattr(event, s), str):
|
||||
raise SynapseError(400, "Not '%s' a string type" % (s,))
|
||||
|
||||
RoomID.from_string(event.room_id)
|
||||
if not create_event_as_room_id:
|
||||
assert event.room_id is not None
|
||||
RoomID.from_string(event.room_id)
|
||||
if event.room_version.msc4291_room_ids_as_hashes and not RoomID.is_valid(
|
||||
event.room_id
|
||||
):
|
||||
raise SynapseError(400, f"Invalid room ID '{event.room_id}'")
|
||||
|
||||
UserID.from_string(event.sender)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
|
||||
@@ -342,6 +342,21 @@ def event_from_pdu_json(pdu_json: JsonDict, room_version: RoomVersion) -> EventB
|
||||
if room_version.strict_canonicaljson:
|
||||
validate_canonicaljson(pdu_json)
|
||||
|
||||
# enforce that MSC4291 auth events don't include the create event.
|
||||
# N.B. if they DO include a spurious create event, it'll fail auth checks elsewhere, so we don't
|
||||
# need to do expensive DB lookups to find which event ID is the create event here.
|
||||
if room_version.msc4291_room_ids_as_hashes:
|
||||
room_id = pdu_json.get("room_id")
|
||||
if room_id:
|
||||
create_event_id = "$" + room_id[1:]
|
||||
auth_events = pdu_json.get("auth_events")
|
||||
if auth_events:
|
||||
if create_event_id in auth_events:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"auth_events must not contain the create event",
|
||||
Codes.BAD_JSON,
|
||||
)
|
||||
event = make_event_from_dict(pdu_json, room_version)
|
||||
return event
|
||||
|
||||
|
||||
@@ -23,6 +23,8 @@ from typing import TYPE_CHECKING, List, Mapping, Optional, Union
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import (
|
||||
CREATOR_POWER_LEVEL,
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
JoinRules,
|
||||
Membership,
|
||||
@@ -141,6 +143,8 @@ class EventAuthHandler:
|
||||
Raises:
|
||||
SynapseError if no appropriate user is found.
|
||||
"""
|
||||
create_event_id = current_state_ids[(EventTypes.Create, "")]
|
||||
create_event = await self._store.get_event(create_event_id)
|
||||
power_level_event_id = current_state_ids.get((EventTypes.PowerLevels, ""))
|
||||
invite_level = 0
|
||||
users_default_level = 0
|
||||
@@ -156,15 +160,28 @@ class EventAuthHandler:
|
||||
|
||||
# Find the user with the highest power level (only interested in local
|
||||
# users).
|
||||
user_power_level = 0
|
||||
chosen_user = None
|
||||
local_users_in_room = await self._store.get_local_users_in_room(room_id)
|
||||
chosen_user = max(
|
||||
local_users_in_room,
|
||||
key=lambda user: users.get(user, users_default_level),
|
||||
default=None,
|
||||
)
|
||||
if create_event.room_version.msc4289_creator_power_enabled:
|
||||
creators = set(
|
||||
create_event.content.get(EventContentFields.ADDITIONAL_CREATORS, [])
|
||||
)
|
||||
creators.add(create_event.sender)
|
||||
local_creators = creators.intersection(set(local_users_in_room))
|
||||
if len(local_creators) > 0:
|
||||
chosen_user = local_creators.pop() # random creator
|
||||
user_power_level = CREATOR_POWER_LEVEL
|
||||
else:
|
||||
chosen_user = max(
|
||||
local_users_in_room,
|
||||
key=lambda user: users.get(user, users_default_level),
|
||||
default=None,
|
||||
)
|
||||
# Return the chosen if they can issue invites.
|
||||
if chosen_user:
|
||||
user_power_level = users.get(chosen_user, users_default_level)
|
||||
|
||||
# Return the chosen if they can issue invites.
|
||||
user_power_level = users.get(chosen_user, users_default_level)
|
||||
if chosen_user and user_power_level >= invite_level:
|
||||
logger.debug(
|
||||
"Found a user who can issue invites %s with power level %d >= invite level %d",
|
||||
|
||||
@@ -674,7 +674,10 @@ class EventCreationHandler:
|
||||
Codes.USER_ACCOUNT_SUSPENDED,
|
||||
)
|
||||
|
||||
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
|
||||
is_create_event = (
|
||||
event_dict["type"] == EventTypes.Create and event_dict["state_key"] == ""
|
||||
)
|
||||
if is_create_event:
|
||||
room_version_id = event_dict["content"]["room_version"]
|
||||
maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||
if not maybe_room_version_obj:
|
||||
@@ -780,6 +783,7 @@ class EventCreationHandler:
|
||||
"""
|
||||
# the only thing the user can do is join the server notices room.
|
||||
if builder.type == EventTypes.Member:
|
||||
assert builder.room_id is not None
|
||||
membership = builder.content.get("membership", None)
|
||||
if membership == Membership.JOIN:
|
||||
return await self.store.is_server_notice_room(builder.room_id)
|
||||
@@ -1242,13 +1246,40 @@ class EventCreationHandler:
|
||||
for_verification=False,
|
||||
)
|
||||
|
||||
if (
|
||||
builder.room_version.msc4291_room_ids_as_hashes
|
||||
and builder.type == EventTypes.Create
|
||||
and builder.is_state()
|
||||
):
|
||||
if builder.room_id is not None:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Cannot resend m.room.create event",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
else:
|
||||
assert builder.room_id is not None
|
||||
|
||||
if prev_event_ids is not None:
|
||||
assert len(prev_event_ids) <= 10, (
|
||||
"Attempting to create an event with %i prev_events"
|
||||
% (len(prev_event_ids),)
|
||||
)
|
||||
else:
|
||||
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
|
||||
if builder.room_id:
|
||||
prev_event_ids = await self.store.get_prev_events_for_room(
|
||||
builder.room_id
|
||||
)
|
||||
else:
|
||||
prev_event_ids = [] # can only happen for the create event in MSC4291 rooms
|
||||
|
||||
if builder.type == EventTypes.Create and builder.is_state():
|
||||
if len(prev_event_ids) != 0:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Cannot resend m.room.create event",
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
# We now ought to have some `prev_events` (unless it's a create event).
|
||||
#
|
||||
@@ -2124,6 +2155,7 @@ class EventCreationHandler:
|
||||
original_event.room_version, third_party_result
|
||||
)
|
||||
self.validator.validate_builder(builder)
|
||||
assert builder.room_id is not None
|
||||
except SynapseError as e:
|
||||
raise Exception(
|
||||
"Third party rules module created an invalid event: " + e.msg,
|
||||
|
||||
@@ -81,6 +81,7 @@ from synapse.types import (
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
RoomIdWithDomain,
|
||||
RoomStreamToken,
|
||||
StateMap,
|
||||
StrCollection,
|
||||
@@ -188,7 +189,11 @@ class RoomCreationHandler:
|
||||
)
|
||||
|
||||
async def upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
self,
|
||||
requester: Requester,
|
||||
old_room_id: str,
|
||||
new_version: RoomVersion,
|
||||
additional_creators: Optional[List[str]],
|
||||
) -> str:
|
||||
"""Replace a room with a new room with a different version
|
||||
|
||||
@@ -196,6 +201,7 @@ class RoomCreationHandler:
|
||||
requester: the user requesting the upgrade
|
||||
old_room_id: the id of the room to be replaced
|
||||
new_version: the new room version to use
|
||||
additional_creators: additional room creators, for MSC4289.
|
||||
|
||||
Returns:
|
||||
the new room id
|
||||
@@ -224,8 +230,29 @@ class RoomCreationHandler:
|
||||
old_room = await self.store.get_room(old_room_id)
|
||||
if old_room is None:
|
||||
raise NotFoundError("Unknown room id %s" % (old_room_id,))
|
||||
old_room_is_public, _ = old_room
|
||||
|
||||
new_room_id = self._generate_room_id()
|
||||
creation_event_with_context = None
|
||||
if new_version.msc4291_room_ids_as_hashes:
|
||||
old_room_create_event = await self.store.get_create_event_for_room(
|
||||
old_room_id
|
||||
)
|
||||
creation_content = self._calculate_upgraded_room_creation_content(
|
||||
old_room_create_event,
|
||||
tombstone_event_id=None,
|
||||
new_room_version=new_version,
|
||||
additional_creators=additional_creators,
|
||||
)
|
||||
creation_event_with_context = await self._generate_create_event_for_room_id(
|
||||
requester,
|
||||
creation_content,
|
||||
old_room_is_public,
|
||||
new_version,
|
||||
)
|
||||
(create_event, _) = creation_event_with_context
|
||||
new_room_id = create_event.room_id
|
||||
else:
|
||||
new_room_id = self._generate_room_id()
|
||||
|
||||
# Try several times, it could fail with PartialStateConflictError
|
||||
# in _upgrade_room, cf comment in except block.
|
||||
@@ -274,6 +301,8 @@ class RoomCreationHandler:
|
||||
new_version,
|
||||
tombstone_event,
|
||||
tombstone_context,
|
||||
additional_creators,
|
||||
creation_event_with_context,
|
||||
)
|
||||
|
||||
return ret
|
||||
@@ -297,6 +326,10 @@ class RoomCreationHandler:
|
||||
new_version: RoomVersion,
|
||||
tombstone_event: EventBase,
|
||||
tombstone_context: synapse.events.snapshot.EventContext,
|
||||
additional_creators: Optional[List[str]],
|
||||
creation_event_with_context: Optional[
|
||||
Tuple[EventBase, synapse.events.snapshot.EventContext]
|
||||
] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Args:
|
||||
@@ -308,6 +341,8 @@ class RoomCreationHandler:
|
||||
new_version: the version to upgrade the room to
|
||||
tombstone_event: the tombstone event to send to the old room
|
||||
tombstone_context: the context for the tombstone event
|
||||
additional_creators: additional room creators, for MSC4289.
|
||||
creation_event_with_context: The new room's create event, for room IDs as create event IDs.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
@@ -317,14 +352,16 @@ class RoomCreationHandler:
|
||||
|
||||
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
|
||||
|
||||
# create the new room. may raise a `StoreError` in the exceedingly unlikely
|
||||
# event of a room ID collision.
|
||||
await self.store.store_room(
|
||||
room_id=new_room_id,
|
||||
room_creator_user_id=user_id,
|
||||
is_public=old_room[0],
|
||||
room_version=new_version,
|
||||
)
|
||||
# We've already stored the room if we have the create event
|
||||
if not creation_event_with_context:
|
||||
# create the new room. may raise a `StoreError` in the exceedingly unlikely
|
||||
# event of a room ID collision.
|
||||
await self.store.store_room(
|
||||
room_id=new_room_id,
|
||||
room_creator_user_id=user_id,
|
||||
is_public=old_room[0],
|
||||
room_version=new_version,
|
||||
)
|
||||
|
||||
await self.clone_existing_room(
|
||||
requester,
|
||||
@@ -332,6 +369,8 @@ class RoomCreationHandler:
|
||||
new_room_id=new_room_id,
|
||||
new_room_version=new_version,
|
||||
tombstone_event_id=tombstone_event.event_id,
|
||||
additional_creators=additional_creators,
|
||||
creation_event_with_context=creation_event_with_context,
|
||||
)
|
||||
|
||||
# now send the tombstone
|
||||
@@ -365,6 +404,7 @@ class RoomCreationHandler:
|
||||
old_room_id,
|
||||
new_room_id,
|
||||
old_room_state,
|
||||
additional_creators,
|
||||
)
|
||||
|
||||
return new_room_id
|
||||
@@ -375,6 +415,7 @@ class RoomCreationHandler:
|
||||
old_room_id: str,
|
||||
new_room_id: str,
|
||||
old_room_state: StateMap[str],
|
||||
additional_creators: Optional[List[str]],
|
||||
) -> None:
|
||||
"""Send updated power levels in both rooms after an upgrade
|
||||
|
||||
@@ -383,7 +424,7 @@ class RoomCreationHandler:
|
||||
old_room_id: the id of the room to be replaced
|
||||
new_room_id: the id of the replacement room
|
||||
old_room_state: the state map for the old room
|
||||
|
||||
additional_creators: Additional creators in the new room.
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
"""
|
||||
@@ -439,6 +480,14 @@ class RoomCreationHandler:
|
||||
except AuthError as e:
|
||||
logger.warning("Unable to update PLs in old room: %s", e)
|
||||
|
||||
new_room_version = await self.store.get_room_version(new_room_id)
|
||||
if new_room_version.msc4289_creator_power_enabled:
|
||||
self._remove_creators_from_pl_users_map(
|
||||
old_room_pl_state.content.get("users", {}),
|
||||
requester.user.to_string(),
|
||||
additional_creators,
|
||||
)
|
||||
|
||||
await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
@@ -453,6 +502,30 @@ class RoomCreationHandler:
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
def _calculate_upgraded_room_creation_content(
|
||||
self,
|
||||
old_room_create_event: EventBase,
|
||||
tombstone_event_id: Optional[str],
|
||||
new_room_version: RoomVersion,
|
||||
) -> JsonDict:
|
||||
creation_content: JsonDict = {
|
||||
"room_version": new_room_version.identifier,
|
||||
"predecessor": {
|
||||
"room_id": old_room_create_event.room_id,
|
||||
},
|
||||
}
|
||||
if tombstone_event_id is not None:
|
||||
creation_content["predecessor"]["event_id"] = tombstone_event_id
|
||||
# Check if old room was non-federatable
|
||||
if not old_room_create_event.content.get(EventContentFields.FEDERATE, True):
|
||||
# If so, mark the new room as non-federatable as well
|
||||
creation_content[EventContentFields.FEDERATE] = False
|
||||
# Copy the room type as per MSC3818.
|
||||
room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
|
||||
if room_type is not None:
|
||||
creation_content[EventContentFields.ROOM_TYPE] = room_type
|
||||
return creation_content
|
||||
|
||||
async def clone_existing_room(
|
||||
self,
|
||||
requester: Requester,
|
||||
@@ -460,6 +533,10 @@ class RoomCreationHandler:
|
||||
new_room_id: str,
|
||||
new_room_version: RoomVersion,
|
||||
tombstone_event_id: str,
|
||||
additional_creators: Optional[List[str]],
|
||||
creation_event_with_context: Optional[
|
||||
Tuple[EventBase, synapse.events.snapshot.EventContext]
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Populate a new room based on an old room
|
||||
|
||||
@@ -470,24 +547,23 @@ class RoomCreationHandler:
|
||||
created with _generate_room_id())
|
||||
new_room_version: the new room version to use
|
||||
tombstone_event_id: the ID of the tombstone event in the old room.
|
||||
creation_event_with_context: The create event of the new room, if the new room supports
|
||||
room ID as create event ID hash.
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
creation_content: JsonDict = {
|
||||
"room_version": new_room_version.identifier,
|
||||
"predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
|
||||
}
|
||||
|
||||
# Check if old room was non-federatable
|
||||
|
||||
# Get old room's create event
|
||||
old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
|
||||
|
||||
# Check if the create event specified a non-federatable room
|
||||
if not old_room_create_event.content.get(EventContentFields.FEDERATE, True):
|
||||
# If so, mark the new room as non-federatable as well
|
||||
creation_content[EventContentFields.FEDERATE] = False
|
||||
|
||||
if creation_event_with_context:
|
||||
create_event, _ = creation_event_with_context
|
||||
creation_content = create_event.content
|
||||
else:
|
||||
creation_content = self._calculate_upgraded_room_creation_content(
|
||||
old_room_create_event,
|
||||
tombstone_event_id,
|
||||
new_room_version,
|
||||
)
|
||||
initial_state = {}
|
||||
|
||||
# Replicate relevant room events
|
||||
@@ -503,11 +579,8 @@ class RoomCreationHandler:
|
||||
(EventTypes.PowerLevels, ""),
|
||||
]
|
||||
|
||||
# Copy the room type as per MSC3818.
|
||||
room_type = old_room_create_event.content.get(EventContentFields.ROOM_TYPE)
|
||||
if room_type is not None:
|
||||
creation_content[EventContentFields.ROOM_TYPE] = room_type
|
||||
|
||||
# If the old room was a space, copy over the rooms in the space.
|
||||
if room_type == RoomTypes.SPACE:
|
||||
types_to_copy.append((EventTypes.SpaceChild, None))
|
||||
@@ -579,6 +652,14 @@ class RoomCreationHandler:
|
||||
if current_power_level_int < needed_power_level:
|
||||
user_power_levels[user_id] = needed_power_level
|
||||
|
||||
if new_room_version.msc4289_creator_power_enabled:
|
||||
# the creator(s) cannot be in the users map
|
||||
self._remove_creators_from_pl_users_map(
|
||||
user_power_levels,
|
||||
user_id,
|
||||
additional_creators,
|
||||
)
|
||||
|
||||
# We construct what the body of a call to /createRoom would look like for passing
|
||||
# to the spam checker. We don't include a preset here, as we expect the
|
||||
# initial state to contain everything we need.
|
||||
@@ -607,6 +688,7 @@ class RoomCreationHandler:
|
||||
invite_list=[],
|
||||
initial_state=initial_state,
|
||||
creation_content=creation_content,
|
||||
creation_event_with_context=creation_event_with_context,
|
||||
)
|
||||
|
||||
# Transfer membership events
|
||||
@@ -890,6 +972,7 @@ class RoomCreationHandler:
|
||||
power_level_content_override = config.get("power_level_content_override")
|
||||
if (
|
||||
power_level_content_override
|
||||
and not room_version.msc4289_creator_power_enabled # this validation doesn't apply in MSC4289 rooms
|
||||
and "users" in power_level_content_override
|
||||
and user_id not in power_level_content_override["users"]
|
||||
):
|
||||
@@ -906,11 +989,41 @@ class RoomCreationHandler:
|
||||
|
||||
self._validate_room_config(config, visibility)
|
||||
|
||||
room_id = await self._generate_and_create_room_id(
|
||||
creator_id=user_id,
|
||||
is_public=is_public,
|
||||
room_version=room_version,
|
||||
)
|
||||
creation_content = config.get("creation_content", {})
|
||||
# override any attempt to set room versions via the creation_content
|
||||
creation_content["room_version"] = room_version.identifier
|
||||
|
||||
# trusted private chats have the invited users marked as additional creators
|
||||
if (
|
||||
room_version.msc4289_creator_power_enabled
|
||||
and config.get("preset", None) == RoomCreationPreset.TRUSTED_PRIVATE_CHAT
|
||||
and len(config.get("invite", [])) > 0
|
||||
):
|
||||
# the other user(s) are additional creators
|
||||
invitees = config.get("invite", [])
|
||||
# we don't want to replace any additional_creators additionally specified, and we want
|
||||
# to remove duplicates.
|
||||
creation_content[EventContentFields.ADDITIONAL_CREATORS] = list(
|
||||
set(creation_content.get(EventContentFields.ADDITIONAL_CREATORS, []))
|
||||
| set(invitees)
|
||||
)
|
||||
|
||||
creation_event_with_context = None
|
||||
if room_version.msc4291_room_ids_as_hashes:
|
||||
creation_event_with_context = await self._generate_create_event_for_room_id(
|
||||
requester,
|
||||
creation_content,
|
||||
is_public,
|
||||
room_version,
|
||||
)
|
||||
(create_event, _) = creation_event_with_context
|
||||
room_id = create_event.room_id
|
||||
else:
|
||||
room_id = await self._generate_and_create_room_id(
|
||||
creator_id=user_id,
|
||||
is_public=is_public,
|
||||
room_version=room_version,
|
||||
)
|
||||
|
||||
# Check whether this visibility value is blocked by a third party module
|
||||
allowed_by_third_party_rules = await (
|
||||
@@ -947,11 +1060,6 @@ class RoomCreationHandler:
|
||||
for val in raw_initial_state:
|
||||
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
|
||||
|
||||
creation_content = config.get("creation_content", {})
|
||||
|
||||
# override any attempt to set room versions via the creation_content
|
||||
creation_content["room_version"] = room_version.identifier
|
||||
|
||||
(
|
||||
last_stream_id,
|
||||
last_sent_event_id,
|
||||
@@ -968,6 +1076,7 @@ class RoomCreationHandler:
|
||||
power_level_content_override=power_level_content_override,
|
||||
creator_join_profile=creator_join_profile,
|
||||
ignore_forced_encryption=ignore_forced_encryption,
|
||||
creation_event_with_context=creation_event_with_context,
|
||||
)
|
||||
|
||||
# we avoid dropping the lock between invites, as otherwise joins can
|
||||
@@ -1033,6 +1142,38 @@ class RoomCreationHandler:
|
||||
|
||||
return room_id, room_alias, last_stream_id
|
||||
|
||||
async def _generate_create_event_for_room_id(
|
||||
self,
|
||||
creator: Requester,
|
||||
creation_content: JsonDict,
|
||||
is_public: bool,
|
||||
room_version: RoomVersion,
|
||||
) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
|
||||
(
|
||||
creation_event,
|
||||
new_unpersisted_context,
|
||||
) = await self.event_creation_handler.create_event(
|
||||
creator,
|
||||
{
|
||||
"content": creation_content,
|
||||
"sender": creator.user.to_string(),
|
||||
"type": EventTypes.Create,
|
||||
"state_key": "",
|
||||
},
|
||||
prev_event_ids=[],
|
||||
depth=1,
|
||||
state_map={},
|
||||
for_batch=False,
|
||||
)
|
||||
await self.store.store_room(
|
||||
room_id=creation_event.room_id,
|
||||
room_creator_user_id=creator.user.to_string(),
|
||||
is_public=is_public,
|
||||
room_version=room_version,
|
||||
)
|
||||
creation_context = await new_unpersisted_context.persist(creation_event)
|
||||
return (creation_event, creation_context)
|
||||
|
||||
async def _send_events_for_new_room(
|
||||
self,
|
||||
creator: Requester,
|
||||
@@ -1046,6 +1187,9 @@ class RoomCreationHandler:
|
||||
power_level_content_override: Optional[JsonDict] = None,
|
||||
creator_join_profile: Optional[JsonDict] = None,
|
||||
ignore_forced_encryption: bool = False,
|
||||
creation_event_with_context: Optional[
|
||||
Tuple[EventBase, synapse.events.snapshot.EventContext]
|
||||
] = None,
|
||||
) -> Tuple[int, str, int]:
|
||||
"""Sends the initial events into a new room. Sends the room creation, membership,
|
||||
and power level events into the room sequentially, then creates and batches up the
|
||||
@@ -1082,7 +1226,10 @@ class RoomCreationHandler:
|
||||
user in this room.
|
||||
ignore_forced_encryption:
|
||||
Ignore encryption forced by `encryption_enabled_by_default_for_room_type` setting.
|
||||
|
||||
creation_event_with_context:
|
||||
Set in MSC4291 rooms where the create event determines the room ID. If provided,
|
||||
does not create an additional create event but instead appends the remaining new
|
||||
events onto the provided create event.
|
||||
Returns:
|
||||
A tuple containing the stream ID, event ID and depth of the last
|
||||
event sent to the room.
|
||||
@@ -1147,13 +1294,26 @@ class RoomCreationHandler:
|
||||
|
||||
preset_config, config = self._room_preset_config(room_config)
|
||||
|
||||
# MSC2175 removes the creator field from the create event.
|
||||
if not room_version.implicit_room_creator:
|
||||
creation_content["creator"] = creator_id
|
||||
creation_event, unpersisted_creation_context = await create_event(
|
||||
EventTypes.Create, creation_content, False
|
||||
)
|
||||
creation_context = await unpersisted_creation_context.persist(creation_event)
|
||||
if creation_event_with_context is None:
|
||||
# MSC2175 removes the creator field from the create event.
|
||||
if not room_version.implicit_room_creator:
|
||||
creation_content["creator"] = creator_id
|
||||
creation_event, unpersisted_creation_context = await create_event(
|
||||
EventTypes.Create, creation_content, False
|
||||
)
|
||||
creation_context = await unpersisted_creation_context.persist(
|
||||
creation_event
|
||||
)
|
||||
else:
|
||||
(creation_event, creation_context) = creation_event_with_context
|
||||
# we had to do the above already in order to have a room ID, so just updates local vars
|
||||
# and continue.
|
||||
depth = 2
|
||||
prev_event = [creation_event.event_id]
|
||||
state_map[(creation_event.type, creation_event.state_key)] = (
|
||||
creation_event.event_id
|
||||
)
|
||||
|
||||
logger.debug("Sending %s in new room", EventTypes.Member)
|
||||
ev = await self.event_creation_handler.handle_new_client_event(
|
||||
requester=creator,
|
||||
@@ -1202,7 +1362,9 @@ class RoomCreationHandler:
|
||||
# Please update the docs for `default_power_level_content_override` when
|
||||
# updating the `events` dict below
|
||||
power_level_content: JsonDict = {
|
||||
"users": {creator_id: 100},
|
||||
"users": {creator_id: 100}
|
||||
if not room_version.msc4289_creator_power_enabled
|
||||
else {},
|
||||
"users_default": 0,
|
||||
"events": {
|
||||
EventTypes.Name: 50,
|
||||
@@ -1210,7 +1372,9 @@ class RoomCreationHandler:
|
||||
EventTypes.RoomHistoryVisibility: 100,
|
||||
EventTypes.CanonicalAlias: 50,
|
||||
EventTypes.RoomAvatar: 50,
|
||||
EventTypes.Tombstone: 100,
|
||||
EventTypes.Tombstone: 150
|
||||
if room_version.msc4289_creator_power_enabled
|
||||
else 100,
|
||||
EventTypes.ServerACL: 100,
|
||||
EventTypes.RoomEncryption: 100,
|
||||
},
|
||||
@@ -1223,7 +1387,13 @@ class RoomCreationHandler:
|
||||
"historical": 100,
|
||||
}
|
||||
|
||||
if config["original_invitees_have_ops"]:
|
||||
# original_invitees_have_ops is set on preset:trusted_private_chat which will already
|
||||
# have set these users as additional_creators, hence don't set the PL for creators as
|
||||
# that is invalid.
|
||||
if (
|
||||
config["original_invitees_have_ops"]
|
||||
and not room_version.msc4289_creator_power_enabled
|
||||
):
|
||||
for invitee in invite_list:
|
||||
power_level_content["users"][invitee] = 100
|
||||
|
||||
@@ -1396,6 +1566,19 @@ class RoomCreationHandler:
|
||||
)
|
||||
return preset_name, preset_config
|
||||
|
||||
def _remove_creators_from_pl_users_map(
|
||||
self,
|
||||
users_map: Dict[str, int],
|
||||
creator: str,
|
||||
additional_creators: Optional[List[str]],
|
||||
) -> None:
|
||||
creators = [creator]
|
||||
if additional_creators:
|
||||
creators.extend(additional_creators)
|
||||
for creator in creators:
|
||||
# the creator(s) cannot be in the users map
|
||||
users_map.pop(creator, None)
|
||||
|
||||
def _generate_room_id(self) -> str:
|
||||
"""Generates a random room ID.
|
||||
|
||||
@@ -1413,7 +1596,7 @@ class RoomCreationHandler:
|
||||
A random room ID of the form "!opaque_id:domain".
|
||||
"""
|
||||
random_string = stringutils.random_string(18)
|
||||
return RoomID(random_string, self.hs.hostname).to_string()
|
||||
return RoomIdWithDomain(random_string, self.hs.hostname).to_string()
|
||||
|
||||
async def _generate_and_create_room_id(
|
||||
self,
|
||||
|
||||
@@ -42,7 +42,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.event_auth import get_named_level, get_power_level_event
|
||||
from synapse.events import EventBase
|
||||
from synapse.events import EventBase, is_creator
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.handlers.pagination import PURGE_ROOM_ACTION_NAME
|
||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
|
||||
@@ -1154,9 +1154,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
elif effective_membership_state == Membership.KNOCK:
|
||||
if not is_host_in_room:
|
||||
# The knock needs to be sent over federation instead
|
||||
remote_room_hosts.append(get_domain_from_id(room_id))
|
||||
|
||||
# we used to add the domain of the room ID to remote_room_hosts.
|
||||
# This is not safe in MSC4291 rooms which do not have a domain.
|
||||
content["membership"] = Membership.KNOCK
|
||||
|
||||
try:
|
||||
@@ -2313,6 +2312,7 @@ def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[s
|
||||
|
||||
# Check which members are able to invite by ensuring they're joined and have
|
||||
# the necessary power level.
|
||||
create_event = auth_events[(EventTypes.Create, "")]
|
||||
for (event_type, state_key), event in auth_events.items():
|
||||
if event_type != EventTypes.Member:
|
||||
continue
|
||||
@@ -2320,8 +2320,12 @@ def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[s
|
||||
if event.membership != Membership.JOIN:
|
||||
continue
|
||||
|
||||
if create_event.room_version.msc4289_creator_power_enabled and is_creator(
|
||||
create_event, state_key
|
||||
):
|
||||
result.append(state_key)
|
||||
# Check if the user has a custom power level.
|
||||
if users.get(state_key, users_default_level) >= invite_level:
|
||||
elif users.get(state_key, users_default_level) >= invite_level:
|
||||
result.append(state_key)
|
||||
|
||||
return result
|
||||
|
||||
@@ -627,6 +627,15 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
]
|
||||
admin_users.sort(key=lambda user: user_power[user])
|
||||
|
||||
if create_event.room_version.msc4289_creator_power_enabled:
|
||||
creators = create_event.content.get("additional_creators", []) + [
|
||||
create_event.sender
|
||||
]
|
||||
for creator in creators:
|
||||
if self.is_mine_id(creator):
|
||||
# include the creator as they won't be in the PL users map.
|
||||
admin_users.insert(0, creator)
|
||||
|
||||
if not admin_users:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "No local admin user in room"
|
||||
@@ -666,7 +675,11 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
# updated power level event.
|
||||
new_pl_content = dict(pl_content)
|
||||
new_pl_content["users"] = dict(pl_content.get("users", {}))
|
||||
new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
|
||||
# give the new user the same PL as the admin, default to 100 in case there is no PL event.
|
||||
# This means in v12+ rooms we get PL100 if the creator promotes us.
|
||||
new_pl_content["users"][user_to_add] = new_pl_content["users"].get(
|
||||
admin_user_id, 100
|
||||
)
|
||||
|
||||
fake_requester = create_requester(
|
||||
admin_user_id,
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.event_auth import check_valid_additional_creators
|
||||
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
@@ -85,13 +86,18 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||
"Your homeserver does not support this room version",
|
||||
Codes.UNSUPPORTED_ROOM_VERSION,
|
||||
)
|
||||
additional_creators = None
|
||||
if new_version.msc4289_creator_power_enabled:
|
||||
additional_creators = content.get("additional_creators")
|
||||
if additional_creators is not None:
|
||||
check_valid_additional_creators(additional_creators)
|
||||
|
||||
try:
|
||||
async with self._worker_lock_handler.acquire_read_write_lock(
|
||||
NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
|
||||
):
|
||||
new_room_id = await self._room_creation_handler.upgrade_room(
|
||||
requester, room_id, new_version
|
||||
requester, room_id, new_version, additional_creators
|
||||
)
|
||||
except ShadowBanError:
|
||||
# Generate a random room ID.
|
||||
|
||||
@@ -53,6 +53,7 @@ from synapse.logging.context import ContextResourceUsage
|
||||
from synapse.logging.opentracing import tag_args, trace
|
||||
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.event_federation import StateDifference
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import StateMap, StrCollection
|
||||
from synapse.types.state import StateFilter
|
||||
@@ -976,17 +977,35 @@ class StateResolutionStore:
|
||||
)
|
||||
|
||||
def get_auth_chain_difference(
|
||||
self, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Awaitable[Set[str]]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
self,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
conflicted_state: Optional[Set[str]],
|
||||
additional_backwards_reachable_conflicted_events: Optional[Set[str]],
|
||||
) -> Awaitable[StateDifference]:
|
||||
""" "Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
This equivalent to fetching the full auth chain for each set of state
|
||||
This is equivalent to fetching the full auth chain for each set of state
|
||||
and returning the events that don't appear in each and every auth
|
||||
chain.
|
||||
|
||||
If conflicted_state is not None, calculate and return the conflicted sub-graph as per
|
||||
state res v2.1. The event IDs in the conflicted state MUST be a subset of the event IDs in
|
||||
state_sets.
|
||||
|
||||
If additional_backwards_reachable_conflicted_events is set, the provided events are included
|
||||
when calculating the conflicted subgraph. This is primarily useful for calculating the
|
||||
subgraph across a combination of persisted and unpersisted events.
|
||||
|
||||
Returns:
|
||||
An awaitable that resolves to a set of event IDs.
|
||||
information on the auth chain difference, and also the conflicted subgraph if
|
||||
conflicted_state is not None
|
||||
"""
|
||||
|
||||
return self.main_store.get_auth_chain_difference(room_id, state_sets)
|
||||
return self.main_store.get_auth_chain_difference_extended(
|
||||
room_id,
|
||||
state_sets,
|
||||
conflicted_state,
|
||||
additional_backwards_reachable_conflicted_events,
|
||||
)
|
||||
|
||||
@@ -39,10 +39,11 @@ from typing import (
|
||||
)
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import CREATOR_POWER_LEVEL, EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.api.room_versions import RoomVersion, StateResolutionVersions
|
||||
from synapse.events import EventBase, is_creator
|
||||
from synapse.storage.databases.main.event_federation import StateDifference
|
||||
from synapse.types import MutableStateMap, StateMap, StrCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,8 +64,12 @@ class StateResolutionStore(Protocol):
|
||||
) -> Awaitable[Dict[str, EventBase]]: ...
|
||||
|
||||
def get_auth_chain_difference(
|
||||
self, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Awaitable[Set[str]]: ...
|
||||
self,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
conflicted_state: Optional[Set[str]],
|
||||
additional_backwards_reachable_conflicted_events: Optional[set[str]],
|
||||
) -> Awaitable[StateDifference]: ...
|
||||
|
||||
|
||||
# We want to await to the reactor occasionally during state res when dealing
|
||||
@@ -123,12 +128,17 @@ async def resolve_events_with_store(
|
||||
logger.debug("%d conflicted state entries", len(conflicted_state))
|
||||
logger.debug("Calculating auth chain difference")
|
||||
|
||||
# Also fetch all auth events that appear in only some of the state sets'
|
||||
# auth chains.
|
||||
conflicted_set: Optional[Set[str]] = None
|
||||
if room_version.state_res == StateResolutionVersions.V2_1:
|
||||
# calculate the conflicted subgraph
|
||||
conflicted_set = set(itertools.chain.from_iterable(conflicted_state.values()))
|
||||
auth_diff = await _get_auth_chain_difference(
|
||||
room_id, state_sets, event_map, state_res_store
|
||||
room_id,
|
||||
state_sets,
|
||||
event_map,
|
||||
state_res_store,
|
||||
conflicted_set,
|
||||
)
|
||||
|
||||
full_conflicted_set = set(
|
||||
itertools.chain(
|
||||
itertools.chain.from_iterable(conflicted_state.values()), auth_diff
|
||||
@@ -168,15 +178,26 @@ async def resolve_events_with_store(
|
||||
|
||||
logger.debug("sorted %d power events", len(sorted_power_events))
|
||||
|
||||
# v2.1 starts iterative auth checks from the empty set and not the unconflicted state.
|
||||
# It relies on IAC behaviour which populates the base state with the events from auth_events
|
||||
# if the state tuple is missing from the base state. This ensures the base state is only
|
||||
# populated from auth_events rather than whatever the unconflicted state is (which could be
|
||||
# completely bogus).
|
||||
base_state = (
|
||||
{}
|
||||
if room_version.state_res == StateResolutionVersions.V2_1
|
||||
else unconflicted_state
|
||||
)
|
||||
|
||||
# Now sequentially auth each one
|
||||
resolved_state = await _iterative_auth_checks(
|
||||
clock,
|
||||
room_id,
|
||||
room_version,
|
||||
sorted_power_events,
|
||||
unconflicted_state,
|
||||
event_map,
|
||||
state_res_store,
|
||||
event_ids=sorted_power_events,
|
||||
base_state=base_state,
|
||||
event_map=event_map,
|
||||
state_res_store=state_res_store,
|
||||
)
|
||||
|
||||
logger.debug("resolved power events")
|
||||
@@ -239,13 +260,23 @@ async def _get_power_level_for_sender(
|
||||
event = await _get_event(room_id, event_id, event_map, state_res_store)
|
||||
|
||||
pl = None
|
||||
create = None
|
||||
for aid in event.auth_event_ids():
|
||||
aev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
pl = aev
|
||||
break
|
||||
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
|
||||
create = aev
|
||||
|
||||
if event.type != EventTypes.Create:
|
||||
# we should always have a create event
|
||||
assert create is not None
|
||||
|
||||
if create and create.room_version.msc4289_creator_power_enabled:
|
||||
if is_creator(create, event.sender):
|
||||
return CREATOR_POWER_LEVEL
|
||||
|
||||
if pl is None:
|
||||
# Couldn't find power level. Check if they're the creator of the room
|
||||
@@ -286,6 +317,7 @@ async def _get_auth_chain_difference(
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
unpersisted_events: Dict[str, EventBase],
|
||||
state_res_store: StateResolutionStore,
|
||||
conflicted_state: Optional[Set[str]],
|
||||
) -> Set[str]:
|
||||
"""Compare the auth chains of each state set and return the set of events
|
||||
that only appear in some, but not all of the auth chains.
|
||||
@@ -294,11 +326,18 @@ async def _get_auth_chain_difference(
|
||||
state_sets: The input state sets we are trying to resolve across.
|
||||
unpersisted_events: A map from event ID to EventBase containing all unpersisted
|
||||
events involved in this resolution.
|
||||
state_res_store:
|
||||
state_res_store: A way to retrieve events and extract graph information on the auth chains.
|
||||
conflicted_state: which event IDs are conflicted. Used in v2.1 for calculating the conflicted
|
||||
subgraph.
|
||||
|
||||
Returns:
|
||||
The auth difference of the given state sets, as a set of event IDs.
|
||||
The auth difference of the given state sets, as a set of event IDs. Also includes the
|
||||
conflicted subgraph if `conflicted_state` is set.
|
||||
"""
|
||||
is_state_res_v21 = conflicted_state is not None
|
||||
num_conflicted_state = (
|
||||
len(conflicted_state) if conflicted_state is not None else None
|
||||
)
|
||||
|
||||
# The `StateResolutionStore.get_auth_chain_difference` function assumes that
|
||||
# all events passed to it (and their auth chains) have been persisted
|
||||
@@ -318,14 +357,19 @@ async def _get_auth_chain_difference(
|
||||
# the event's auth chain with the events in `unpersisted_events` *plus* their
|
||||
# auth event IDs.
|
||||
events_to_auth_chain: Dict[str, Set[str]] = {}
|
||||
# remember the forward links when doing the graph traversal, we'll need it for v2.1 checks
|
||||
# This is a map from an event to the set of events that contain it as an auth event.
|
||||
event_to_next_event: Dict[str, Set[str]] = {}
|
||||
for event in unpersisted_events.values():
|
||||
chain = {event.event_id}
|
||||
events_to_auth_chain[event.event_id] = chain
|
||||
|
||||
to_search = [event]
|
||||
while to_search:
|
||||
for auth_id in to_search.pop().auth_event_ids():
|
||||
next_event = to_search.pop()
|
||||
for auth_id in next_event.auth_event_ids():
|
||||
chain.add(auth_id)
|
||||
event_to_next_event.setdefault(auth_id, set()).add(next_event.event_id)
|
||||
auth_event = unpersisted_events.get(auth_id)
|
||||
if auth_event:
|
||||
to_search.append(auth_event)
|
||||
@@ -335,6 +379,8 @@ async def _get_auth_chain_difference(
|
||||
#
|
||||
# Note: If there are no `unpersisted_events` (which is the common case), we can do a
|
||||
# much simpler calculation.
|
||||
additional_backwards_reachable_conflicted_events: Set[str] = set()
|
||||
unpersisted_conflicted_events: Set[str] = set()
|
||||
if unpersisted_events:
|
||||
# The list of state sets to pass to the store, where each state set is a set
|
||||
# of the event ids making up the state. This is similar to `state_sets`,
|
||||
@@ -372,7 +418,16 @@ async def _get_auth_chain_difference(
|
||||
)
|
||||
else:
|
||||
set_ids.add(event_id)
|
||||
|
||||
if conflicted_state:
|
||||
for conflicted_event_id in conflicted_state:
|
||||
# presence in this map means it is unpersisted.
|
||||
event_chain = events_to_auth_chain.get(conflicted_event_id)
|
||||
if event_chain is not None:
|
||||
unpersisted_conflicted_events.add(conflicted_event_id)
|
||||
# tell the DB layer that we have some unpersisted conflicted events
|
||||
additional_backwards_reachable_conflicted_events.update(
|
||||
e for e in event_chain if e not in unpersisted_events
|
||||
)
|
||||
# The auth chain difference of the unpersisted events of the state sets
|
||||
# is calculated by taking the difference between the union and
|
||||
# intersections.
|
||||
@@ -384,12 +439,89 @@ async def _get_auth_chain_difference(
|
||||
auth_difference_unpersisted_part = ()
|
||||
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
|
||||
|
||||
difference = await state_res_store.get_auth_chain_difference(
|
||||
room_id, state_sets_ids
|
||||
)
|
||||
difference.update(auth_difference_unpersisted_part)
|
||||
if conflicted_state:
|
||||
# to ensure that conflicted state is a subset of state set IDs, we need to remove UNPERSISTED
|
||||
# conflicted state set ids as we removed them above.
|
||||
conflicted_state = conflicted_state - unpersisted_conflicted_events
|
||||
|
||||
return difference
|
||||
difference = await state_res_store.get_auth_chain_difference(
|
||||
room_id,
|
||||
state_sets_ids,
|
||||
conflicted_state,
|
||||
additional_backwards_reachable_conflicted_events,
|
||||
)
|
||||
difference.auth_difference.update(auth_difference_unpersisted_part)
|
||||
|
||||
# if we're doing v2.1 we may need to add or expand the conflicted subgraph
|
||||
if (
|
||||
is_state_res_v21
|
||||
and difference.conflicted_subgraph is not None
|
||||
and unpersisted_events
|
||||
):
|
||||
# we always include the conflicted events themselves in the subgraph.
|
||||
if conflicted_state:
|
||||
difference.conflicted_subgraph.update(conflicted_state)
|
||||
# we may need to expand the subgraph in the case where the subgraph starts in the DB and
|
||||
# ends in unpersisted events. To do this, we first need to see where the subgraph got up to,
|
||||
# which we can do by finding the intersection between the additional backwards reachable
|
||||
# conflicted events and the conflicted subgraph. Events in both sets mean A) some unpersisted
|
||||
# conflicted event could backwards reach it and B) some persisted conflicted event could forward
|
||||
# reach it.
|
||||
subgraph_frontier = difference.conflicted_subgraph.intersection(
|
||||
additional_backwards_reachable_conflicted_events
|
||||
)
|
||||
# we can now combine the 2 scenarios:
|
||||
# - subgraph starts in DB and ends in unpersisted
|
||||
# - subgraph starts in unpersisted and ends in unpersisted
|
||||
# by expanding the frontier into unpersisted events.
|
||||
# The frontier is currently all persisted events. We want to expand this into unpersisted
|
||||
# events. Mark every forwards reachable event from the frontier in the forwards_conflicted_set
|
||||
# but NOT the backwards conflicted set. This mirrors what the DB layer does but in reverse:
|
||||
# we supplied events which are backwards reachable to the DB and now the DB is providing
|
||||
# forwards reachable events from the DB.
|
||||
forwards_conflicted_set: Set[str] = set()
|
||||
# we include unpersisted conflicted events here to process exclusive unpersisted subgraphs
|
||||
search_queue = subgraph_frontier.union(unpersisted_conflicted_events)
|
||||
while search_queue:
|
||||
frontier_event = search_queue.pop()
|
||||
next_event_ids = event_to_next_event.get(frontier_event, set())
|
||||
search_queue.update(next_event_ids)
|
||||
forwards_conflicted_set.add(frontier_event)
|
||||
|
||||
# we've already calculated the backwards form as this is the auth chain for each
|
||||
# unpersisted conflicted event.
|
||||
backwards_conflicted_set: Set[str] = set()
|
||||
for uce in unpersisted_conflicted_events:
|
||||
backwards_conflicted_set.update(events_to_auth_chain.get(uce, []))
|
||||
|
||||
# the unpersisted conflicted subgraph is the intersection of the backwards/forwards sets
|
||||
conflicted_subgraph_unpersisted_part = backwards_conflicted_set.intersection(
|
||||
forwards_conflicted_set
|
||||
)
|
||||
# print(f"event_to_next_event={event_to_next_event}")
|
||||
# print(f"unpersisted_conflicted_events={unpersisted_conflicted_events}")
|
||||
# print(f"unperssited backwards_conflicted_set={backwards_conflicted_set}")
|
||||
# print(f"unperssited forwards_conflicted_set={forwards_conflicted_set}")
|
||||
difference.conflicted_subgraph.update(conflicted_subgraph_unpersisted_part)
|
||||
|
||||
if difference.conflicted_subgraph:
|
||||
old_events = difference.auth_difference.union(
|
||||
conflicted_state if conflicted_state else set()
|
||||
)
|
||||
additional_events = difference.conflicted_subgraph.difference(old_events)
|
||||
|
||||
logger.debug(
|
||||
"v2.1 %s additional events replayed=%d num_conflicts=%d conflicted_subgraph=%d auth_difference=%d",
|
||||
room_id,
|
||||
len(additional_events),
|
||||
num_conflicted_state,
|
||||
len(difference.conflicted_subgraph),
|
||||
len(difference.auth_difference),
|
||||
)
|
||||
# State res v2.1 includes the conflicted subgraph in the difference
|
||||
return difference.auth_difference.union(difference.conflicted_subgraph)
|
||||
|
||||
return difference.auth_difference
|
||||
|
||||
|
||||
def _seperate(
|
||||
|
||||
@@ -110,6 +110,12 @@ _LONGEST_BACKOFF_PERIOD_MILLISECONDS = (
|
||||
assert 0 < _LONGEST_BACKOFF_PERIOD_MILLISECONDS <= ((2**31) - 1)
|
||||
|
||||
|
||||
# We use 2^53-1 as a "very large number", it has no particular
|
||||
# importance other than knowing synapse can support it (given canonical json
|
||||
# requires it).
|
||||
MAX_CHAIN_LENGTH = (2**53) - 1
|
||||
|
||||
|
||||
# All the info we need while iterating the DAG while backfilling
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class BackfillQueueNavigationItem:
|
||||
@@ -119,6 +125,14 @@ class BackfillQueueNavigationItem:
|
||||
type: str
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class StateDifference:
|
||||
# The event IDs in the auth difference.
|
||||
auth_difference: Set[str]
|
||||
# The event IDs in the conflicted state subgraph. Used in v2.1 only.
|
||||
conflicted_subgraph: Optional[Set[str]]
|
||||
|
||||
|
||||
class _NoChainCoverIndex(Exception):
|
||||
def __init__(self, room_id: str):
|
||||
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
|
||||
@@ -467,17 +481,41 @@ class EventFederationWorkerStore(
|
||||
return results
|
||||
|
||||
async def get_auth_chain_difference(
|
||||
self, room_id: str, state_sets: List[Set[str]]
|
||||
self,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
) -> Set[str]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
state_diff = await self.get_auth_chain_difference_extended(
|
||||
room_id, state_sets, None, None
|
||||
)
|
||||
return state_diff.auth_difference
|
||||
|
||||
async def get_auth_chain_difference_extended(
|
||||
self,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
conflicted_set: Optional[Set[str]],
|
||||
additional_backwards_reachable_conflicted_events: Optional[Set[str]],
|
||||
) -> StateDifference:
|
||||
""" "Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
This equivalent to fetching the full auth chain for each set of state
|
||||
This is equivalent to fetching the full auth chain for each set of state
|
||||
and returning the events that don't appear in each and every auth
|
||||
chain.
|
||||
|
||||
If conflicted_set is not None, calculate and return the conflicted sub-graph as per
|
||||
state res v2.1. The event IDs in the conflicted set MUST be a subset of the event IDs in
|
||||
state_sets.
|
||||
|
||||
If additional_backwards_reachable_conflicted_events is set, the provided events are included
|
||||
when calculating the conflicted subgraph. This is primarily useful for calculating the
|
||||
subgraph across a combination of persisted and unpersisted events. The event IDs in this set
|
||||
MUST be a subset of the event IDs in state_sets.
|
||||
|
||||
Returns:
|
||||
The set of the difference in auth chains.
|
||||
information on the auth chain difference, and also the conflicted subgraph if
|
||||
conflicted_set is not None
|
||||
"""
|
||||
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
@@ -491,6 +529,8 @@ class EventFederationWorkerStore(
|
||||
self._get_auth_chain_difference_using_cover_index_txn,
|
||||
room_id,
|
||||
state_sets,
|
||||
conflicted_set,
|
||||
additional_backwards_reachable_conflicted_events,
|
||||
)
|
||||
except _NoChainCoverIndex:
|
||||
# For whatever reason we don't actually have a chain cover index
|
||||
@@ -499,25 +539,48 @@ class EventFederationWorkerStore(
|
||||
if not self.tests_allow_no_chain_cover_index:
|
||||
raise
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
# It's been 4 years since we added chain cover, so we expect all rooms to have it.
|
||||
# If they don't, we will error out when trying to do state res v2.1
|
||||
if conflicted_set is not None:
|
||||
raise _NoChainCoverIndex(room_id)
|
||||
|
||||
auth_diff = await self.db_pool.runInteraction(
|
||||
"get_auth_chain_difference",
|
||||
self._get_auth_chain_difference_txn,
|
||||
state_sets,
|
||||
)
|
||||
return StateDifference(auth_difference=auth_diff, conflicted_subgraph=None)
|
||||
|
||||
def _get_auth_chain_difference_using_cover_index_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
state_sets: List[Set[str]],
|
||||
conflicted_set: Optional[Set[str]] = None,
|
||||
additional_backwards_reachable_conflicted_events: Optional[Set[str]] = None,
|
||||
) -> StateDifference:
|
||||
"""Calculates the auth chain difference using the chain index.
|
||||
|
||||
See docs/auth_chain_difference_algorithm.md for details
|
||||
"""
|
||||
is_state_res_v21 = conflicted_set is not None
|
||||
|
||||
# First we look up the chain ID/sequence numbers for all the events, and
|
||||
# work out the chain/sequence numbers reachable from each state set.
|
||||
|
||||
initial_events = set(state_sets[0]).union(*state_sets[1:])
|
||||
|
||||
if is_state_res_v21:
|
||||
# Sanity check v2.1 fields
|
||||
assert conflicted_set is not None
|
||||
assert conflicted_set.issubset(initial_events)
|
||||
# It's possible for the conflicted_set to be empty if all the conflicts are in
|
||||
# unpersisted events, so we don't assert that conflicted_set has len > 0
|
||||
if additional_backwards_reachable_conflicted_events:
|
||||
assert additional_backwards_reachable_conflicted_events.issubset(
|
||||
initial_events
|
||||
)
|
||||
|
||||
# Map from event_id -> (chain ID, seq no)
|
||||
chain_info: Dict[str, Tuple[int, int]] = {}
|
||||
|
||||
@@ -553,14 +616,14 @@ class EventFederationWorkerStore(
|
||||
events_missing_chain_info = initial_events.difference(chain_info)
|
||||
|
||||
# The result set to return, i.e. the auth chain difference.
|
||||
result: Set[str] = set()
|
||||
auth_difference_result: Set[str] = set()
|
||||
|
||||
if events_missing_chain_info:
|
||||
# For some reason we have events we haven't calculated the chain
|
||||
# index for, so we need to handle those separately. This should only
|
||||
# happen for older rooms where the server doesn't have all the auth
|
||||
# events.
|
||||
result = self._fixup_auth_chain_difference_sets(
|
||||
auth_difference_result = self._fixup_auth_chain_difference_sets(
|
||||
txn,
|
||||
room_id,
|
||||
state_sets=state_sets,
|
||||
@@ -579,6 +642,45 @@ class EventFederationWorkerStore(
|
||||
|
||||
fetch_chain_info(new_events_to_fetch)
|
||||
|
||||
# State Res v2.1 needs extra data structures to calculate the conflicted subgraph which
|
||||
# are outlined below.
|
||||
|
||||
# A subset of chain_info for conflicted events only, as we need to
|
||||
# loop all conflicted chain positions. Map from event_id -> (chain ID, seq no)
|
||||
conflicted_chain_positions: Dict[str, Tuple[int, int]] = {}
|
||||
# For each chain, remember the positions where conflicted events are.
|
||||
# We need this for calculating the forward reachable events.
|
||||
conflicted_chain_to_seq: Dict[int, Set[int]] = {} # chain_id => {seq_num}
|
||||
# A subset of chain_info for additional backwards reachable events only, as we need to
|
||||
# loop all additional backwards reachable events for calculating backwards reachable events.
|
||||
additional_backwards_reachable_positions: Dict[
|
||||
str, Tuple[int, int]
|
||||
] = {} # event_id => (chain_id, seq_num)
|
||||
# These next two fields are critical as the intersection of them is the conflicted subgraph.
|
||||
# We'll populate them when we walk the chain links.
|
||||
# chain_id => max(seq_num) backwards reachable (e.g 4 means 1,2,3,4 are backwards reachable)
|
||||
conflicted_backwards_reachable: Dict[int, int] = {}
|
||||
# chain_id => min(seq_num) forwards reachable (e.g 4 means 4,5,6..n are forwards reachable)
|
||||
conflicted_forwards_reachable: Dict[int, int] = {}
|
||||
|
||||
# populate the v2.1 data structures
|
||||
if is_state_res_v21:
|
||||
assert conflicted_set is not None
|
||||
# provide chain positions for each conflicted event
|
||||
for conflicted_event_id in conflicted_set:
|
||||
(chain_id, seq_num) = chain_info[conflicted_event_id]
|
||||
conflicted_chain_positions[conflicted_event_id] = (chain_id, seq_num)
|
||||
conflicted_chain_to_seq.setdefault(chain_id, set()).add(seq_num)
|
||||
if additional_backwards_reachable_conflicted_events:
|
||||
for (
|
||||
additional_event_id
|
||||
) in additional_backwards_reachable_conflicted_events:
|
||||
(chain_id, seq_num) = chain_info[additional_event_id]
|
||||
additional_backwards_reachable_positions[additional_event_id] = (
|
||||
chain_id,
|
||||
seq_num,
|
||||
)
|
||||
|
||||
# Corresponds to `state_sets`, except as a map from chain ID to max
|
||||
# sequence number reachable from the state set.
|
||||
set_to_chain: List[Dict[int, int]] = []
|
||||
@@ -596,6 +698,8 @@ class EventFederationWorkerStore(
|
||||
|
||||
# (We need to take a copy of `seen_chains` as the function mutates it)
|
||||
for links in self._get_chain_links(txn, set(seen_chains)):
|
||||
# `links` encodes the backwards reachable events _from a single chain_ all the way to
|
||||
# the root of the graph.
|
||||
for chains in set_to_chain:
|
||||
for chain_id in links:
|
||||
if chain_id not in chains:
|
||||
@@ -604,6 +708,87 @@ class EventFederationWorkerStore(
|
||||
_materialize(chain_id, chains[chain_id], links, chains)
|
||||
|
||||
seen_chains.update(chains)
|
||||
if is_state_res_v21:
|
||||
# Apply v2.1 conflicted event reachability checks.
|
||||
#
|
||||
# A <-- B <-- C <-- D <-- E
|
||||
#
|
||||
# Backwards reachable from C = {A,B}
|
||||
# Forwards reachable from C = {D,E}
|
||||
|
||||
# this handles calculating forwards reachable information and updates
|
||||
# conflicted_forwards_reachable.
|
||||
accumulate_forwards_reachable_events(
|
||||
conflicted_forwards_reachable,
|
||||
links,
|
||||
conflicted_chain_positions,
|
||||
)
|
||||
|
||||
# handle backwards reachable information
|
||||
for (
|
||||
conflicted_chain_id,
|
||||
conflicted_chain_seq,
|
||||
) in conflicted_chain_positions.values():
|
||||
if conflicted_chain_id not in links:
|
||||
# This conflicted event does not lie on the path to the root.
|
||||
continue
|
||||
|
||||
# The conflicted chain position itself encodes reachability information
|
||||
# _within_ the chain. Set it now before walking to other links.
|
||||
conflicted_backwards_reachable[conflicted_chain_id] = max(
|
||||
conflicted_chain_seq,
|
||||
conflicted_backwards_reachable.get(conflicted_chain_id, 0),
|
||||
)
|
||||
|
||||
# Build backwards reachability paths. This is the same as what the auth difference
|
||||
# code does. We find which chain the conflicted event
|
||||
# belongs to then walk it backwards to the root. We store reachability info
|
||||
# for all conflicted events in the same map 'conflicted_backwards_reachable'
|
||||
# as we don't care about the paths themselves.
|
||||
_materialize(
|
||||
conflicted_chain_id,
|
||||
conflicted_chain_seq,
|
||||
links,
|
||||
conflicted_backwards_reachable,
|
||||
)
|
||||
# Mark some extra events as backwards reachable. This is used when we have some
|
||||
# unpersisted events and want to know the subgraph across the persisted/unpersisted
|
||||
# boundary:
|
||||
# |
|
||||
# A <-- B <-- C <-|- D <-- E <-- F
|
||||
# persisted | unpersisted
|
||||
#
|
||||
# Assume {B,E} are conflicted, we want to return {B,C,D,E}
|
||||
#
|
||||
# The unpersisted code ensures it passes C as an additional backwards reachable
|
||||
# event. C is NOT a conflicted event, but we do need to consider it as part of
|
||||
# the backwards reachable set. When we then calculate the forwards reachable set
|
||||
# from B, C will be in both the backwards and forwards reachable sets and hence
|
||||
# will be included in the conflicted subgraph.
|
||||
for (
|
||||
additional_chain_id,
|
||||
additional_chain_seq,
|
||||
) in additional_backwards_reachable_positions.values():
|
||||
if additional_chain_id not in links:
|
||||
# The additional backwards reachable event does not lie on the path to the root.
|
||||
continue
|
||||
|
||||
# the additional event chain position itself encodes reachability information.
|
||||
# It means that position and all positions earlier in that chain are backwards reachable
|
||||
# by some unpersisted conflicted event.
|
||||
conflicted_backwards_reachable[additional_chain_id] = max(
|
||||
additional_chain_seq,
|
||||
conflicted_backwards_reachable.get(additional_chain_id, 0),
|
||||
)
|
||||
|
||||
# Now walk the chains back, marking backwards reachable events.
|
||||
# This is the same thing we do for auth difference / conflicted events.
|
||||
_materialize(
|
||||
additional_chain_id, # walk all links back, marking them as backwards reachable
|
||||
additional_chain_seq,
|
||||
links,
|
||||
conflicted_backwards_reachable,
|
||||
)
|
||||
|
||||
# Now for each chain we figure out the maximum sequence number reachable
|
||||
# from *any* state set and the minimum sequence number reachable from
|
||||
@@ -612,7 +797,7 @@ class EventFederationWorkerStore(
|
||||
|
||||
# Mapping from chain ID to the range of sequence numbers that should be
|
||||
# pulled from the database.
|
||||
chain_to_gap: Dict[int, Tuple[int, int]] = {}
|
||||
auth_diff_chain_to_gap: Dict[int, Tuple[int, int]] = {}
|
||||
|
||||
for chain_id in seen_chains:
|
||||
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
|
||||
@@ -625,15 +810,76 @@ class EventFederationWorkerStore(
|
||||
for seq_no in range(min_seq_no + 1, max_seq_no + 1):
|
||||
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
|
||||
if event_id:
|
||||
result.add(event_id)
|
||||
auth_difference_result.add(event_id)
|
||||
else:
|
||||
chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
|
||||
auth_diff_chain_to_gap[chain_id] = (min_seq_no, max_seq_no)
|
||||
break
|
||||
|
||||
if not chain_to_gap:
|
||||
# If there are no gaps to fetch, we're done!
|
||||
return result
|
||||
conflicted_subgraph_result: Set[str] = set()
|
||||
# Mapping from chain ID to the range of sequence numbers that should be
|
||||
# pulled from the database.
|
||||
conflicted_subgraph_chain_to_gap: Dict[int, Tuple[int, int]] = {}
|
||||
if is_state_res_v21:
|
||||
# also include the conflicted subgraph using backward/forward reachability info from all
|
||||
# the conflicted events. To calculate this, we want to extract the intersection between
|
||||
# the backwards and forwards reachability sets, e.g:
|
||||
# A <- B <- C <- D <- E
|
||||
# Assume B and D are conflicted so we want {C} as the conflicted subgraph.
|
||||
# B_backwards={A}, B_forwards={C,D,E}
|
||||
# D_backwards={A,B,C} D_forwards={E}
|
||||
# ALL_backwards={A,B,C} ALL_forwards={C,D,E}
|
||||
# Intersection(ALL_backwards, ALL_forwards) = {C}
|
||||
#
|
||||
# It's worth noting that once we have the ALL_ sets, we no longer care about the paths.
|
||||
# We're dealing with chains and not singular events, but we've already got the ALL_ sets.
|
||||
# As such, we can inspect each chain in isolation and check for overlapping sequence
|
||||
# numbers:
|
||||
# 1,2,3,4,5 Seq Num
|
||||
# Chain N [A,B,C,D,E]
|
||||
#
|
||||
# if (N,4) is in the backwards set and (N,2) is in the forwards set, then the
|
||||
# intersection is events between 2 < 4. We will include the conflicted events themselves
|
||||
# in the subgraph, but they will already be, hence the full set of events is {B,C,D}.
|
||||
for chain_id, backwards_seq_num in conflicted_backwards_reachable.items():
|
||||
forwards_seq_num = conflicted_forwards_reachable.get(chain_id)
|
||||
if forwards_seq_num is None:
|
||||
continue # this chain isn't in both sets so can't intersect
|
||||
if forwards_seq_num > backwards_seq_num:
|
||||
continue # this chain is in both sets but they don't overap
|
||||
for seq_no in range(
|
||||
forwards_seq_num, backwards_seq_num + 1
|
||||
): # inclusive of both
|
||||
event_id = chain_to_event.get(chain_id, {}).get(seq_no)
|
||||
if event_id:
|
||||
conflicted_subgraph_result.add(event_id)
|
||||
else:
|
||||
conflicted_subgraph_chain_to_gap[chain_id] = (
|
||||
# _fetch_event_ids_from_chains_txn is exclusive of the min value
|
||||
forwards_seq_num - 1,
|
||||
backwards_seq_num,
|
||||
)
|
||||
break
|
||||
|
||||
if auth_diff_chain_to_gap:
|
||||
auth_difference_result.update(
|
||||
self._fetch_event_ids_from_chains_txn(txn, auth_diff_chain_to_gap)
|
||||
)
|
||||
if conflicted_subgraph_chain_to_gap:
|
||||
conflicted_subgraph_result.update(
|
||||
self._fetch_event_ids_from_chains_txn(
|
||||
txn, conflicted_subgraph_chain_to_gap
|
||||
)
|
||||
)
|
||||
|
||||
return StateDifference(
|
||||
auth_difference=auth_difference_result,
|
||||
conflicted_subgraph=conflicted_subgraph_result,
|
||||
)
|
||||
|
||||
def _fetch_event_ids_from_chains_txn(
|
||||
self, txn: LoggingTransaction, chains: Dict[int, Tuple[int, int]]
|
||||
) -> Set[str]:
|
||||
result: Set[str] = set()
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We can use `execute_values` to efficiently fetch the gaps when
|
||||
# using postgres.
|
||||
@@ -647,7 +893,7 @@ class EventFederationWorkerStore(
|
||||
|
||||
args = [
|
||||
(chain_id, min_no, max_no)
|
||||
for chain_id, (min_no, max_no) in chain_to_gap.items()
|
||||
for chain_id, (min_no, max_no) in chains.items()
|
||||
]
|
||||
|
||||
rows = txn.execute_values(sql, args)
|
||||
@@ -658,10 +904,9 @@ class EventFederationWorkerStore(
|
||||
SELECT event_id FROM event_auth_chains
|
||||
WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ?
|
||||
"""
|
||||
for chain_id, (min_no, max_no) in chain_to_gap.items():
|
||||
for chain_id, (min_no, max_no) in chains.items():
|
||||
txn.execute(sql, (chain_id, min_no, max_no))
|
||||
result.update(r for (r,) in txn)
|
||||
|
||||
return result
|
||||
|
||||
def _fixup_auth_chain_difference_sets(
|
||||
@@ -2155,6 +2400,7 @@ def _materialize(
|
||||
origin_sequence_number: int,
|
||||
links: Dict[int, List[Tuple[int, int, int]]],
|
||||
materialized: Dict[int, int],
|
||||
backwards: bool = True,
|
||||
) -> None:
|
||||
"""Helper function for fetching auth chain links. For a given origin chain
|
||||
ID / sequence number and a dictionary of links, updates the materialized
|
||||
@@ -2171,6 +2417,7 @@ def _materialize(
|
||||
target sequence number.
|
||||
materialized: dict to update with new reachability information, as a
|
||||
map from chain ID to max sequence number reachable.
|
||||
backwards: If True, walks backwards down the chains. If False, walks forwards from the chains.
|
||||
"""
|
||||
|
||||
# Do a standard graph traversal.
|
||||
@@ -2185,12 +2432,104 @@ def _materialize(
|
||||
target_chain_id,
|
||||
target_sequence_number,
|
||||
) in chain_links:
|
||||
# Ignore any links that are higher up the chain
|
||||
if sequence_number > s:
|
||||
continue
|
||||
if backwards:
|
||||
# Ignore any links that are higher up the chain
|
||||
if sequence_number > s:
|
||||
continue
|
||||
|
||||
# Check if we have already visited the target chain before, if so we
|
||||
# can skip it.
|
||||
if materialized.get(target_chain_id, 0) < target_sequence_number:
|
||||
stack.append((target_chain_id, target_sequence_number))
|
||||
materialized[target_chain_id] = target_sequence_number
|
||||
# Check if we have already visited the target chain before, if so we
|
||||
# can skip it.
|
||||
if materialized.get(target_chain_id, 0) < target_sequence_number:
|
||||
stack.append((target_chain_id, target_sequence_number))
|
||||
materialized[target_chain_id] = target_sequence_number
|
||||
else:
|
||||
# Ignore any links that are lower down the chain.
|
||||
if sequence_number < s:
|
||||
continue
|
||||
# Check if we have already visited the target chain before, if so we
|
||||
# can skip it.
|
||||
if (
|
||||
materialized.get(target_chain_id, MAX_CHAIN_LENGTH)
|
||||
> target_sequence_number
|
||||
):
|
||||
stack.append((target_chain_id, target_sequence_number))
|
||||
materialized[target_chain_id] = target_sequence_number
|
||||
|
||||
|
||||
def _generate_forward_links(
|
||||
links: Dict[int, List[Tuple[int, int, int]]],
|
||||
) -> Dict[int, List[Tuple[int, int, int]]]:
|
||||
"""Reverse the input links from the given backwards links"""
|
||||
new_links: Dict[int, List[Tuple[int, int, int]]] = {}
|
||||
for origin_chain_id, chain_links in links.items():
|
||||
for origin_seq_num, target_chain_id, target_seq_num in chain_links:
|
||||
new_links.setdefault(target_chain_id, []).append(
|
||||
(target_seq_num, origin_chain_id, origin_seq_num)
|
||||
)
|
||||
return new_links
|
||||
|
||||
|
||||
def accumulate_forwards_reachable_events(
|
||||
conflicted_forwards_reachable: Dict[int, int],
|
||||
back_links: Dict[int, List[Tuple[int, int, int]]],
|
||||
conflicted_chain_positions: Dict[str, Tuple[int, int]],
|
||||
) -> None:
|
||||
"""Accumulate new forwards reachable events using the back_links provided.
|
||||
|
||||
Accumulating forwards reachable information is quite different from backwards reachable information
|
||||
because _get_chain_links returns the entire linkage information for backwards reachable events,
|
||||
but not _forwards_ reachable events. We are only interested in the forwards reachable information
|
||||
that is encoded in the backwards reachable links, so we can just invert all the operations we do
|
||||
for backwards reachable events to calculate a subset of forwards reachable information. The
|
||||
caveat with this approach is that it is a _subset_. This means new back_links may encode new
|
||||
forwards reachable information which we also need. Consider this scenario:
|
||||
|
||||
A <-- B <-- C <--- D <-- E <-- F Chain 1
|
||||
|
|
||||
`----- G <-- H <-- I Chain 2
|
||||
|
|
||||
`---- J <-- K Chain 3
|
||||
|
||||
Now consider what happens when B is a conflicted event. _get_chain_links returns the conflicted
|
||||
chain and ALL links heading towards the root of the graph. This means we will know the
|
||||
Chain 1 to Chain 2 link via C (as all links for the chain are returned, not strictly ones with
|
||||
a lower sequence number), but we will NOT know the Chain 2 to Chain 3 link via H. We can be
|
||||
blissfully unaware of Chain 3 entirely, if and only if there isn't some other conflicted event
|
||||
on that chain. Consider what happens when K is /also/ conflicted. _get_chain_links will generate
|
||||
two iterations: one for B and one for K. It's important that we re-evaluate the forwards reachable
|
||||
information for B to include Chain 3 when we process the K iteration, hence we are "accumulating"
|
||||
forwards reachability information.
|
||||
|
||||
NB: We don't consider 'additional backwards reachable events' here because they have no effect
|
||||
on forwards reachability calculations, only backwards.
|
||||
|
||||
Args:
|
||||
conflicted_forwards_reachable: The materialised dict of forwards reachable information.
|
||||
The output to this function are stored here.
|
||||
back_links: One iteration of _get_chain_links which encodes backwards reachable information.
|
||||
conflicted_chain_positions: The conflicted events.
|
||||
"""
|
||||
# links go backwards but we want them to go forwards as well for v2.1
|
||||
fwd_links = _generate_forward_links(back_links)
|
||||
|
||||
# for each conflicted event, accumulate forwards reachability information
|
||||
for (
|
||||
conflicted_chain_id,
|
||||
conflicted_chain_seq,
|
||||
) in conflicted_chain_positions.values():
|
||||
# the conflicted event itself encodes reachability information
|
||||
# e.g if D was conflicted, it encodes E,F as forwards reachable.
|
||||
conflicted_forwards_reachable[conflicted_chain_id] = min(
|
||||
conflicted_chain_seq,
|
||||
conflicted_forwards_reachable.get(conflicted_chain_id, MAX_CHAIN_LENGTH),
|
||||
)
|
||||
# Walk from the conflicted event forwards to explore the links.
|
||||
# This function checks if we've visited the chain before and skips reprocessing, so this
|
||||
# does not repeatedly traverse the graph.
|
||||
_materialize(
|
||||
conflicted_chain_id,
|
||||
conflicted_chain_seq,
|
||||
fwd_links,
|
||||
conflicted_forwards_reachable,
|
||||
backwards=False,
|
||||
)
|
||||
|
||||
@@ -354,12 +354,78 @@ class RoomAlias(DomainSpecificString):
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, repr=False)
|
||||
class RoomID(DomainSpecificString):
|
||||
"""Structure representing a room id."""
|
||||
class RoomIdWithDomain(DomainSpecificString):
|
||||
"""Structure representing a room ID with a domain suffix."""
|
||||
|
||||
SIGIL = "!"
|
||||
|
||||
|
||||
# the set of urlsafe base64 characters, no padding.
|
||||
ROOM_ID_PATTERN_DOMAINLESS = re.compile(r"^[A-Za-z0-9\-_]{43}$")
|
||||
|
||||
|
||||
@attr.define(slots=True, frozen=True, repr=False)
|
||||
class RoomID:
|
||||
"""Structure representing a room id without a domain.
|
||||
There are two forms of room IDs:
|
||||
- "!localpart:domain" used in most room versions prior to MSC4291.
|
||||
- "!event_id_base_64" used in room versions post MSC4291.
|
||||
This class will accept any room ID which meets either of these two criteria.
|
||||
"""
|
||||
|
||||
SIGIL = "!"
|
||||
id: str
|
||||
room_id_with_domain: Optional[RoomIdWithDomain]
|
||||
|
||||
@classmethod
|
||||
def is_valid(cls: Type["RoomID"], s: str) -> bool:
|
||||
if ":" in s:
|
||||
return RoomIdWithDomain.is_valid(s)
|
||||
try:
|
||||
cls.from_string(s)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_domain(self) -> Optional[str]:
|
||||
if not self.room_id_with_domain:
|
||||
return None
|
||||
return self.room_id_with_domain.domain
|
||||
|
||||
def to_string(self) -> str:
|
||||
if self.room_id_with_domain:
|
||||
return self.room_id_with_domain.to_string()
|
||||
return self.id
|
||||
|
||||
__repr__ = to_string
|
||||
|
||||
@classmethod
|
||||
def from_string(cls: Type["RoomID"], s: str) -> "RoomID":
|
||||
# sigil check
|
||||
if len(s) < 1 or s[0] != cls.SIGIL:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL),
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
room_id_with_domain: Optional[RoomIdWithDomain] = None
|
||||
if ":" in s:
|
||||
room_id_with_domain = RoomIdWithDomain.from_string(s)
|
||||
else:
|
||||
# MSC4291 room IDs must be valid urlsafe unpadded base64
|
||||
val = s[1:]
|
||||
if not ROOM_ID_PATTERN_DOMAINLESS.match(val):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Expected %s string to be valid urlsafe unpadded base64 '%s'"
|
||||
% (cls.__name__, val),
|
||||
Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
return cls(id=s, room_id_with_domain=room_id_with_domain)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, repr=False)
|
||||
class EventID(DomainSpecificString):
|
||||
"""Structure representing an event ID which is namespaced to a homeserver.
|
||||
|
||||
@@ -33,6 +33,7 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventTypes, Membership, RoomTypes
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.handlers.pagination import (
|
||||
PURGE_ROOM_ACTION_NAME,
|
||||
SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME,
|
||||
@@ -2892,6 +2893,30 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase):
|
||||
"No local admin user in room with power to update power levels.",
|
||||
)
|
||||
|
||||
def test_v12_room(self) -> None:
|
||||
"""Test that you can be promoted to admin in v12 rooms which won't have the admin the PL event."""
|
||||
room_id = self.helper.create_room_as(
|
||||
self.creator,
|
||||
tok=self.creator_tok,
|
||||
room_version=RoomVersions.V12.identifier,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin",
|
||||
content={},
|
||||
access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
|
||||
# Now we test that we can join the room and that the admin user has PL 100.
|
||||
self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok)
|
||||
pl = self.helper.get_state(
|
||||
room_id, EventTypes.PowerLevels, tok=self.creator_tok
|
||||
)
|
||||
self.assertEquals(pl["users"][self.admin_user], 100)
|
||||
|
||||
|
||||
class BlockRoomTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
|
||||
@@ -45,6 +45,7 @@ from synapse.state.v2 import (
|
||||
lexicographical_topological_sort,
|
||||
resolve_events_with_store,
|
||||
)
|
||||
from synapse.storage.databases.main.event_federation import StateDifference
|
||||
from synapse.types import EventID, StateMap
|
||||
|
||||
from tests import unittest
|
||||
@@ -734,7 +735,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||
store = TestStateResolutionStore(persisted_events)
|
||||
|
||||
diff_d = _get_auth_chain_difference(
|
||||
ROOM_ID, state_sets, unpersited_events, store
|
||||
ROOM_ID,
|
||||
state_sets,
|
||||
unpersited_events,
|
||||
store,
|
||||
None,
|
||||
)
|
||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||
|
||||
@@ -791,7 +796,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||
store = TestStateResolutionStore(persisted_events)
|
||||
|
||||
diff_d = _get_auth_chain_difference(
|
||||
ROOM_ID, state_sets, unpersited_events, store
|
||||
ROOM_ID,
|
||||
state_sets,
|
||||
unpersited_events,
|
||||
store,
|
||||
None,
|
||||
)
|
||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||
|
||||
@@ -858,7 +867,11 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
|
||||
store = TestStateResolutionStore(persisted_events)
|
||||
|
||||
diff_d = _get_auth_chain_difference(
|
||||
ROOM_ID, state_sets, unpersited_events, store
|
||||
ROOM_ID,
|
||||
state_sets,
|
||||
unpersited_events,
|
||||
store,
|
||||
None,
|
||||
)
|
||||
difference = self.successResultOf(defer.ensureDeferred(diff_d))
|
||||
|
||||
@@ -1070,9 +1083,18 @@ class TestStateResolutionStore:
|
||||
return list(result)
|
||||
|
||||
def get_auth_chain_difference(
|
||||
self, room_id: str, auth_sets: List[Set[str]]
|
||||
) -> "defer.Deferred[Set[str]]":
|
||||
self,
|
||||
room_id: str,
|
||||
auth_sets: List[Set[str]],
|
||||
conflicted_state: Optional[Set[str]],
|
||||
additional_backwards_reachable_conflicted_events: Optional[Set[str]],
|
||||
) -> "defer.Deferred[StateDifference]":
|
||||
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
|
||||
|
||||
common = set(chains[0]).intersection(*chains[1:])
|
||||
return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
|
||||
return defer.succeed(
|
||||
StateDifference(
|
||||
auth_difference=set(chains[0]).union(*chains[1:]) - common,
|
||||
conflicted_subgraph=set(),
|
||||
),
|
||||
)
|
||||
|
||||
503
tests/state/test_v21.py
Normal file
503
tests/state/test_v21.py
Normal file
@@ -0,0 +1,503 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
import itertools
|
||||
from typing import Dict, List, Optional, Sequence, Set
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.federation_base import event_from_pdu_json
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.state import StateResolutionStore
|
||||
from synapse.state.v2 import (
|
||||
StateResolutionStore as StateResolutionStoreInterface,
|
||||
_get_auth_chain_difference,
|
||||
_seperate,
|
||||
resolve_events_with_store,
|
||||
)
|
||||
from synapse.types import StateMap
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.state.test_v2 import TestStateResolutionStore
|
||||
|
||||
ALICE = "@alice:example.com"
|
||||
BOB = "@bob:example.com"
|
||||
CHARLIE = "@charlie:example.com"
|
||||
EVELYN = "@evelyn:example.com"
|
||||
ZARA = "@zara:example.com"
|
||||
|
||||
ROOM_ID = "!test:example.com"
|
||||
|
||||
MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
|
||||
MEMBERSHIP_CONTENT_INVITE = {"membership": Membership.INVITE}
|
||||
MEMBERSHIP_CONTENT_LEAVE = {"membership": Membership.LEAVE}
|
||||
|
||||
|
||||
ORIGIN_SERVER_TS = 0
|
||||
|
||||
|
||||
def monotonic_timestamp() -> int:
|
||||
global ORIGIN_SERVER_TS
|
||||
ORIGIN_SERVER_TS += 1
|
||||
return ORIGIN_SERVER_TS
|
||||
|
||||
|
||||
class FakeClock:
|
||||
def sleep(self, msec: float) -> "defer.Deferred[None]":
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
class StateResV21TestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(
|
||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
|
||||
) -> None:
|
||||
self.state = self.hs.get_state_handler()
|
||||
persistence = self.hs.get_storage_controllers().persistence
|
||||
assert persistence is not None
|
||||
self._persistence = persistence
|
||||
self._state_storage_controller = self.hs.get_storage_controllers().state
|
||||
self._state_deletion = self.hs.get_datastores().state_deletion
|
||||
self.store = self.hs.get_datastores().main
|
||||
|
||||
self.register_user("user", "pass")
|
||||
self.token = self.login("user", "pass")
|
||||
|
||||
def test_state_reset_replay_conflicted_subgraph(self) -> None:
|
||||
# 1. Alice creates a room.
|
||||
e1_create = self.create_event(
|
||||
EventTypes.Create,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"creator": ALICE},
|
||||
auth_events=[],
|
||||
)
|
||||
# 2. Alice joins it.
|
||||
e2_ma = self.create_event(
|
||||
EventTypes.Member,
|
||||
ALICE,
|
||||
sender=ALICE,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[],
|
||||
prev_events=[e1_create.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 3. Alice is the creator
|
||||
e3_power1 = self.create_event(
|
||||
EventTypes.PowerLevels,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"users": {}},
|
||||
auth_events=[e2_ma.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 4. Alice sets the room to public.
|
||||
e4_jr = self.create_event(
|
||||
EventTypes.JoinRules,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"join_rule": JoinRules.PUBLIC},
|
||||
auth_events=[e2_ma.event_id, e3_power1.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 5. Bob joins the room.
|
||||
e5_mb = self.create_event(
|
||||
EventTypes.Member,
|
||||
BOB,
|
||||
sender=BOB,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[e3_power1.event_id, e4_jr.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 6. Charlie joins the room.
|
||||
e6_mc = self.create_event(
|
||||
EventTypes.Member,
|
||||
CHARLIE,
|
||||
sender=CHARLIE,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[e3_power1.event_id, e4_jr.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 7. Alice promotes Bob.
|
||||
e7_power2 = self.create_event(
|
||||
EventTypes.PowerLevels,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"users": {BOB: 50}},
|
||||
auth_events=[e2_ma.event_id, e3_power1.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 8. Bob promotes Charlie.
|
||||
e8_power3 = self.create_event(
|
||||
EventTypes.PowerLevels,
|
||||
"",
|
||||
sender=BOB,
|
||||
content={"users": {BOB: 50, CHARLIE: 50}},
|
||||
auth_events=[e5_mb.event_id, e7_power2.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 9. Eve joins the room.
|
||||
e9_me1 = self.create_event(
|
||||
EventTypes.Member,
|
||||
EVELYN,
|
||||
sender=EVELYN,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[e8_power3.event_id, e4_jr.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 10. Eve changes her name, /!\\ but cites old power levels /!\
|
||||
e10_me2 = self.create_event(
|
||||
EventTypes.Member,
|
||||
EVELYN,
|
||||
sender=EVELYN,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[
|
||||
e3_power1.event_id,
|
||||
e4_jr.event_id,
|
||||
e9_me1.event_id,
|
||||
],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 11. Zara joins the room, citing the most recent power levels.
|
||||
e11_mz = self.create_event(
|
||||
EventTypes.Member,
|
||||
ZARA,
|
||||
sender=ZARA,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[e8_power3.event_id, e4_jr.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
|
||||
# Event 10 above is DODGY: it directly cites old auth events, but indirectly
|
||||
# cites new ones. If the state after event 10 contains old power level and old
|
||||
# join events, we are vulnerable to a reset.
|
||||
|
||||
dodgy_state_after_eve_rename: StateMap[str] = {
|
||||
(EventTypes.Create, ""): e1_create.event_id,
|
||||
(EventTypes.Member, ALICE): e2_ma.event_id,
|
||||
(EventTypes.Member, BOB): e5_mb.event_id,
|
||||
(EventTypes.Member, CHARLIE): e6_mc.event_id,
|
||||
(EventTypes.Member, EVELYN): e10_me2.event_id,
|
||||
(EventTypes.PowerLevels, ""): e3_power1.event_id, # old and /!\\ DODGY /!\
|
||||
(EventTypes.JoinRules, ""): e4_jr.event_id,
|
||||
}
|
||||
|
||||
sensible_state_after_zara_joins: StateMap[str] = {
|
||||
(EventTypes.Create, ""): e1_create.event_id,
|
||||
(EventTypes.Member, ALICE): e2_ma.event_id,
|
||||
(EventTypes.Member, BOB): e5_mb.event_id,
|
||||
(EventTypes.Member, CHARLIE): e6_mc.event_id,
|
||||
(EventTypes.Member, ZARA): e11_mz.event_id,
|
||||
(EventTypes.PowerLevels, ""): e8_power3.event_id,
|
||||
(EventTypes.JoinRules, ""): e4_jr.event_id,
|
||||
}
|
||||
|
||||
expected: StateMap[str] = {
|
||||
(EventTypes.Create, ""): e1_create.event_id,
|
||||
(EventTypes.Member, ALICE): e2_ma.event_id,
|
||||
(EventTypes.Member, BOB): e5_mb.event_id,
|
||||
(EventTypes.Member, CHARLIE): e6_mc.event_id,
|
||||
# Expect ME2 replayed first: it's in the POWER 1 epoch
|
||||
# Then ME1, in the POWER 3 epoch
|
||||
(EventTypes.Member, EVELYN): e9_me1.event_id,
|
||||
(EventTypes.Member, ZARA): e11_mz.event_id,
|
||||
(EventTypes.PowerLevels, ""): e8_power3.event_id,
|
||||
(EventTypes.JoinRules, ""): e4_jr.event_id,
|
||||
}
|
||||
|
||||
self.get_resolution_and_verify_expected(
|
||||
[dodgy_state_after_eve_rename, sensible_state_after_zara_joins],
|
||||
[
|
||||
e1_create,
|
||||
e2_ma,
|
||||
e3_power1,
|
||||
e4_jr,
|
||||
e5_mb,
|
||||
e6_mc,
|
||||
e7_power2,
|
||||
e8_power3,
|
||||
e9_me1,
|
||||
e10_me2,
|
||||
e11_mz,
|
||||
],
|
||||
expected,
|
||||
)
|
||||
|
||||
def test_state_reset_start_empty_set(self) -> None:
|
||||
# The join rules reset to missing, when:
|
||||
# - join rules were in conflict
|
||||
# - the membership of those join rules' senders were not in conflict
|
||||
# - those memberships are all leaves.
|
||||
|
||||
# 1. Alice creates a room.
|
||||
e1_create = self.create_event(
|
||||
EventTypes.Create,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"creator": ALICE},
|
||||
auth_events=[],
|
||||
)
|
||||
# 2. Alice joins it.
|
||||
e2_ma1 = self.create_event(
|
||||
EventTypes.Member,
|
||||
ALICE,
|
||||
sender=ALICE,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 3. Alice makes Bob an admin.
|
||||
e3_power = self.create_event(
|
||||
EventTypes.PowerLevels,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"users": {BOB: 100}},
|
||||
auth_events=[e2_ma1.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 4. Alice sets the room to public.
|
||||
e4_jr1 = self.create_event(
|
||||
EventTypes.JoinRules,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"join_rule": JoinRules.PUBLIC},
|
||||
auth_events=[e2_ma1.event_id, e3_power.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 5. Bob joins.
|
||||
e5_mb = self.create_event(
|
||||
EventTypes.Member,
|
||||
BOB,
|
||||
sender=BOB,
|
||||
content=MEMBERSHIP_CONTENT_JOIN,
|
||||
auth_events=[e3_power.event_id, e4_jr1.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 6. Alice sets join rules to invite.
|
||||
e6_jr2 = self.create_event(
|
||||
EventTypes.JoinRules,
|
||||
"",
|
||||
sender=ALICE,
|
||||
content={"join_rule": JoinRules.INVITE},
|
||||
auth_events=[e2_ma1.event_id, e3_power.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
# 7. Alice then leaves.
|
||||
e7_ma2 = self.create_event(
|
||||
EventTypes.Member,
|
||||
ALICE,
|
||||
sender=ALICE,
|
||||
content=MEMBERSHIP_CONTENT_LEAVE,
|
||||
auth_events=[e3_power.event_id, e2_ma1.event_id],
|
||||
room_id=e1_create.room_id,
|
||||
)
|
||||
|
||||
correct_state: StateMap[str] = {
|
||||
(EventTypes.Create, ""): e1_create.event_id,
|
||||
(EventTypes.Member, ALICE): e7_ma2.event_id,
|
||||
(EventTypes.Member, BOB): e5_mb.event_id,
|
||||
(EventTypes.PowerLevels, ""): e3_power.event_id,
|
||||
(EventTypes.JoinRules, ""): e6_jr2.event_id,
|
||||
}
|
||||
|
||||
# Imagine that another server gives us incorrect state on a fork
|
||||
# (via e.g. backfill). It cites the old join rules.
|
||||
incorrect_state: StateMap[str] = {
|
||||
(EventTypes.Create, ""): e1_create.event_id,
|
||||
(EventTypes.Member, ALICE): e7_ma2.event_id,
|
||||
(EventTypes.Member, BOB): e5_mb.event_id,
|
||||
(EventTypes.PowerLevels, ""): e3_power.event_id,
|
||||
(EventTypes.JoinRules, ""): e4_jr1.event_id,
|
||||
}
|
||||
|
||||
# Resolving those two should give us the new join rules.
|
||||
expected: StateMap[str] = {
|
||||
(EventTypes.Create, ""): e1_create.event_id,
|
||||
(EventTypes.Member, ALICE): e7_ma2.event_id,
|
||||
(EventTypes.Member, BOB): e5_mb.event_id,
|
||||
(EventTypes.PowerLevels, ""): e3_power.event_id,
|
||||
(EventTypes.JoinRules, ""): e6_jr2.event_id,
|
||||
}
|
||||
|
||||
self.get_resolution_and_verify_expected(
|
||||
[correct_state, incorrect_state],
|
||||
[e1_create, e2_ma1, e3_power, e4_jr1, e5_mb, e6_jr2, e7_ma2],
|
||||
expected,
|
||||
)
|
||||
|
||||
async def _get_auth_difference_and_conflicted_subgraph(
|
||||
self,
|
||||
room_id: str,
|
||||
state_maps: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: StateResolutionStoreInterface,
|
||||
) -> Set[str]:
|
||||
_, conflicted_state = _seperate(state_maps)
|
||||
conflicted_set: Optional[Set[str]] = set(
|
||||
itertools.chain.from_iterable(conflicted_state.values())
|
||||
)
|
||||
if event_map is None:
|
||||
event_map = {}
|
||||
return await _get_auth_chain_difference(
|
||||
room_id,
|
||||
state_maps,
|
||||
event_map,
|
||||
state_res_store,
|
||||
conflicted_set,
|
||||
)
|
||||
|
||||
def get_resolution_and_verify_expected(
|
||||
self,
|
||||
state_maps: Sequence[StateMap[str]],
|
||||
events: List[EventBase],
|
||||
expected: StateMap[str],
|
||||
) -> None:
|
||||
room_id = events[0].room_id
|
||||
# First we try everything in-memory to check that the test case works.
|
||||
event_map = {ev.event_id: ev for ev in events}
|
||||
resolution = self.successResultOf(
|
||||
resolve_events_with_store(
|
||||
FakeClock(),
|
||||
room_id,
|
||||
events[0].room_version,
|
||||
state_maps,
|
||||
event_map=event_map,
|
||||
state_res_store=TestStateResolutionStore(event_map),
|
||||
)
|
||||
)
|
||||
self.assertEqual(resolution, expected)
|
||||
|
||||
got_auth_diff = self.successResultOf(
|
||||
self._get_auth_difference_and_conflicted_subgraph(
|
||||
room_id,
|
||||
state_maps,
|
||||
event_map,
|
||||
TestStateResolutionStore(event_map),
|
||||
)
|
||||
)
|
||||
# we should never see the create event in the auth diff. If we do, this implies the
|
||||
# conflicted subgraph is wrong and is returning too many old events.
|
||||
assert events[0].event_id not in got_auth_diff
|
||||
|
||||
# now let's make the room exist on the DB, some queries rely on there being a row in
|
||||
# the rooms table when persisting
|
||||
self.get_success(
|
||||
self.store.store_room(
|
||||
room_id,
|
||||
events[0].sender,
|
||||
True,
|
||||
events[0].room_version,
|
||||
)
|
||||
)
|
||||
|
||||
def resolve_and_check() -> None:
|
||||
event_map = {ev.event_id: ev for ev in events}
|
||||
store = StateResolutionStore(
|
||||
self._persistence.main_store,
|
||||
self._state_deletion,
|
||||
)
|
||||
resolution = self.get_success(
|
||||
resolve_events_with_store(
|
||||
FakeClock(),
|
||||
room_id,
|
||||
RoomVersions.HydraV11,
|
||||
state_maps,
|
||||
event_map=event_map,
|
||||
state_res_store=store,
|
||||
)
|
||||
)
|
||||
self.assertEqual(resolution, expected)
|
||||
got_auth_diff2 = self.get_success(
|
||||
self._get_auth_difference_and_conflicted_subgraph(
|
||||
room_id,
|
||||
state_maps,
|
||||
event_map,
|
||||
store,
|
||||
)
|
||||
)
|
||||
# no matter how many events are persisted, the overall diff should always be the same.
|
||||
self.assertEquals(got_auth_diff, got_auth_diff2)
|
||||
|
||||
# now we will drip feed in `events` one-by-one, persisting them then resolving with the
|
||||
# rest. This ensures we correctly handle mixed persisted/unpersisted events. We will finish
|
||||
# with doing the test with all persisted events.
|
||||
while len(events) > 0:
|
||||
event_to_persist = events.pop(0)
|
||||
self.persist_event(event_to_persist)
|
||||
# now retest
|
||||
resolve_and_check()
|
||||
|
||||
def persist_event(
|
||||
self, event: EventBase, state: Optional[StateMap[str]] = None
|
||||
) -> None:
|
||||
"""Persist the event, with optional state"""
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event=state,
|
||||
partial_state=None if state is None else False,
|
||||
)
|
||||
)
|
||||
self.get_success(self._persistence.persist_event(event, context))
|
||||
|
||||
def create_event(
|
||||
self,
|
||||
event_type: str,
|
||||
state_key: Optional[str],
|
||||
sender: str,
|
||||
content: Dict,
|
||||
auth_events: List[str],
|
||||
prev_events: Optional[List[str]] = None,
|
||||
room_id: Optional[str] = None,
|
||||
) -> EventBase:
|
||||
"""Short-hand for event_from_pdu_json for fields we typically care about.
|
||||
Tests can override by just calling event_from_pdu_json directly."""
|
||||
if prev_events is None:
|
||||
prev_events = []
|
||||
|
||||
pdu = {
|
||||
"type": event_type,
|
||||
"state_key": state_key,
|
||||
"content": content,
|
||||
"sender": sender,
|
||||
"depth": 5,
|
||||
"prev_events": prev_events,
|
||||
"auth_events": auth_events,
|
||||
"origin_server_ts": monotonic_timestamp(),
|
||||
}
|
||||
if event_type != EventTypes.Create:
|
||||
if room_id is None:
|
||||
raise Exception("must specify a room_id to create_event")
|
||||
pdu["room_id"] = room_id
|
||||
return event_from_pdu_json(
|
||||
pdu,
|
||||
RoomVersions.HydraV11,
|
||||
)
|
||||
@@ -26,6 +26,7 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NamedTuple,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
@@ -730,6 +731,202 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertSetEqual(difference, set())
|
||||
|
||||
def test_conflicted_subgraph(self) -> None:
|
||||
"""Test that the conflicted subgraph code in state res v2.1 can walk arbitrary length links
|
||||
and chains.
|
||||
|
||||
We construct a chain cover index like this:
|
||||
|
||||
A1 <- A2 <- A3
|
||||
^--- B1 <- B2 <- B3
|
||||
^--- C1 <- C2 <- C3 v--- G1 <- G2
|
||||
^--- D1 <- D2 <- D3
|
||||
^--- E1
|
||||
^--- F1 <- F2
|
||||
|
||||
..and then pick various events to be conflicted / additional backwards reachable to assert
|
||||
that the code walks the chains correctly. We're particularly interested in ensuring that
|
||||
the code walks multiple links between chains, hence why we have so many chains.
|
||||
"""
|
||||
|
||||
class TestNode(NamedTuple):
|
||||
event_id: str
|
||||
chain_id: int
|
||||
seq_num: int
|
||||
|
||||
class TestLink(NamedTuple):
|
||||
origin_chain_and_seq: Tuple[int, int]
|
||||
target_chain_and_seq: Tuple[int, int]
|
||||
|
||||
# Map to chain IDs / seq nums
|
||||
nodes: List[TestNode] = [
|
||||
TestNode("A1", 1, 1),
|
||||
TestNode("A2", 1, 2),
|
||||
TestNode("A3", 1, 3),
|
||||
TestNode("B1", 2, 1),
|
||||
TestNode("B2", 2, 2),
|
||||
TestNode("B3", 2, 3),
|
||||
TestNode("C1", 3, 1),
|
||||
TestNode("C2", 3, 2),
|
||||
TestNode("C3", 3, 3),
|
||||
TestNode("D1", 4, 1),
|
||||
TestNode("D2", 4, 2),
|
||||
TestNode("D3", 4, 3),
|
||||
TestNode("E1", 5, 1),
|
||||
TestNode("F1", 6, 1),
|
||||
TestNode("F2", 6, 2),
|
||||
TestNode("G1", 7, 1),
|
||||
TestNode("G2", 7, 2),
|
||||
]
|
||||
links: List[TestLink] = [
|
||||
TestLink((2, 1), (1, 2)), # B1 -> A2
|
||||
TestLink((3, 1), (2, 2)), # C1 -> B2
|
||||
TestLink((4, 1), (3, 1)), # D1 -> C1
|
||||
TestLink((5, 1), (4, 2)), # E1 -> D2
|
||||
TestLink((6, 1), (5, 1)), # F1 -> E1
|
||||
TestLink((7, 1), (4, 3)), # G1 -> D3
|
||||
]
|
||||
|
||||
# populate the chain cover index tables as that's all we need
|
||||
for node in nodes:
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_insert(
|
||||
"event_auth_chains",
|
||||
{
|
||||
"event_id": node.event_id,
|
||||
"chain_id": node.chain_id,
|
||||
"sequence_number": node.seq_num,
|
||||
},
|
||||
desc="insert",
|
||||
)
|
||||
)
|
||||
for link in links:
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_insert(
|
||||
"event_auth_chain_links",
|
||||
{
|
||||
"origin_chain_id": link.origin_chain_and_seq[0],
|
||||
"origin_sequence_number": link.origin_chain_and_seq[1],
|
||||
"target_chain_id": link.target_chain_and_seq[0],
|
||||
"target_sequence_number": link.target_chain_and_seq[1],
|
||||
},
|
||||
desc="insert",
|
||||
)
|
||||
)
|
||||
|
||||
# Define the test cases
|
||||
class TestCase(NamedTuple):
|
||||
name: str
|
||||
conflicted: Set[str]
|
||||
additional_backwards_reachable: Set[str]
|
||||
want_conflicted_subgraph: Set[str]
|
||||
|
||||
# Reminder:
|
||||
# A1 <- A2 <- A3
|
||||
# ^--- B1 <- B2 <- B3
|
||||
# ^--- C1 <- C2 <- C3 v--- G1 <- G2
|
||||
# ^--- D1 <- D2 <- D3
|
||||
# ^--- E1
|
||||
# ^--- F1 <- F2
|
||||
test_cases = [
|
||||
TestCase(
|
||||
name="basic_single_chain",
|
||||
conflicted={"B1", "B3"},
|
||||
additional_backwards_reachable=set(),
|
||||
want_conflicted_subgraph={"B1", "B2", "B3"},
|
||||
),
|
||||
TestCase(
|
||||
name="basic_single_link",
|
||||
conflicted={"A1", "B2"},
|
||||
additional_backwards_reachable=set(),
|
||||
want_conflicted_subgraph={"A1", "A2", "B1", "B2"},
|
||||
),
|
||||
TestCase(
|
||||
name="basic_multi_link",
|
||||
conflicted={"B1", "F1"},
|
||||
additional_backwards_reachable=set(),
|
||||
want_conflicted_subgraph={"B1", "B2", "C1", "D1", "D2", "E1", "F1"},
|
||||
),
|
||||
# Repeat these tests but put the later event as an additional backwards reachable event.
|
||||
# The output should be the same.
|
||||
TestCase(
|
||||
name="basic_single_chain_as_additional",
|
||||
conflicted={"B1"},
|
||||
additional_backwards_reachable={"B3"},
|
||||
want_conflicted_subgraph={"B1", "B2", "B3"},
|
||||
),
|
||||
TestCase(
|
||||
name="basic_single_link_as_additional",
|
||||
conflicted={"A1"},
|
||||
additional_backwards_reachable={"B2"},
|
||||
want_conflicted_subgraph={"A1", "A2", "B1", "B2"},
|
||||
),
|
||||
TestCase(
|
||||
name="basic_multi_link_as_additional",
|
||||
conflicted={"B1"},
|
||||
additional_backwards_reachable={"F1"},
|
||||
want_conflicted_subgraph={"B1", "B2", "C1", "D1", "D2", "E1", "F1"},
|
||||
),
|
||||
TestCase(
|
||||
name="mixed_multi_link",
|
||||
conflicted={"D1", "F1"},
|
||||
additional_backwards_reachable={"G1"},
|
||||
want_conflicted_subgraph={"D1", "D2", "D3", "E1", "F1", "G1"},
|
||||
),
|
||||
TestCase(
|
||||
name="additional_backwards_doesnt_add_forwards_info",
|
||||
conflicted={"C1", "C3"},
|
||||
# This is on the path to the root but Chain C isn't forwards reachable so this doesn't
|
||||
# change anything.
|
||||
additional_backwards_reachable={"B1"},
|
||||
want_conflicted_subgraph={"C1", "C2", "C3"},
|
||||
),
|
||||
TestCase(
|
||||
name="empty_subgraph",
|
||||
conflicted={
|
||||
"B3",
|
||||
"C3",
|
||||
}, # these can't reach each other going forwards, so the subgraph is empty
|
||||
additional_backwards_reachable=set(),
|
||||
want_conflicted_subgraph={"B3", "C3"},
|
||||
),
|
||||
TestCase(
|
||||
name="empty_subgraph_with_additional",
|
||||
conflicted={"C1"},
|
||||
# This is on the path to the root but Chain C isn't forwards reachable so this doesn't
|
||||
# change anything.
|
||||
additional_backwards_reachable={"B1"},
|
||||
want_conflicted_subgraph={"C1"},
|
||||
),
|
||||
TestCase(
|
||||
name="empty_subgraph_single_conflict",
|
||||
conflicted={"C1"}, # no subgraph can form as you need 2+
|
||||
additional_backwards_reachable=set(),
|
||||
want_conflicted_subgraph={"C1"},
|
||||
),
|
||||
]
|
||||
|
||||
def run_test(txn: LoggingTransaction, test_case: TestCase) -> None:
|
||||
result = self.store._get_auth_chain_difference_using_cover_index_txn(
|
||||
txn,
|
||||
"!not_relevant",
|
||||
[test_case.conflicted.union(test_case.additional_backwards_reachable)],
|
||||
test_case.conflicted,
|
||||
test_case.additional_backwards_reachable,
|
||||
)
|
||||
self.assertEquals(
|
||||
result.conflicted_subgraph,
|
||||
test_case.want_conflicted_subgraph,
|
||||
f"{test_case.name} : conflicted subgraph mismatch",
|
||||
)
|
||||
|
||||
for test_case in test_cases:
|
||||
self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
f"test_case_{test_case.name}", run_test, test_case
|
||||
)
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[(room_version,) for room_version in KNOWN_ROOM_VERSIONS.values()]
|
||||
)
|
||||
|
||||
@@ -66,6 +66,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
body = self.helper.send(self.room_id, body="Test", tok=self.token)
|
||||
local_message_event_id = body["event_id"]
|
||||
|
||||
current_state = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
# Fudge a remote event and persist it. This will be the extremity before
|
||||
# the gap.
|
||||
self.remote_event_1 = event_from_pdu_json(
|
||||
@@ -77,7 +81,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other",
|
||||
"depth": 5,
|
||||
"prev_events": [local_message_event_id],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
current_state.get((EventTypes.Create, "")),
|
||||
current_state.get((EventTypes.PowerLevels, "")),
|
||||
current_state.get((EventTypes.JoinRules, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
@@ -113,6 +121,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
the same domain.
|
||||
"""
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
# Fudge a second event which points to an event we don't have. This is a
|
||||
# state event so that the state changes (otherwise we won't prune the
|
||||
# extremity as they'll have the same state group).
|
||||
@@ -125,16 +137,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other",
|
||||
"depth": 50,
|
||||
"prev_events": ["$some_unknown_message"],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
state_before_gap.get((EventTypes.Create, "")),
|
||||
state_before_gap.get((EventTypes.PowerLevels, "")),
|
||||
state_before_gap.get((EventTypes.JoinRules, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
@@ -145,6 +157,15 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
state is different.
|
||||
"""
|
||||
|
||||
# Now we persist it with state with a dropped history visibility
|
||||
# setting. The state resolution across the old and new event will then
|
||||
# include it, and so the resolved state won't match the new state.
|
||||
state_before_gap = dict(
|
||||
self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
)
|
||||
|
||||
# Fudge a second event which points to an event we don't have.
|
||||
remote_event_2 = event_from_pdu_json(
|
||||
{
|
||||
@@ -155,20 +176,15 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other",
|
||||
"depth": 10,
|
||||
"prev_events": ["$some_unknown_message"],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
state_before_gap.get((EventTypes.Create, "")),
|
||||
state_before_gap.get((EventTypes.PowerLevels, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
# Now we persist it with state with a dropped history visibility
|
||||
# setting. The state resolution across the old and new event will then
|
||||
# include it, and so the resolved state won't match the new state.
|
||||
state_before_gap = dict(
|
||||
self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
)
|
||||
state_before_gap.pop(("m.room.history_visibility", ""))
|
||||
|
||||
context = self.get_success(
|
||||
@@ -193,6 +209,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
# also set the depth to "lots".
|
||||
self.reactor.advance(7 * 24 * 60 * 60)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
# Fudge a second event which points to an event we don't have. This is a
|
||||
# state event so that the state changes (otherwise we won't prune the
|
||||
# extremity as they'll have the same state group).
|
||||
@@ -205,16 +225,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other2",
|
||||
"depth": 10000,
|
||||
"prev_events": ["$some_unknown_message"],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
state_before_gap.get((EventTypes.Create, "")),
|
||||
state_before_gap.get((EventTypes.PowerLevels, "")),
|
||||
state_before_gap.get((EventTypes.JoinRules, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
@@ -225,6 +245,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
from a different domain.
|
||||
"""
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
# Fudge a second event which points to an event we don't have. This is a
|
||||
# state event so that the state changes (otherwise we won't prune the
|
||||
# extremity as they'll have the same state group).
|
||||
@@ -237,16 +261,15 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other2",
|
||||
"depth": 10,
|
||||
"prev_events": ["$some_unknown_message"],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
state_before_gap.get((EventTypes.Create, "")),
|
||||
state_before_gap.get((EventTypes.PowerLevels, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
@@ -267,6 +290,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
# also set the depth to "lots".
|
||||
self.reactor.advance(7 * 24 * 60 * 60)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
# Fudge a second event which points to an event we don't have. This is a
|
||||
# state event so that the state changes (otherwise we won't prune the
|
||||
# extremity as they'll have the same state group).
|
||||
@@ -279,16 +306,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other2",
|
||||
"depth": 10000,
|
||||
"prev_events": ["$some_unknown_message"],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
state_before_gap.get((EventTypes.Create, "")),
|
||||
state_before_gap.get((EventTypes.PowerLevels, "")),
|
||||
state_before_gap.get((EventTypes.JoinRules, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
@@ -311,6 +338,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
# also set the depth to "lots".
|
||||
self.reactor.advance(7 * 24 * 60 * 60)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
# Fudge a second event which points to an event we don't have. This is a
|
||||
# state event so that the state changes (otherwise we won't prune the
|
||||
# extremity as they'll have the same state group).
|
||||
@@ -323,16 +354,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other2",
|
||||
"depth": 10000,
|
||||
"prev_events": ["$some_unknown_message"],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
state_before_gap.get((EventTypes.Create, "")),
|
||||
state_before_gap.get((EventTypes.PowerLevels, "")),
|
||||
state_before_gap.get((EventTypes.JoinRules, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
@@ -347,6 +378,10 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
local_message_event_id = body["event_id"]
|
||||
self.assert_extremities([local_message_event_id])
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
# Fudge a second event which points to an event we don't have. This is a
|
||||
# state event so that the state changes (otherwise we won't prune the
|
||||
# extremity as they'll have the same state group).
|
||||
@@ -359,16 +394,16 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
"sender": "@user:other2",
|
||||
"depth": 10000,
|
||||
"prev_events": ["$some_unknown_message"],
|
||||
"auth_events": [],
|
||||
"auth_events": [
|
||||
state_before_gap.get((EventTypes.Create, "")),
|
||||
state_before_gap.get((EventTypes.PowerLevels, "")),
|
||||
state_before_gap.get((EventTypes.JoinRules, "")),
|
||||
],
|
||||
"origin_server_ts": self.clock.time_msec(),
|
||||
},
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(
|
||||
self._state_storage_controller.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
|
||||
@@ -25,7 +25,7 @@ from canonicaljson import json
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.server import HomeServer
|
||||
@@ -263,12 +263,17 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@property
|
||||
def room_id(self) -> str:
|
||||
assert self._base_builder.room_id is not None
|
||||
return self._base_builder.room_id
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self._base_builder.type
|
||||
|
||||
@property
|
||||
def room_version(self) -> RoomVersion:
|
||||
return self._base_builder.room_version
|
||||
|
||||
@property
|
||||
def internal_metadata(self) -> EventInternalMetadata:
|
||||
return self._base_builder.internal_metadata
|
||||
|
||||
51
tests/types/test_init.py
Normal file
51
tests/types/test_init.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import RoomID
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class RoomIDTestCase(TestCase):
|
||||
def test_can_create_msc4291_room_ids(self) -> None:
|
||||
valid_msc4291_room_id = "!31hneApxJ_1o-63DmFrpeqnkFfWppnzWso1JvH3ogLM"
|
||||
room_id = RoomID.from_string(valid_msc4291_room_id)
|
||||
self.assertEquals(RoomID.is_valid(valid_msc4291_room_id), True)
|
||||
self.assertEquals(
|
||||
room_id.to_string(),
|
||||
valid_msc4291_room_id,
|
||||
)
|
||||
self.assertEquals(room_id.id, "!31hneApxJ_1o-63DmFrpeqnkFfWppnzWso1JvH3ogLM")
|
||||
self.assertEquals(room_id.get_domain(), None)
|
||||
|
||||
def test_cannot_create_invalid_msc4291_room_ids(self) -> None:
|
||||
invalid_room_ids = [
|
||||
"!wronglength",
|
||||
"!31hneApxJ_1o-63DmFrpeqnNOTurlsafeBASE64/gLM",
|
||||
"!",
|
||||
"! ",
|
||||
]
|
||||
for bad_room_id in invalid_room_ids:
|
||||
with self.assertRaises(SynapseError):
|
||||
RoomID.from_string(bad_room_id)
|
||||
if not RoomID.is_valid(bad_room_id):
|
||||
raise SynapseError(400, "invalid")
|
||||
|
||||
def test_cannot_create_invalid_legacy_room_ids(self) -> None:
|
||||
invalid_room_ids = [
|
||||
"!something:invalid$_chars.com",
|
||||
]
|
||||
for bad_room_id in invalid_room_ids:
|
||||
with self.assertRaises(SynapseError):
|
||||
RoomID.from_string(bad_room_id)
|
||||
if not RoomID.is_valid(bad_room_id):
|
||||
raise SynapseError(400, "invalid")
|
||||
|
||||
def test_can_create_valid_legacy_room_ids(self) -> None:
|
||||
valid_room_ids = [
|
||||
"!foo:example.com",
|
||||
"!foo:example.com:8448",
|
||||
"!💩💩💩:example.com",
|
||||
]
|
||||
for room_id_str in valid_room_ids:
|
||||
room_id = RoomID.from_string(room_id_str)
|
||||
self.assertEquals(RoomID.is_valid(room_id_str), True)
|
||||
self.assertIsNotNone(room_id.get_domain())
|
||||
Reference in New Issue
Block a user