Compare commits

...

24 Commits

Author SHA1 Message Date
Erik Johnston
055dc16d49 Fix sending receipts 2022-05-23 09:40:36 +01:00
Erik Johnston
9319bf036c Log when we get state at sync 2022-05-21 17:10:54 +01:00
Erik Johnston
e5c2ea6341 Don't pull out state for catchup 2022-05-21 16:53:01 +01:00
Erik Johnston
8141a0d0b3 Fix test for prev 2022-05-21 16:51:58 +01:00
Erik Johnston
8e47d72992 Pull current hosts out from current_state table 2022-05-21 16:42:51 +01:00
Erik Johnston
da10dfc311 Merge branch 'erikj/less_state_on_missing' into erikj/push_hack 2022-05-21 14:13:10 +01:00
Erik Johnston
f9d470b2da Fix tests 2022-05-21 14:11:52 +01:00
Erik Johnston
8b33331cb5 Newsfile 2022-05-21 14:01:21 +01:00
Erik Johnston
2ebb0c6f99 Pull out less state when handling gaps 2022-05-21 13:58:52 +01:00
Erik Johnston
3bbe3074fb Ignore display name push 2022-05-20 18:57:26 +01:00
Erik Johnston
6fd8b850ed Reduce state that push rules pull from DB 2022-05-20 18:57:13 +01:00
Erik Johnston
4b5a1a45da Don't pull out stuff for push 2022-05-20 18:50:59 +01:00
Erik Johnston
2dd2ca17a0 Fix lint 2022-05-20 18:34:05 +01:00
Erik Johnston
456a394bf7 Use 'get_domains_from_state' still for backfill 2022-05-20 15:13:06 +01:00
Erik Johnston
4bd06c9c98 Newsfile 2022-05-20 13:02:47 +01:00
Erik Johnston
c8c12ac13a Pull out less state when checking soft fail 2022-05-20 12:58:50 +01:00
Erik Johnston
9bb3bbe153 Require 'latest_event_ids' 2022-05-20 12:58:50 +01:00
Erik Johnston
68ff8f3575 Remove 'get_current_state' from StateHandler 2022-05-20 12:58:50 +01:00
Erik Johnston
11efe7231f Use new helper functions 2022-05-20 12:51:53 +01:00
Erik Johnston
f69785e875 Add helper methods to store 2022-05-20 12:51:53 +01:00
Erik Johnston
151cb6e2f4 Use new store.get_current_state_event 2022-05-20 12:51:53 +01:00
Erik Johnston
d882ee6219 Use helper function elsewhere 2022-05-20 12:07:37 +01:00
Erik Johnston
94cd2cad4f Use helper function in auth 2022-05-20 12:03:40 +01:00
Erik Johnston
155399a145 Add helper function to get the current state event in the room 2022-05-20 12:03:40 +01:00
33 changed files with 419 additions and 390 deletions

1
changelog.d/12811.misc Normal file
View File

@@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.

1
changelog.d/12828.misc Normal file
View File

@@ -0,0 +1 @@
Pull out less state when handling gaps in room DAG.

View File

@@ -61,7 +61,6 @@ class Auth:
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
self._account_validity_handler = hs.get_account_validity_handler()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
@@ -81,7 +80,7 @@ class Auth:
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
) -> EventBase:
) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -99,29 +98,28 @@ class Auth:
Raises:
AuthError if the user is/was not in the room.
Returns:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
The current membership of the user in the room and the
membership event ID of the user.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
if member:
membership = member.membership
(
membership,
member_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if membership:
if membership == Membership.JOIN:
return member
return membership, member_event_id
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member
return membership, member_event_id
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@@ -602,7 +600,8 @@ class Auth:
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
power_level_event = await self.state.get_current_state(
power_level_event = await self.store.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
@@ -693,12 +692,11 @@ class Auth:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = await self.check_user_in_room(
return await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = await self.state.get_current_state(
visibility = await self.store.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (

View File

@@ -1167,14 +1167,10 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
state_ids = await self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
acl_event = await self.store.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")

View File

@@ -602,7 +602,7 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.store.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains_set

View File

@@ -36,6 +36,7 @@ from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.visibility import filter_events_for_server
if TYPE_CHECKING:
import synapse.server
@@ -76,6 +77,7 @@ class PerDestinationQueue:
):
self._server_name = hs.hostname
self._clock = hs.get_clock()
self._storage = hs.get_storage()
self._store = hs.get_datastores().main
self._transaction_manager = transaction_manager
self._instance_name = hs.get_instance_name()
@@ -441,6 +443,12 @@ class PerDestinationQueue:
"This should not happen." % event_ids
)
logger.info(
"Catching up destination %s with %d PDUs",
self._destination,
len(catchup_pdus),
)
# We send transactions with events from one room only, as its likely
# that the remote will have to do additional processing, which may
# take some time. It's better to give it small amounts of work
@@ -486,19 +494,17 @@ class PerDestinationQueue:
):
continue
# Filter out events where the server is not in the room,
# e.g. it may have left/been kicked. *Ideally* we'd pull
# out the kick and send that, but it's a rare edge case
# so we don't bother for now (the server that sent the
# kick should send it out if its online).
hosts = await self._state.get_hosts_in_room_at_events(
p.room_id, [p.event_id]
)
if self._destination not in hosts:
continue
new_pdus.append(p)
# Filter out events where the server is not in the room,
# e.g. it may have left/been kicked. *Ideally* we'd pull
# out the kick and send that, but it's a rare edge case
# so we don't bother for now (the server that sent the
# kick should send it out if its online).
new_pdus = await filter_events_for_server(
self._storage, self._destination, new_pdus, redact=False
)
# If we've filtered out all the extremities, fall back to
# sending the original event. This should ensure that the
# server gets at least some of missed events (especially if

View File

@@ -319,7 +319,7 @@ class DirectoryHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
alias_event = await self.store.get_current_state_event(
room_id, EventTypes.CanonicalAlias, ""
)

View File

@@ -353,7 +353,7 @@ class FederationHandler:
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self.store.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)

View File

@@ -463,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in state
},
partial_state=partial_state,
)
@@ -501,7 +503,7 @@ class FederationEventHandler:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event=state,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@@ -765,7 +767,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@@ -792,8 +794,8 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of
the events in the state at `event`.
if we already had all the prev events, `None`. Otherwise, returns
the state at `event`.
"""
room_id = event.room_id
event_id = event.event_id
@@ -837,13 +839,7 @@ class FederationEventHandler:
dest, room_id, p
)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
state_maps.append(remote_state)
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
@@ -854,19 +850,6 @@ class FederationEventHandler:
state_res_store=StateResolutionStore(self._store),
)
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.as_is,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@@ -878,14 +861,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
return state
return state_map
async def _get_state_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@@ -894,7 +877,7 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, including the event itself
The state *after* the given event.
"""
(
state_event_ids,
@@ -913,15 +896,13 @@ class FederationEventHandler:
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self._store.get_events(
desired_events, allow_rejected=True
)
have_events = await self._store.have_seen_events(room_id, desired_events)
missing_desired_events = desired_events - fetched_events.keys()
missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
len(fetched_events),
len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@@ -932,7 +913,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@@ -958,47 +939,54 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
await self._store.get_events(missing_desired_events, allow_rejected=True)
)
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
state_map = {}
del fetched_events[bad_event_id]
for state_event_id, metadata in event_metadata.items():
if metadata.room_id != room_id:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
state_event_id,
metadata.room_id,
room_id,
)
continue
if metadata.state_key is None:
logger.warning(
"Remote server gave us non-state event in state: %s", state_event_id
)
continue
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id)
remote_event = await self._store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - fetched_events.keys()
failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@@ -1006,14 +994,12 @@ class FederationEventHandler:
failed_to_fetch,
)
remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id
return remote_state
return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@@ -1040,7 +1026,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@@ -1074,7 +1060,7 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
event, old_state=state
event, state_ids_before_event=state
)
context = await self._check_event_auth(
origin,
@@ -1558,14 +1544,14 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN:
return
current_state_map = await self._state_handler.get_current_state(event.room_id)
current_state = list(current_state_map.values())
await self._get_room_member_handler().kick_guest_users(current_state)
current_state = await self._store.get_current_state(event.room_id)
current_state_list = list(current_state.values())
await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail(
self,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@@ -1588,6 +1574,9 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# The event types we want to pull from the "current" state.
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state".
if state is not None:
# If we're explicitly given the state then we won't have all the
@@ -1602,20 +1591,24 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets_d = await self._state_store.get_state_groups(
state_sets_d = await self._state_store.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self._state_handler.resolve_events(
room_version, state_sets, event
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_sets,
event_map={},
state_res_store=StateResolutionStore(self._store),
)
)
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else:
current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
current_state_ids = await self._store.get_filtered_current_state_ids(
event.room_id, StateFilter.from_types(auth_types)
)
logger.debug(
@@ -1625,7 +1618,6 @@ class FederationEventHandler:
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]

View File

@@ -190,7 +190,7 @@ class InitialSyncHandler:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
self.state_handler.get_current_state, event.room_id
self.store.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
@@ -404,7 +404,7 @@ class InitialSyncHandler:
membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
current_state = await self.store.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()

View File

@@ -117,7 +117,9 @@ class MessageHandler:
)
if membership == Membership.JOIN:
data = await self.state.get_current_state(room_id, event_type, state_key)
data = await self.store.get_current_state_event(
room_id, event_type, state_key
)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.
@@ -1021,8 +1023,21 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
metadata = await self.store.get_metadata_for_events(state_event_ids)
state_map = {}
for event_id, data in metadata.items():
if data.state_key is None:
raise Exception(
"Trying to set non-state event as state: %s", event_id
)
state_map[(data.event_type, data.state_key)] = event_id
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map,
)
else:
context = await self.state.compute_event_context(event)

View File

@@ -1399,7 +1399,7 @@ class TimestampLookupHandler:
)
# Find other homeservers from the given state in the room
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self.store.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name

View File

@@ -1409,7 +1409,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
room_state = await self.state_handler.get_current_state(room_id)
room_state = await self.store.get_filtered_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Member, user.to_string()),
(EventTypes.CanonicalAlias, ""),
(EventTypes.Name, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
(EventTypes.RoomAvatar, ""),
]
),
)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -1805,7 +1817,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
member = await self.state_handler.get_current_state(
member = await self.store.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None

View File

@@ -348,7 +348,7 @@ class SearchHandler:
state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state = await self.store.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations(

View File

@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
import attr
from prometheus_client import Counter
from twisted.python import failure
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
@@ -643,6 +645,13 @@ class SyncHandler:
event: event of interest
state_filter: The state filter used to fetch state from the database.
"""
f = failure.Failure()
logger.info(
"SYNC get_state_after_event in room %s",
event.room_id,
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter or StateFilter.all()
)

View File

@@ -681,7 +681,7 @@ class Notifier:
return joined_room_ids, True
async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state(
state = await self.store.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:

View File

@@ -209,7 +209,12 @@ class BulkPushRuleEvaluator:
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user: Dict[str, List[Union[dict, str]]] = {}
room_members = await self.store.get_joined_users_from_context(event, context)
# FIXME!!!
# room_members = await self.store.get_joined_users_from_context(event, context)
room_member_count = await self.store.get_number_joined_users_in_room(
event.room_id
)
(
power_levels,
@@ -217,7 +222,7 @@ class BulkPushRuleEvaluator:
) = await self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
event, room_member_count, sender_power_level, power_levels
)
# If the event is not a state event check if any users ignore the sender.
@@ -234,9 +239,10 @@ class BulkPushRuleEvaluator:
continue
display_name = None
profile_info = room_members.get(uid)
if profile_info:
display_name = profile_info.display_name
# FIXME!!!
# profile_info = room_members.get(uid)
# if profile_info:
# display_name = profile_info.display_name
if not display_name:
# Handle the case where we are pushing a membership event to
@@ -387,77 +393,27 @@ class RulesForRoom:
self.room_push_rule_cache_metrics.inc_hits()
return self.data.rules_by_user
self.room_push_rule_cache_metrics.inc_misses()
ret_rules_by_user = {}
missing_member_event_ids = {}
if state_group and self.data.state_group == context.prev_group:
# If we have a simple delta then we can reuse most of the previous
# results.
ret_rules_by_user = self.data.rules_by_user
current_state_ids = context.delta_ids
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
# Ensure the state IDs exist.
assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
local_users = await self.store.get_local_users_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)
# Loop through to see which member events we've seen and have rules
# for and which we need to fetch
for key in current_state_ids:
typ, user_id = key
if typ != EventTypes.Member:
continue
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
if self.is_mine_id(event.state_key):
local_users = list(local_users)
local_users.append(event.state_key)
if user_id in self.data.uninteresting_user_set:
continue
ret_rules_by_user = await self.store.bulk_get_push_rules(
local_users, on_invalidate=self.invalidate_all_cb
)
if not self.is_mine_id(user_id):
self.data.uninteresting_user_set.add(user_id)
continue
logger.info("Users in room: %s", local_users)
if self.store.get_if_app_services_interested_in_user(user_id):
self.data.uninteresting_user_set.add(user_id)
continue
event_id = current_state_ids[key]
res = self.data.member_map.get(event_id, None)
if res:
if res.membership == Membership.JOIN:
rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
ret_rules_by_user[res.user_id] = rules
continue
# If a user has left a room we remove their push rule. If they
# joined then we re-add it later in _update_rules_with_member_event_ids
ret_rules_by_user.pop(user_id, None)
missing_member_event_ids[user_id] = event_id
if missing_member_event_ids:
# If we have some member events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
else:
# The push rules didn't change but lets update the cache anyway
self.update_cache(
self.data.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
state_group=state_group,
)
self.update_cache(
self.data.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
state_group=state_group,
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -465,67 +421,6 @@ class RulesForRoom:
)
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
self,
ret_rules_by_user: Dict[str, list],
member_event_ids: Dict[str, str],
state_group: Optional[int],
event: EventBase,
) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
event: The event we are currently computing push rules for.
"""
sequence = self.data.sequence
members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)
# If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
joined_user_ids = {
entry.user_id
for entry in members.values()
if entry and entry.membership == Membership.JOIN
}
logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
)
ret_rules_by_user.update(
item for item in rules_by_user.items() if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
def update_cache(
self,
sequence: int,

View File

@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder
@@ -447,7 +448,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
async def on_POST(
@@ -489,8 +490,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
)
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
join_rules_event = await self.store.get_current_state_event(
room_id, EventTypes.JoinRules, ""
)
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
# update_membership with an action of "invite" can raise a
@@ -552,12 +554,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
user_to_add = content.get("user_id", requester.user.to_string())
# Figure out which local users currently have power in the room, if any.
room_state = await self.state_handler.get_current_state(room_id)
if not room_state:
filtered_room_state = await self.store.get_filtered_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Create, ""),
(EventTypes.PowerLevels, ""),
(EventTypes.JoinRules, ""),
(EventTypes.Member, user_to_add),
]
),
)
if not filtered_room_state:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
create_event = room_state[(EventTypes.Create, "")]
power_levels = room_state.get((EventTypes.PowerLevels, ""))
create_event = filtered_room_state[(EventTypes.Create, "")]
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
if power_levels is not None:
# We pick the local user with the highest power.
@@ -633,7 +645,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# Now we check if the user we're granting admin rights to is already in
# the room. If not and it's not a public room we invite them.
member_event = room_state.get((EventTypes.Member, user_to_add))
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
is_joined = False
if member_event:
is_joined = member_event.content["membership"] in (
@@ -644,7 +656,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
if is_joined:
return HTTPStatus.OK, {}
join_rules = room_state.get((EventTypes.JoinRules, ""))
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
is_public = False
if join_rules:
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC

View File

@@ -673,7 +673,7 @@ class RoomEventServlet(RestServlet):
if include_unredacted_content and not await self.auth.is_server_admin(
requester.user
):
power_level_event = await self._state.get_current_state(
power_level_event = await self._store.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)

View File

@@ -178,8 +178,8 @@ class ResourceLimitsServerNotices:
currently_blocked = False
pinned_state_event = None
try:
pinned_state_event = await self._state.get_current_state(
room_id, event_type=EventTypes.Pinned
pinned_state_event = await self._store.get_current_state_event(
room_id, event_type=EventTypes.Pinned, state_key=""
)
except AuthError:
# The user has yet to join the server notices room

View File

@@ -32,13 +32,11 @@ from typing import (
Set,
Tuple,
Union,
overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Counter, Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -132,85 +130,20 @@ class StateHandler:
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage = hs.get_storage()
@overload
async def get_current_state(
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
Returns:
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
self,
room_id: str,
latest_event_ids: Collection[str],
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id:
latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache).
latest_event_ids: The forward extremities to resolve.
Returns:
the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state
@@ -239,10 +172,6 @@ class StateHandler:
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]:
@@ -261,7 +190,7 @@ class StateHandler:
async def compute_event_context(
self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -273,12 +202,12 @@ class StateHandler:
Args:
event:
old_state: The state at the event if it can't be
state_ids_before_event: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
partial_state: True if `state_ids_before_event` is partial and omits
non-critical membership events
Returns:
The event context.
"""
@@ -288,11 +217,7 @@ class StateHandler:
#
# first of all, figure out the state before the event
#
if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
}
if state_ids_before_event:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None

View File

@@ -74,6 +74,9 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
self._attempt_to_invalidate_cache(
"get_number_joined_users_in_room.invalidate", (room_id,)
)
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))

View File

@@ -217,6 +217,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_local_user.invalidate((state_key,))
self.get_local_users_in_room.invalidate((room_id,))
self.get_number_joined_users_in_room((room_id,))
if relates_to:
self.get_relations_for_event.invalidate((relates_to,))

View File

@@ -1766,6 +1766,14 @@ class PersistEventsStore:
self.store.get_invited_rooms_for_local_user.invalidate,
(event.state_key,),
)
txn.call_after(
self.store.get_local_users_in_room.invalidate,
(event.room_id,),
)
txn.call_after(
self.store.get_number_joined_users_in_room.invalidate,
(event.room_id,),
)
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.

View File

@@ -337,6 +337,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_room_summary", _get_room_summary_txn
)
@cached()
async def get_number_joined_users_in_room(self, room_id: str) -> int:
return await self.db_pool.simple_select_one_onecol(
table="current_state_events",
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
retcol="COUNT(*)",
desc="get_number_joined_users_in_room",
)
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
@@ -444,6 +453,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached()
async def get_local_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
table="local_current_membership",
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
retcol="user_id",
desc="get_local_users_in_room",
)
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
@@ -869,6 +887,29 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room."""
if isinstance(self.database_engine, Sqlite3Engine):
users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:

View File

@@ -16,6 +16,8 @@ import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,6 +28,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@@ -43,6 +46,15 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventMetadata:
"""Returned by `get_metadata_for_events`"""
room_id: str
event_type: str
state_key: Optional[str]
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
@@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
async def get_metadata_for_events(
self, event_ids: Collection[str]
) -> Dict[str, EventMetadata]:
"""Get some metadata (room_id, type, state_key) for the given events."""
clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", event_ids
)
sql = f"""
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
LEFT JOIN state_events USING (event_id)
WHERE {clause}
"""
def get_metadata_for_events_txn(
txn: LoggingTransaction,
) -> Dict[str, EventMetadata]:
txn.execute(sql, args)
return {
event_id: EventMetadata(
room_id=room_id, event_type=event_type, state_key=state_key
)
for event_id, room_id, event_type, state_key in txn
}
return await self.db_pool.runInteraction(
"get_metadata_for_events", get_metadata_for_events_txn
)
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -218,6 +260,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_current_state_ids", _get_current_state_ids_txn
)
async def get_current_state(self, room_id: str) -> StateMap[EventBase]:
"""Same as `get_current_state_ids` but also fetches the events"""
state_map_ids = await self.get_current_state_ids(room_id)
event_map = await self.get_events(list(state_map_ids.values()))
state_map = {}
for key, event_id in state_map_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
return state_map
# FIXME: how should this be cached?
async def get_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
@@ -269,6 +325,39 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
async def get_filtered_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""Same as `get_filtered_current_state_ids` but also fetches the events"""
state_map_ids = await self.get_filtered_current_state_ids(room_id, state_filter)
event_map = await self.get_events(list(state_map_ids.values()))
state_map = {}
for key, event_id in state_map_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
return state_map
async def get_current_state_event(
self, room_id: str, event_type: str, state_key: str
) -> Optional[EventBase]:
"""Get the current state event for the given type/state_key."""
key = (event_type, state_key)
state_map = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types((key,))
)
event_id = state_map.get(key)
event = None
if event_id:
event = await self.get_event(event_id, allow_none=True)
return event
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any

View File

@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
state_handler=mock_state_handler,
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
hs.get_datastores().main.get_current_hosts_in_room = Mock(
return_value=make_awaitable(["test", "host2"])
)
return hs
@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (

View File

@@ -207,7 +207,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self.hs.get_datastores().main.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@@ -258,7 +258,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self.hs.get_datastores().main.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

View File

@@ -335,7 +335,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
self.state_handler.get_current_state(
self.store.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)

View File

@@ -276,7 +276,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
self.OTHER_SERVER_NAME,
event,
state={(e.type, e.state_key): e.event_id for e in current_state},
)
)

View File

@@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(event, old_state=state)
self.state.compute_event_context(event, state_ids_before_event=state)
)
self.get_success(self.persistence.persist_event(event, context))
@@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(self.state.get_current_state(self.room_id))
self.get_success(self.store.get_current_state_ids(self.room_id))
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
self.state.compute_event_context(
remote_event_2, old_state=state_before_gap.values()
remote_event_2,
state_ids_before_event=state_before_gap,
)
)
@@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])

View File

@@ -98,9 +98,8 @@ class PurgeTests(HomeserverTestCase):
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
state_handler = self.hs.get_state_handler()
create_event = self.get_success(
state_handler.get_current_state(self.room_id, "m.room.create", "")
self.store.get_current_state_event(self.room_id, "m.room.create", "")
)
self.assertIsNotNone(create_event)

View File

@@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())