mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
24 Commits
erikj/ss_u
...
erikj/push
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
055dc16d49 | ||
|
|
9319bf036c | ||
|
|
e5c2ea6341 | ||
|
|
8141a0d0b3 | ||
|
|
8e47d72992 | ||
|
|
da10dfc311 | ||
|
|
f9d470b2da | ||
|
|
8b33331cb5 | ||
|
|
2ebb0c6f99 | ||
|
|
3bbe3074fb | ||
|
|
6fd8b850ed | ||
|
|
4b5a1a45da | ||
|
|
2dd2ca17a0 | ||
|
|
456a394bf7 | ||
|
|
4bd06c9c98 | ||
|
|
c8c12ac13a | ||
|
|
9bb3bbe153 | ||
|
|
68ff8f3575 | ||
|
|
11efe7231f | ||
|
|
f69785e875 | ||
|
|
151cb6e2f4 | ||
|
|
d882ee6219 | ||
|
|
94cd2cad4f | ||
|
|
155399a145 |
1
changelog.d/12811.misc
Normal file
1
changelog.d/12811.misc
Normal file
@@ -0,0 +1 @@
|
||||
Reduce the amount of state we pull from the DB.
|
||||
1
changelog.d/12828.misc
Normal file
1
changelog.d/12828.misc
Normal file
@@ -0,0 +1 @@
|
||||
Pull out less state when handling gaps in room DAG.
|
||||
@@ -61,7 +61,6 @@ class Auth:
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastores().main
|
||||
self.state = hs.get_state_handler()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
|
||||
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
|
||||
@@ -81,7 +80,7 @@ class Auth:
|
||||
user_id: str,
|
||||
current_state: Optional[StateMap[EventBase]] = None,
|
||||
allow_departed_users: bool = False,
|
||||
) -> EventBase:
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Check if the user is in the room, or was at some point.
|
||||
Args:
|
||||
room_id: The room to check.
|
||||
@@ -99,29 +98,28 @@ class Auth:
|
||||
Raises:
|
||||
AuthError if the user is/was not in the room.
|
||||
Returns:
|
||||
Membership event for the user if the user was in the
|
||||
room. This will be the join event if they are currently joined to
|
||||
the room. This will be the leave event if they have left the room.
|
||||
The current membership of the user in the room and the
|
||||
membership event ID of the user.
|
||||
"""
|
||||
if current_state:
|
||||
member = current_state.get((EventTypes.Member, user_id), None)
|
||||
else:
|
||||
member = await self.state.get_current_state(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
|
||||
if member:
|
||||
membership = member.membership
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
) = await self.store.get_local_current_membership_for_user_in_room(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
if membership:
|
||||
if membership == Membership.JOIN:
|
||||
return member
|
||||
return membership, member_event_id
|
||||
|
||||
# XXX this looks totally bogus. Why do we not allow users who have been banned,
|
||||
# or those who were members previously and have been re-invited?
|
||||
if allow_departed_users and membership == Membership.LEAVE:
|
||||
forgot = await self.store.did_forget(user_id, room_id)
|
||||
if not forgot:
|
||||
return member
|
||||
return membership, member_event_id
|
||||
|
||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||
|
||||
@@ -602,7 +600,8 @@ class Auth:
|
||||
# We currently require the user is a "moderator" in the room. We do this
|
||||
# by checking if they would (theoretically) be able to change the
|
||||
# m.room.canonical_alias events
|
||||
power_level_event = await self.state.get_current_state(
|
||||
|
||||
power_level_event = await self.store.get_current_state_event(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
)
|
||||
|
||||
@@ -693,12 +692,11 @@ class Auth:
|
||||
# * The user is a non-guest user, and was ever in the room
|
||||
# * The user is a guest user, and has joined the room
|
||||
# else it will throw.
|
||||
member_event = await self.check_user_in_room(
|
||||
return await self.check_user_in_room(
|
||||
room_id, user_id, allow_departed_users=allow_departed_users
|
||||
)
|
||||
return member_event.membership, member_event.event_id
|
||||
except AuthError:
|
||||
visibility = await self.state.get_current_state(
|
||||
visibility = await self.store.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -1167,14 +1167,10 @@ class FederationServer(FederationBase):
|
||||
Raises:
|
||||
AuthError if the server does not match the ACL
|
||||
"""
|
||||
state_ids = await self.store.get_current_state_ids(room_id)
|
||||
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
||||
|
||||
if not acl_event_id:
|
||||
return
|
||||
|
||||
acl_event = await self.store.get_event(acl_event_id)
|
||||
if server_matches_acl_event(server_name, acl_event):
|
||||
acl_event = await self.store.get_current_state_event(
|
||||
room_id, EventTypes.ServerACL, ""
|
||||
)
|
||||
if not acl_event or server_matches_acl_event(server_name, acl_event):
|
||||
return
|
||||
|
||||
raise AuthError(code=403, msg="Server is banned from room")
|
||||
|
||||
@@ -602,7 +602,7 @@ class FederationSender(AbstractFederationSender):
|
||||
room_id = receipt.room_id
|
||||
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
domains_set = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains_set = await self.store.get_current_hosts_in_room(room_id)
|
||||
domains = [
|
||||
d
|
||||
for d in domains_set
|
||||
|
||||
@@ -36,6 +36,7 @@ from synapse.metrics import sent_transactions_counter
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
from synapse.visibility import filter_events_for_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
@@ -76,6 +77,7 @@ class PerDestinationQueue:
|
||||
):
|
||||
self._server_name = hs.hostname
|
||||
self._clock = hs.get_clock()
|
||||
self._storage = hs.get_storage()
|
||||
self._store = hs.get_datastores().main
|
||||
self._transaction_manager = transaction_manager
|
||||
self._instance_name = hs.get_instance_name()
|
||||
@@ -441,6 +443,12 @@ class PerDestinationQueue:
|
||||
"This should not happen." % event_ids
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Catching up destination %s with %d PDUs",
|
||||
self._destination,
|
||||
len(catchup_pdus),
|
||||
)
|
||||
|
||||
# We send transactions with events from one room only, as its likely
|
||||
# that the remote will have to do additional processing, which may
|
||||
# take some time. It's better to give it small amounts of work
|
||||
@@ -486,19 +494,17 @@ class PerDestinationQueue:
|
||||
):
|
||||
continue
|
||||
|
||||
# Filter out events where the server is not in the room,
|
||||
# e.g. it may have left/been kicked. *Ideally* we'd pull
|
||||
# out the kick and send that, but it's a rare edge case
|
||||
# so we don't bother for now (the server that sent the
|
||||
# kick should send it out if its online).
|
||||
hosts = await self._state.get_hosts_in_room_at_events(
|
||||
p.room_id, [p.event_id]
|
||||
)
|
||||
if self._destination not in hosts:
|
||||
continue
|
||||
|
||||
new_pdus.append(p)
|
||||
|
||||
# Filter out events where the server is not in the room,
|
||||
# e.g. it may have left/been kicked. *Ideally* we'd pull
|
||||
# out the kick and send that, but it's a rare edge case
|
||||
# so we don't bother for now (the server that sent the
|
||||
# kick should send it out if its online).
|
||||
new_pdus = await filter_events_for_server(
|
||||
self._storage, self._destination, new_pdus, redact=False
|
||||
)
|
||||
|
||||
# If we've filtered out all the extremities, fall back to
|
||||
# sending the original event. This should ensure that the
|
||||
# server gets at least some of missed events (especially if
|
||||
|
||||
@@ -319,7 +319,7 @@ class DirectoryHandler:
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
alias_event = await self.state.get_current_state(
|
||||
alias_event = await self.store.get_current_state_event(
|
||||
room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
|
||||
|
||||
@@ -353,7 +353,7 @@ class FederationHandler:
|
||||
# First we try hosts that are already in the room
|
||||
# TODO: HEURISTIC ALERT.
|
||||
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self.store.get_current_state(room_id)
|
||||
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
|
||||
|
||||
@@ -463,7 +463,9 @@ class FederationEventHandler:
|
||||
with nested_logging_context(suffix=event.event_id):
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event,
|
||||
old_state=state,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in state
|
||||
},
|
||||
partial_state=partial_state,
|
||||
)
|
||||
|
||||
@@ -501,7 +503,7 @@ class FederationEventHandler:
|
||||
# build a new state group for it if need be
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event,
|
||||
old_state=state,
|
||||
state_ids_before_event=state,
|
||||
)
|
||||
if context.partial_state:
|
||||
# this can happen if some or all of the event's prev_events still have
|
||||
@@ -765,7 +767,7 @@ class FederationEventHandler:
|
||||
|
||||
async def _resolve_state_at_missing_prevs(
|
||||
self, dest: str, event: EventBase
|
||||
) -> Optional[Iterable[EventBase]]:
|
||||
) -> Optional[StateMap[str]]:
|
||||
"""Calculate the state at an event with missing prev_events.
|
||||
|
||||
This is used when we have pulled a batch of events from a remote server, and
|
||||
@@ -792,8 +794,8 @@ class FederationEventHandler:
|
||||
event: an event to check for missing prevs.
|
||||
|
||||
Returns:
|
||||
if we already had all the prev events, `None`. Otherwise, returns a list of
|
||||
the events in the state at `event`.
|
||||
if we already had all the prev events, `None`. Otherwise, returns
|
||||
the state at `event`.
|
||||
"""
|
||||
room_id = event.room_id
|
||||
event_id = event.event_id
|
||||
@@ -837,13 +839,7 @@ class FederationEventHandler:
|
||||
dest, room_id, p
|
||||
)
|
||||
|
||||
remote_state_map = {
|
||||
(x.type, x.state_key): x.event_id for x in remote_state
|
||||
}
|
||||
state_maps.append(remote_state_map)
|
||||
|
||||
for x in remote_state:
|
||||
event_map[x.event_id] = x
|
||||
state_maps.append(remote_state)
|
||||
|
||||
room_version = await self._store.get_room_version_id(room_id)
|
||||
state_map = await self._state_resolution_handler.resolve_events_with_store(
|
||||
@@ -854,19 +850,6 @@ class FederationEventHandler:
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
)
|
||||
|
||||
# We need to give _process_received_pdu the actual state events
|
||||
# rather than event ids, so generate that now.
|
||||
|
||||
# First though we need to fetch all the events that are in
|
||||
# state_map, so we can build up the state below.
|
||||
evs = await self._store.get_events(
|
||||
list(state_map.values()),
|
||||
get_prev_content=False,
|
||||
redact_behaviour=EventRedactBehaviour.as_is,
|
||||
)
|
||||
event_map.update(evs)
|
||||
|
||||
state = [event_map[e] for e in state_map.values()]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Error attempting to resolve state at missing prev_events",
|
||||
@@ -878,14 +861,14 @@ class FederationEventHandler:
|
||||
"We can't get valid state history.",
|
||||
affected=event_id,
|
||||
)
|
||||
return state
|
||||
return state_map
|
||||
|
||||
async def _get_state_after_missing_prev_event(
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
) -> List[EventBase]:
|
||||
) -> StateMap[str]:
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
@@ -894,7 +877,7 @@ class FederationEventHandler:
|
||||
event_id: The id of the event we want the state at.
|
||||
|
||||
Returns:
|
||||
A list of events in the state, including the event itself
|
||||
The state *after* the given event.
|
||||
"""
|
||||
(
|
||||
state_event_ids,
|
||||
@@ -913,15 +896,13 @@ class FederationEventHandler:
|
||||
desired_events = set(state_event_ids)
|
||||
desired_events.add(event_id)
|
||||
logger.debug("Fetching %i events from cache/store", len(desired_events))
|
||||
fetched_events = await self._store.get_events(
|
||||
desired_events, allow_rejected=True
|
||||
)
|
||||
have_events = await self._store.have_seen_events(room_id, desired_events)
|
||||
|
||||
missing_desired_events = desired_events - fetched_events.keys()
|
||||
missing_desired_events = desired_events - have_events
|
||||
logger.debug(
|
||||
"We are missing %i events (got %i)",
|
||||
len(missing_desired_events),
|
||||
len(fetched_events),
|
||||
len(have_events),
|
||||
)
|
||||
|
||||
# We probably won't need most of the auth events, so let's just check which
|
||||
@@ -932,7 +913,7 @@ class FederationEventHandler:
|
||||
# already have a bunch of the state events. It would be nice if the
|
||||
# federation api gave us a way of finding out which we actually need.
|
||||
|
||||
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
|
||||
missing_auth_events = set(auth_event_ids) - have_events
|
||||
missing_auth_events.difference_update(
|
||||
await self._store.have_seen_events(room_id, missing_auth_events)
|
||||
)
|
||||
@@ -958,47 +939,54 @@ class FederationEventHandler:
|
||||
destination=destination, room_id=room_id, event_ids=missing_events
|
||||
)
|
||||
|
||||
# we need to make sure we re-load from the database to get the rejected
|
||||
# state correct.
|
||||
fetched_events.update(
|
||||
await self._store.get_events(missing_desired_events, allow_rejected=True)
|
||||
)
|
||||
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
|
||||
|
||||
# check for events which were in the wrong room.
|
||||
#
|
||||
# this can happen if a remote server claims that the state or
|
||||
# auth_events at an event in room A are actually events in room B
|
||||
|
||||
bad_events = [
|
||||
(event_id, event.room_id)
|
||||
for event_id, event in fetched_events.items()
|
||||
if event.room_id != room_id
|
||||
]
|
||||
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
|
||||
|
||||
for bad_event_id, bad_room_id in bad_events:
|
||||
# This is a bogus situation, but since we may only discover it a long time
|
||||
# after it happened, we try our best to carry on, by just omitting the
|
||||
# bad events from the returned state set.
|
||||
logger.warning(
|
||||
"Remote server %s claims event %s in room %s is an auth/state "
|
||||
"event in room %s",
|
||||
destination,
|
||||
bad_event_id,
|
||||
bad_room_id,
|
||||
room_id,
|
||||
)
|
||||
state_map = {}
|
||||
|
||||
del fetched_events[bad_event_id]
|
||||
for state_event_id, metadata in event_metadata.items():
|
||||
if metadata.room_id != room_id:
|
||||
# This is a bogus situation, but since we may only discover it a long time
|
||||
# after it happened, we try our best to carry on, by just omitting the
|
||||
# bad events from the returned state set.
|
||||
logger.warning(
|
||||
"Remote server %s claims event %s in room %s is an auth/state "
|
||||
"event in room %s",
|
||||
destination,
|
||||
state_event_id,
|
||||
metadata.room_id,
|
||||
room_id,
|
||||
)
|
||||
continue
|
||||
|
||||
if metadata.state_key is None:
|
||||
logger.warning(
|
||||
"Remote server gave us non-state event in state: %s", state_event_id
|
||||
)
|
||||
continue
|
||||
|
||||
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
|
||||
|
||||
# if we couldn't get the prev event in question, that's a problem.
|
||||
remote_event = fetched_events.get(event_id)
|
||||
remote_event = await self._store.get_event(
|
||||
event_id,
|
||||
allow_none=True,
|
||||
allow_rejected=True,
|
||||
redact_behaviour=EventRedactBehaviour.as_is,
|
||||
)
|
||||
if not remote_event:
|
||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||
|
||||
# missing state at that event is a warning, not a blocker
|
||||
# XXX: this doesn't sound right? it means that we'll end up with incomplete
|
||||
# state.
|
||||
failed_to_fetch = desired_events - fetched_events.keys()
|
||||
failed_to_fetch = desired_events - event_metadata.keys()
|
||||
if failed_to_fetch:
|
||||
logger.warning(
|
||||
"Failed to fetch missing state events for %s %s",
|
||||
@@ -1006,14 +994,12 @@ class FederationEventHandler:
|
||||
failed_to_fetch,
|
||||
)
|
||||
|
||||
remote_state = [
|
||||
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
|
||||
]
|
||||
|
||||
if remote_event.is_state() and remote_event.rejected_reason is None:
|
||||
remote_state.append(remote_event)
|
||||
state_map[
|
||||
(remote_event.type, remote_event.state_key)
|
||||
] = remote_event.event_id
|
||||
|
||||
return remote_state
|
||||
return state_map
|
||||
|
||||
async def _get_state_and_persist(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
@@ -1040,7 +1026,7 @@ class FederationEventHandler:
|
||||
self,
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
state: Optional[StateMap[str]],
|
||||
backfilled: bool = False,
|
||||
) -> None:
|
||||
"""Called when we have a new non-outlier event.
|
||||
@@ -1074,7 +1060,7 @@ class FederationEventHandler:
|
||||
|
||||
try:
|
||||
context = await self._state_handler.compute_event_context(
|
||||
event, old_state=state
|
||||
event, state_ids_before_event=state
|
||||
)
|
||||
context = await self._check_event_auth(
|
||||
origin,
|
||||
@@ -1558,14 +1544,14 @@ class FederationEventHandler:
|
||||
if guest_access == GuestAccess.CAN_JOIN:
|
||||
return
|
||||
|
||||
current_state_map = await self._state_handler.get_current_state(event.room_id)
|
||||
current_state = list(current_state_map.values())
|
||||
await self._get_room_member_handler().kick_guest_users(current_state)
|
||||
current_state = await self._store.get_current_state(event.room_id)
|
||||
current_state_list = list(current_state.values())
|
||||
await self._get_room_member_handler().kick_guest_users(current_state_list)
|
||||
|
||||
async def _check_for_soft_fail(
|
||||
self,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
state: Optional[StateMap[str]],
|
||||
origin: str,
|
||||
) -> None:
|
||||
"""Checks if we should soft fail the event; if so, marks the event as
|
||||
@@ -1588,6 +1574,9 @@ class FederationEventHandler:
|
||||
room_version = await self._store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
# The event types we want to pull from the "current" state.
|
||||
auth_types = auth_types_for_event(room_version_obj, event)
|
||||
|
||||
# Calculate the "current state".
|
||||
if state is not None:
|
||||
# If we're explicitly given the state then we won't have all the
|
||||
@@ -1602,20 +1591,24 @@ class FederationEventHandler:
|
||||
# given state at the event. This should correctly handle cases
|
||||
# like bans, especially with state res v2.
|
||||
|
||||
state_sets_d = await self._state_store.get_state_groups(
|
||||
state_sets_d = await self._state_store.get_state_groups_ids(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
|
||||
state_sets: List[StateMap[str]] = list(state_sets_d.values())
|
||||
state_sets.append(state)
|
||||
current_states = await self._state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
|
||||
current_state_ids = (
|
||||
await self._state_resolution_handler.resolve_events_with_store(
|
||||
event.room_id,
|
||||
room_version,
|
||||
state_sets,
|
||||
event_map={},
|
||||
state_res_store=StateResolutionStore(self._store),
|
||||
)
|
||||
)
|
||||
current_state_ids: StateMap[str] = {
|
||||
k: e.event_id for k, e in current_states.items()
|
||||
}
|
||||
else:
|
||||
current_state_ids = await self._state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
current_state_ids = await self._store.get_filtered_current_state_ids(
|
||||
event.room_id, StateFilter.from_types(auth_types)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1625,7 +1618,6 @@ class FederationEventHandler:
|
||||
)
|
||||
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(room_version_obj, event)
|
||||
current_state_ids_list = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
|
||||
@@ -190,7 +190,7 @@ class InitialSyncHandler:
|
||||
if event.membership == Membership.JOIN:
|
||||
room_end_token = now_token.room_key
|
||||
deferred_room_state = run_in_background(
|
||||
self.state_handler.get_current_state, event.room_id
|
||||
self.store.get_current_state, event.room_id
|
||||
)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
room_end_token = RoomStreamToken(
|
||||
@@ -404,7 +404,7 @@ class InitialSyncHandler:
|
||||
membership: str,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
current_state = await self.state.get_current_state(room_id=room_id)
|
||||
current_state = await self.store.get_current_state(room_id=room_id)
|
||||
|
||||
# TODO: These concurrently
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
@@ -117,7 +117,9 @@ class MessageHandler:
|
||||
)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
data = await self.state.get_current_state(room_id, event_type, state_key)
|
||||
data = await self.store.get_current_state_event(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
elif membership == Membership.LEAVE:
|
||||
key = (event_type, state_key)
|
||||
# If the membership is not JOIN, then the event ID should exist.
|
||||
@@ -1021,8 +1023,21 @@ class EventCreationHandler:
|
||||
#
|
||||
# TODO(faster_joins): figure out how this works, and make sure that the
|
||||
# old state is complete.
|
||||
old_state = await self.store.get_events_as_list(state_event_ids)
|
||||
context = await self.state.compute_event_context(event, old_state=old_state)
|
||||
metadata = await self.store.get_metadata_for_events(state_event_ids)
|
||||
|
||||
state_map = {}
|
||||
for event_id, data in metadata.items():
|
||||
if data.state_key is None:
|
||||
raise Exception(
|
||||
"Trying to set non-state event as state: %s", event_id
|
||||
)
|
||||
|
||||
state_map[(data.event_type, data.state_key)] = event_id
|
||||
|
||||
context = await self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event=state_map,
|
||||
)
|
||||
else:
|
||||
context = await self.state.compute_event_context(event)
|
||||
|
||||
|
||||
@@ -1399,7 +1399,7 @@ class TimestampLookupHandler:
|
||||
)
|
||||
|
||||
# Find other homeservers from the given state in the room
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self.store.get_current_state(room_id)
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
likely_domains = [
|
||||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
|
||||
@@ -1409,7 +1409,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
txn_id: Optional[str],
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> int:
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
room_state = await self.store.get_filtered_current_state(
|
||||
room_id,
|
||||
StateFilter.from_types(
|
||||
[
|
||||
(EventTypes.Member, user.to_string()),
|
||||
(EventTypes.CanonicalAlias, ""),
|
||||
(EventTypes.Name, ""),
|
||||
(EventTypes.Create, ""),
|
||||
(EventTypes.JoinRules, ""),
|
||||
(EventTypes.RoomAvatar, ""),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
inviter_display_name = ""
|
||||
inviter_avatar_url = ""
|
||||
@@ -1805,7 +1817,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
async def forget(self, user: UserID, room_id: str) -> None:
|
||||
user_id = user.to_string()
|
||||
|
||||
member = await self.state_handler.get_current_state(
|
||||
member = await self.store.get_current_state_event(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
membership = member.membership if member else None
|
||||
|
||||
@@ -348,7 +348,7 @@ class SearchHandler:
|
||||
state_results = {}
|
||||
if include_state:
|
||||
for room_id in {e.room_id for e in search_result.allowed_events}:
|
||||
state = await self.state_handler.get_current_state(room_id)
|
||||
state = await self.store.get_current_state(room_id)
|
||||
state_results[room_id] = list(state.values())
|
||||
|
||||
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||
|
||||
@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
@@ -643,6 +645,13 @@ class SyncHandler:
|
||||
event: event of interest
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
"""
|
||||
f = failure.Failure()
|
||||
logger.info(
|
||||
"SYNC get_state_after_event in room %s",
|
||||
event.room_id,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
event.event_id, state_filter=state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
@@ -681,7 +681,7 @@ class Notifier:
|
||||
return joined_room_ids, True
|
||||
|
||||
async def _is_world_readable(self, room_id: str) -> bool:
|
||||
state = await self.state_handler.get_current_state(
|
||||
state = await self.store.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if state and "history_visibility" in state.content:
|
||||
|
||||
@@ -209,7 +209,12 @@ class BulkPushRuleEvaluator:
|
||||
rules_by_user = await self._get_rules_for_event(event, context)
|
||||
actions_by_user: Dict[str, List[Union[dict, str]]] = {}
|
||||
|
||||
room_members = await self.store.get_joined_users_from_context(event, context)
|
||||
# FIXME!!!
|
||||
# room_members = await self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
room_member_count = await self.store.get_number_joined_users_in_room(
|
||||
event.room_id
|
||||
)
|
||||
|
||||
(
|
||||
power_levels,
|
||||
@@ -217,7 +222,7 @@ class BulkPushRuleEvaluator:
|
||||
) = await self._get_power_levels_and_sender_level(event, context)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event, len(room_members), sender_power_level, power_levels
|
||||
event, room_member_count, sender_power_level, power_levels
|
||||
)
|
||||
|
||||
# If the event is not a state event check if any users ignore the sender.
|
||||
@@ -234,9 +239,10 @@ class BulkPushRuleEvaluator:
|
||||
continue
|
||||
|
||||
display_name = None
|
||||
profile_info = room_members.get(uid)
|
||||
if profile_info:
|
||||
display_name = profile_info.display_name
|
||||
# FIXME!!!
|
||||
# profile_info = room_members.get(uid)
|
||||
# if profile_info:
|
||||
# display_name = profile_info.display_name
|
||||
|
||||
if not display_name:
|
||||
# Handle the case where we are pushing a membership event to
|
||||
@@ -387,77 +393,27 @@ class RulesForRoom:
|
||||
self.room_push_rule_cache_metrics.inc_hits()
|
||||
return self.data.rules_by_user
|
||||
|
||||
self.room_push_rule_cache_metrics.inc_misses()
|
||||
|
||||
ret_rules_by_user = {}
|
||||
missing_member_event_ids = {}
|
||||
if state_group and self.data.state_group == context.prev_group:
|
||||
# If we have a simple delta then we can reuse most of the previous
|
||||
# results.
|
||||
ret_rules_by_user = self.data.rules_by_user
|
||||
current_state_ids = context.delta_ids
|
||||
|
||||
push_rules_delta_state_cache_metric.inc_hits()
|
||||
else:
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
push_rules_delta_state_cache_metric.inc_misses()
|
||||
# Ensure the state IDs exist.
|
||||
assert current_state_ids is not None
|
||||
|
||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||
|
||||
logger.debug(
|
||||
"Looking for member changes in %r %r", state_group, current_state_ids
|
||||
local_users = await self.store.get_local_users_in_room(
|
||||
self.room_id, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
# Loop through to see which member events we've seen and have rules
|
||||
# for and which we need to fetch
|
||||
for key in current_state_ids:
|
||||
typ, user_id = key
|
||||
if typ != EventTypes.Member:
|
||||
continue
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
if self.is_mine_id(event.state_key):
|
||||
local_users = list(local_users)
|
||||
local_users.append(event.state_key)
|
||||
|
||||
if user_id in self.data.uninteresting_user_set:
|
||||
continue
|
||||
ret_rules_by_user = await self.store.bulk_get_push_rules(
|
||||
local_users, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
if not self.is_mine_id(user_id):
|
||||
self.data.uninteresting_user_set.add(user_id)
|
||||
continue
|
||||
logger.info("Users in room: %s", local_users)
|
||||
|
||||
if self.store.get_if_app_services_interested_in_user(user_id):
|
||||
self.data.uninteresting_user_set.add(user_id)
|
||||
continue
|
||||
|
||||
event_id = current_state_ids[key]
|
||||
|
||||
res = self.data.member_map.get(event_id, None)
|
||||
if res:
|
||||
if res.membership == Membership.JOIN:
|
||||
rules = self.data.rules_by_user.get(res.user_id, None)
|
||||
if rules:
|
||||
ret_rules_by_user[res.user_id] = rules
|
||||
continue
|
||||
|
||||
# If a user has left a room we remove their push rule. If they
|
||||
# joined then we re-add it later in _update_rules_with_member_event_ids
|
||||
ret_rules_by_user.pop(user_id, None)
|
||||
missing_member_event_ids[user_id] = event_id
|
||||
|
||||
if missing_member_event_ids:
|
||||
# If we have some member events we haven't seen, look them up
|
||||
# and fetch push rules for them if appropriate.
|
||||
logger.debug("Found new member events %r", missing_member_event_ids)
|
||||
await self._update_rules_with_member_event_ids(
|
||||
ret_rules_by_user, missing_member_event_ids, state_group, event
|
||||
)
|
||||
else:
|
||||
# The push rules didn't change but lets update the cache anyway
|
||||
self.update_cache(
|
||||
self.data.sequence,
|
||||
members={}, # There were no membership changes
|
||||
rules_by_user=ret_rules_by_user,
|
||||
state_group=state_group,
|
||||
)
|
||||
self.update_cache(
|
||||
self.data.sequence,
|
||||
members={}, # There were no membership changes
|
||||
rules_by_user=ret_rules_by_user,
|
||||
state_group=state_group,
|
||||
)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
@@ -465,67 +421,6 @@ class RulesForRoom:
|
||||
)
|
||||
return ret_rules_by_user
|
||||
|
||||
async def _update_rules_with_member_event_ids(
|
||||
self,
|
||||
ret_rules_by_user: Dict[str, list],
|
||||
member_event_ids: Dict[str, str],
|
||||
state_group: Optional[int],
|
||||
event: EventBase,
|
||||
) -> None:
|
||||
"""Update the partially filled rules_by_user dict by fetching rules for
|
||||
any newly joined users in the `member_event_ids` list.
|
||||
|
||||
Args:
|
||||
ret_rules_by_user: Partially filled dict of push rules. Gets
|
||||
updated with any new rules.
|
||||
member_event_ids: Dict of user id to event id for membership events
|
||||
that have happened since the last time we filled rules_by_user
|
||||
state_group: The state group we are currently computing push rules
|
||||
for. Used when updating the cache.
|
||||
event: The event we are currently computing push rules for.
|
||||
"""
|
||||
sequence = self.data.sequence
|
||||
|
||||
members = await self.store.get_membership_from_event_ids(
|
||||
member_event_ids.values()
|
||||
)
|
||||
|
||||
# If the event is a join event then it will be in current state events
|
||||
# map but not in the DB, so we have to explicitly insert it.
|
||||
if event.type == EventTypes.Member:
|
||||
for event_id in member_event_ids.values():
|
||||
if event_id == event.event_id:
|
||||
members[event_id] = EventIdMembership(
|
||||
user_id=event.state_key, membership=event.membership
|
||||
)
|
||||
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||
|
||||
joined_user_ids = {
|
||||
entry.user_id
|
||||
for entry in members.values()
|
||||
if entry and entry.membership == Membership.JOIN
|
||||
}
|
||||
|
||||
logger.debug("Joined: %r", joined_user_ids)
|
||||
|
||||
# Previously we only considered users with pushers or read receipts in that
|
||||
# room. We can't do this anymore because we use push actions to calculate unread
|
||||
# counts, which don't rely on the user having pushers or sent a read receipt into
|
||||
# the room. Therefore we just need to filter for local users here.
|
||||
user_ids = list(filter(self.is_mine_id, joined_user_ids))
|
||||
|
||||
rules_by_user = await self.store.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=self.invalidate_all_cb
|
||||
)
|
||||
|
||||
ret_rules_by_user.update(
|
||||
item for item in rules_by_user.items() if item[0] is not None
|
||||
)
|
||||
|
||||
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
||||
|
||||
def update_cache(
|
||||
self,
|
||||
sequence: int,
|
||||
|
||||
@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
|
||||
assert_user_is_admin,
|
||||
)
|
||||
from synapse.storage.databases.main.room import RoomSortOrder
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, RoomID, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
|
||||
@@ -447,7 +448,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
super().__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_POST(
|
||||
@@ -489,8 +490,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
)
|
||||
|
||||
# send invite if room has "JoinRules.INVITE"
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
|
||||
join_rules_event = await self.store.get_current_state_event(
|
||||
room_id, EventTypes.JoinRules, ""
|
||||
)
|
||||
if join_rules_event:
|
||||
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
||||
# update_membership with an action of "invite" can raise a
|
||||
@@ -552,12 +554,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
user_to_add = content.get("user_id", requester.user.to_string())
|
||||
|
||||
# Figure out which local users currently have power in the room, if any.
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
if not room_state:
|
||||
filtered_room_state = await self.store.get_filtered_current_state(
|
||||
room_id,
|
||||
StateFilter.from_types(
|
||||
[
|
||||
(EventTypes.Create, ""),
|
||||
(EventTypes.PowerLevels, ""),
|
||||
(EventTypes.JoinRules, ""),
|
||||
(EventTypes.Member, user_to_add),
|
||||
]
|
||||
),
|
||||
)
|
||||
if not filtered_room_state:
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
|
||||
|
||||
create_event = room_state[(EventTypes.Create, "")]
|
||||
power_levels = room_state.get((EventTypes.PowerLevels, ""))
|
||||
create_event = filtered_room_state[(EventTypes.Create, "")]
|
||||
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
|
||||
|
||||
if power_levels is not None:
|
||||
# We pick the local user with the highest power.
|
||||
@@ -633,7 +645,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
|
||||
# Now we check if the user we're granting admin rights to is already in
|
||||
# the room. If not and it's not a public room we invite them.
|
||||
member_event = room_state.get((EventTypes.Member, user_to_add))
|
||||
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
|
||||
is_joined = False
|
||||
if member_event:
|
||||
is_joined = member_event.content["membership"] in (
|
||||
@@ -644,7 +656,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
if is_joined:
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
join_rules = room_state.get((EventTypes.JoinRules, ""))
|
||||
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
|
||||
is_public = False
|
||||
if join_rules:
|
||||
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
|
||||
|
||||
@@ -673,7 +673,7 @@ class RoomEventServlet(RestServlet):
|
||||
if include_unredacted_content and not await self.auth.is_server_admin(
|
||||
requester.user
|
||||
):
|
||||
power_level_event = await self._state.get_current_state(
|
||||
power_level_event = await self._store.get_current_state_event(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
)
|
||||
|
||||
|
||||
@@ -178,8 +178,8 @@ class ResourceLimitsServerNotices:
|
||||
currently_blocked = False
|
||||
pinned_state_event = None
|
||||
try:
|
||||
pinned_state_event = await self._state.get_current_state(
|
||||
room_id, event_type=EventTypes.Pinned
|
||||
pinned_state_event = await self._store.get_current_state_event(
|
||||
room_id, event_type=EventTypes.Pinned, state_key=""
|
||||
)
|
||||
except AuthError:
|
||||
# The user has yet to join the server notices room
|
||||
|
||||
@@ -32,13 +32,11 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from prometheus_client import Counter, Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||
@@ -132,85 +130,20 @@ class StateHandler:
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
self._storage = hs.get_storage()
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Literal[None] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> StateMap[EventBase]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Union[Optional[EventBase], StateMap[EventBase]]:
|
||||
"""Retrieves the current state for the room. This is done by
|
||||
calling `get_latest_events_in_room` to get the leading edges of the
|
||||
event graph and then resolving any of the state conflicts.
|
||||
|
||||
This is equivalent to getting the state of an event that were to send
|
||||
next before receiving any new events.
|
||||
|
||||
Returns:
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
|
||||
Otherwise, a map from (type, state_key) to event.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
event_id = state.get((event_type, state_key))
|
||||
event = None
|
||||
if event_id:
|
||||
event = await self.store.get_event(event_id, allow_none=True)
|
||||
return event
|
||||
|
||||
state_map = await self.store.get_events(
|
||||
list(state.values()), get_prev_content=False
|
||||
)
|
||||
return {
|
||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
||||
}
|
||||
|
||||
async def get_current_state_ids(
|
||||
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
|
||||
self,
|
||||
room_id: str,
|
||||
latest_event_ids: Collection[str],
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
room_id:
|
||||
latest_event_ids: if given, the forward extremities to resolve. If
|
||||
None, we look them up from the database (via a cache).
|
||||
latest_event_ids: The forward extremities to resolve.
|
||||
|
||||
Returns:
|
||||
the state dict, mapping from (event_type, state_key) -> event_id
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return ret.state
|
||||
@@ -239,10 +172,6 @@ class StateHandler:
|
||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
|
||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> FrozenSet[str]:
|
||||
@@ -261,7 +190,7 @@ class StateHandler:
|
||||
async def compute_event_context(
|
||||
self,
|
||||
event: EventBase,
|
||||
old_state: Optional[Iterable[EventBase]] = None,
|
||||
state_ids_before_event: Optional[StateMap[str]] = None,
|
||||
partial_state: bool = False,
|
||||
) -> EventContext:
|
||||
"""Build an EventContext structure for a non-outlier event.
|
||||
@@ -273,12 +202,12 @@ class StateHandler:
|
||||
|
||||
Args:
|
||||
event:
|
||||
old_state: The state at the event if it can't be
|
||||
state_ids_before_event: The state at the event if it can't be
|
||||
calculated from existing events. This is normally only specified
|
||||
when receiving an event from federation where we don't have the
|
||||
prev events for, e.g. when backfilling.
|
||||
partial_state: True if `old_state` is partial and omits non-critical
|
||||
membership events
|
||||
partial_state: True if `state_ids_before_event` is partial and omits
|
||||
non-critical membership events
|
||||
Returns:
|
||||
The event context.
|
||||
"""
|
||||
@@ -288,11 +217,7 @@ class StateHandler:
|
||||
#
|
||||
# first of all, figure out the state before the event
|
||||
#
|
||||
if old_state:
|
||||
# if we're given the state before the event, then we use that
|
||||
state_ids_before_event: StateMap[str] = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
if state_ids_before_event:
|
||||
state_group_before_event = None
|
||||
state_group_before_event_prev_group = None
|
||||
deltas_to_state_group_before_event = None
|
||||
|
||||
@@ -74,6 +74,9 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_users_in_room_with_profiles", (room_id,)
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_number_joined_users_in_room.invalidate", (room_id,)
|
||||
)
|
||||
|
||||
# Purge other caches based on room state.
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
|
||||
@@ -217,6 +217,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
if etype == EventTypes.Member:
|
||||
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
|
||||
self.get_invited_rooms_for_local_user.invalidate((state_key,))
|
||||
self.get_local_users_in_room.invalidate((room_id,))
|
||||
self.get_number_joined_users_in_room((room_id,))
|
||||
|
||||
if relates_to:
|
||||
self.get_relations_for_event.invalidate((relates_to,))
|
||||
|
||||
@@ -1766,6 +1766,14 @@ class PersistEventsStore:
|
||||
self.store.get_invited_rooms_for_local_user.invalidate,
|
||||
(event.state_key,),
|
||||
)
|
||||
txn.call_after(
|
||||
self.store.get_local_users_in_room.invalidate,
|
||||
(event.room_id,),
|
||||
)
|
||||
txn.call_after(
|
||||
self.store.get_number_joined_users_in_room.invalidate,
|
||||
(event.room_id,),
|
||||
)
|
||||
|
||||
# The `_get_membership_from_event_id` is immutable, except for the
|
||||
# case where we look up an event *before* persisting it.
|
||||
|
||||
@@ -337,6 +337,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
"get_room_summary", _get_room_summary_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def get_number_joined_users_in_room(self, room_id: str) -> int:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="current_state_events",
|
||||
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
|
||||
retcol="COUNT(*)",
|
||||
desc="get_number_joined_users_in_room",
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def get_invited_rooms_for_local_user(
|
||||
self, user_id: str
|
||||
@@ -444,6 +453,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
return results
|
||||
|
||||
@cached()
|
||||
async def get_local_users_in_room(self, room_id: str) -> List[str]:
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
table="local_current_membership",
|
||||
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
|
||||
retcol="user_id",
|
||||
desc="get_local_users_in_room",
|
||||
)
|
||||
|
||||
async def get_local_current_membership_for_user_in_room(
|
||||
self, user_id: str, room_id: str
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
@@ -869,6 +887,29 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
return True
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
"""Get current hosts in room."""
|
||||
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
users = await self.get_users_in_room(room_id)
|
||||
return {get_domain_from_id(u) for u in users}
|
||||
|
||||
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
sql = """
|
||||
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
|
||||
FROM current_state_events
|
||||
WHERE
|
||||
type = 'm.room.member'
|
||||
AND membership = 'join'
|
||||
AND room_id = ?
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
return {d for d, in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||
)
|
||||
|
||||
async def get_joined_hosts(
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
|
||||
@@ -16,6 +16,8 @@ import collections.abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
@@ -26,6 +28,7 @@ from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
@@ -43,6 +46,15 @@ logger = logging.getLogger(__name__)
|
||||
MAX_STATE_DELTA_HOPS = 100
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class EventMetadata:
|
||||
"""Returned by `get_metadata_for_events`"""
|
||||
|
||||
room_id: str
|
||||
event_type: str
|
||||
state_key: Optional[str]
|
||||
|
||||
|
||||
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||
if not v:
|
||||
@@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
return room_version
|
||||
|
||||
async def get_metadata_for_events(
|
||||
self, event_ids: Collection[str]
|
||||
) -> Dict[str, EventMetadata]:
|
||||
"""Get some metadata (room_id, type, state_key) for the given events."""
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "e.event_id", event_ids
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
|
||||
LEFT JOIN state_events USING (event_id)
|
||||
WHERE {clause}
|
||||
"""
|
||||
|
||||
def get_metadata_for_events_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, EventMetadata]:
|
||||
txn.execute(sql, args)
|
||||
return {
|
||||
event_id: EventMetadata(
|
||||
room_id=room_id, event_type=event_type, state_key=state_key
|
||||
)
|
||||
for event_id, room_id, event_type, state_key in txn
|
||||
}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_metadata_for_events", get_metadata_for_events_txn
|
||||
)
|
||||
|
||||
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
|
||||
"""Get the predecessor of an upgraded room if it exists.
|
||||
Otherwise return None.
|
||||
@@ -218,6 +260,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"get_current_state_ids", _get_current_state_ids_txn
|
||||
)
|
||||
|
||||
async def get_current_state(self, room_id: str) -> StateMap[EventBase]:
|
||||
"""Same as `get_current_state_ids` but also fetches the events"""
|
||||
state_map_ids = await self.get_current_state_ids(room_id)
|
||||
|
||||
event_map = await self.get_events(list(state_map_ids.values()))
|
||||
|
||||
state_map = {}
|
||||
for key, event_id in state_map_ids.items():
|
||||
event = event_map.get(event_id)
|
||||
if event:
|
||||
state_map[key] = event
|
||||
|
||||
return state_map
|
||||
|
||||
# FIXME: how should this be cached?
|
||||
async def get_filtered_current_state_ids(
|
||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||
@@ -269,6 +325,39 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||
)
|
||||
|
||||
async def get_filtered_current_state(
|
||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[EventBase]:
|
||||
"""Same as `get_filtered_current_state_ids` but also fetches the events"""
|
||||
state_map_ids = await self.get_filtered_current_state_ids(room_id, state_filter)
|
||||
|
||||
event_map = await self.get_events(list(state_map_ids.values()))
|
||||
|
||||
state_map = {}
|
||||
for key, event_id in state_map_ids.items():
|
||||
event = event_map.get(event_id)
|
||||
if event:
|
||||
state_map[key] = event
|
||||
|
||||
return state_map
|
||||
|
||||
async def get_current_state_event(
|
||||
self, room_id: str, event_type: str, state_key: str
|
||||
) -> Optional[EventBase]:
|
||||
"""Get the current state event for the given type/state_key."""
|
||||
|
||||
key = (event_type, state_key)
|
||||
state_map = await self.get_filtered_current_state_ids(
|
||||
room_id, StateFilter.from_types((key,))
|
||||
)
|
||||
event_id = state_map.get(key)
|
||||
|
||||
event = None
|
||||
if event_id:
|
||||
event = await self.get_event(event_id, allow_none=True)
|
||||
|
||||
return event
|
||||
|
||||
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
||||
"""Get canonical alias for room, if any
|
||||
|
||||
|
||||
@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
|
||||
# Ensure a new Awaitable is created for each call.
|
||||
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
|
||||
["test", "host2"]
|
||||
)
|
||||
return self.setup_test_homeserver(
|
||||
state_handler=mock_state_handler,
|
||||
hs = self.setup_test_homeserver(
|
||||
federation_transport_client=Mock(spec=["send_transaction"]),
|
||||
)
|
||||
|
||||
hs.get_datastores().main.get_current_hosts_in_room = Mock(
|
||||
return_value=make_awaitable(["test", "host2"])
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
@override_config({"send_federation": True})
|
||||
def test_send_receipts(self):
|
||||
mock_send_transaction = (
|
||||
|
||||
@@ -207,7 +207,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
# the room should show that the new user is a member
|
||||
r = self.get_success(
|
||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
||||
self.hs.get_datastores().main.get_current_state(self._room_id)
|
||||
)
|
||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||
|
||||
@@ -258,7 +258,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
# the room should show that the new user is a member
|
||||
r = self.get_success(
|
||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
||||
self.hs.get_datastores().main.get_current_state(self._room_id)
|
||||
)
|
||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||
|
||||
|
||||
@@ -335,7 +335,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
def _get_canonical_alias(self):
|
||||
"""Get the canonical alias state of the room."""
|
||||
return self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self.store.get_current_state_event(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
|
||||
@@ -276,7 +276,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
# federation handler wanting to backfill the fake event.
|
||||
self.get_success(
|
||||
federation_event_handler._process_received_pdu(
|
||||
self.OTHER_SERVER_NAME, event, state=current_state
|
||||
self.OTHER_SERVER_NAME,
|
||||
event,
|
||||
state={(e.type, e.state_key): e.event_id for e in current_state},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
def persist_event(self, event, state=None):
|
||||
"""Persist the event, with optional state"""
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(event, old_state=state)
|
||||
self.state.compute_event_context(event, state_ids_before_event=state)
|
||||
)
|
||||
self.get_success(self.persistence.persist_event(event, context))
|
||||
|
||||
@@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
@@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
# setting. The state resolution across the old and new event will then
|
||||
# include it, and so the resolved state won't match the new state.
|
||||
state_before_gap = dict(
|
||||
self.get_success(self.state.get_current_state(self.room_id))
|
||||
self.get_success(self.store.get_current_state_ids(self.room_id))
|
||||
)
|
||||
state_before_gap.pop(("m.room.history_visibility", ""))
|
||||
|
||||
context = self.get_success(
|
||||
self.state.compute_event_context(
|
||||
remote_event_2, old_state=state_before_gap.values()
|
||||
remote_event_2,
|
||||
state_ids_before_event=state_before_gap,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
@@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
|
||||
@@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id])
|
||||
@@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
|
||||
@@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
|
||||
RoomVersions.V6,
|
||||
)
|
||||
|
||||
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
|
||||
state_before_gap = self.get_success(
|
||||
self.store.get_current_state_ids(self.room_id)
|
||||
)
|
||||
|
||||
self.persist_event(remote_event_2, state=state_before_gap.values())
|
||||
self.persist_event(remote_event_2, state=state_before_gap)
|
||||
|
||||
# Check the new extremity is just the new remote event.
|
||||
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
|
||||
|
||||
@@ -98,9 +98,8 @@ class PurgeTests(HomeserverTestCase):
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
|
||||
# Get the current room state.
|
||||
state_handler = self.hs.get_state_handler()
|
||||
create_event = self.get_success(
|
||||
state_handler.get_current_state(self.room_id, "m.room.create", "")
|
||||
self.store.get_current_state_event(self.room_id, "m.room.create", "")
|
||||
)
|
||||
self.assertIsNotNone(create_event)
|
||||
|
||||
|
||||
@@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
|
||||
]
|
||||
|
||||
context = yield defer.ensureDeferred(
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in old_state
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
@@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
|
||||
]
|
||||
|
||||
context = yield defer.ensureDeferred(
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
self.state.compute_event_context(
|
||||
event,
|
||||
state_ids_before_event={
|
||||
(e.type, e.state_key): e.event_id for e in old_state
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
|
||||
Reference in New Issue
Block a user