Compare commits

...

4 Commits

Author SHA1 Message Date
Erik Johnston
98cb4b8755 Add batching 2024-10-29 10:10:12 +00:00
Erik Johnston
9a482a61a9 fixup 2024-10-29 10:10:12 +00:00
Erik Johnston
fb751d3914 Use fast path 2024-10-29 10:10:12 +00:00
Erik Johnston
35d797a9c4 Add history visibility index table 2024-10-29 10:10:12 +00:00
6 changed files with 275 additions and 20 deletions

View File

@@ -49,7 +49,7 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
@@ -635,6 +635,44 @@ class EventsPersistenceStorageController:
room_id, [e for e, _ in chunk]
)
visibilities: Dict[str, str] = {}
with Measure(self._clock, "calculate_history_vis"):
# TODO: We only need to do this on changes, rather than looking
# up the state for every event
for event, context in events_and_contexts:
if (
backfilled
or event.internal_metadata.is_outlier()
or context.rejected
):
continue
state = await context.get_current_state_ids(
StateFilter.from_types([(EventTypes.RoomHistoryVisibility, "")])
)
# We're not an outlier
assert state is not None
history_visibility = HistoryVisibility.SHARED
history_visibility_event_id = state.get(
(EventTypes.RoomHistoryVisibility, "")
)
if history_visibility_event_id:
for event, _ in events_and_contexts:
if event.event_id == history_visibility_event_id:
history_visibility_event = event
break
else:
history_visibility_event = await self.main_store.get_event(
history_visibility_event_id,
get_prev_content=False,
)
history_visibility = history_visibility_event.content.get(
"history_visibility", HistoryVisibility.SHARED
)
visibilities[event.event_id] = history_visibility
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
@@ -643,6 +681,7 @@ class EventsPersistenceStorageController:
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
visibilities=visibilities,
)
return replaced_events

View File

@@ -31,6 +31,7 @@ from typing import (
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
@@ -271,6 +272,7 @@ class PersistEventsStore:
new_event_links: Dict[str, NewEventChainLinks],
use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False,
visibilities: Mapping[str, str] = {},
) -> None:
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -355,6 +357,7 @@ class PersistEventsStore:
new_forward_extremities=new_forward_extremities,
new_event_links=new_event_links,
sliding_sync_table_changes=sliding_sync_table_changes,
visibilities=visibilities,
)
persist_event_counter.inc(len(events_and_contexts))
@@ -874,6 +877,7 @@ class PersistEventsStore:
new_forward_extremities: Optional[Set[str]],
new_event_links: Dict[str, NewEventChainLinks],
sliding_sync_table_changes: Optional[SlidingSyncTableChanges],
visibilities: Mapping[str, str] = {},
) -> None:
"""Insert some number of room events into the necessary database tables.
@@ -1027,6 +1031,52 @@ class PersistEventsStore:
txn, room_id, events_and_contexts
)
changes = [
(visibilities[event.event_id], event.internal_metadata.stream_ordering)
for event, context in events_and_contexts
if event.event_id in visibilities
]
if changes:
sql = """
SELECT visibility, start_range FROM history_visibility_ranges
WHERE room_id = ?
ORDER BY start_range DESC
LIMIT 1
"""
txn.execute(sql, (room_id,))
row = txn.fetchone()
prev_visibility = None
start_range = None
if row:
(
prev_visibility,
start_range,
) = row
for new_visibility, stream_ordering in changes:
assert stream_ordering is not None
if new_visibility != prev_visibility:
if start_range is not None:
self.db_pool.simple_update_one_txn(
txn,
table="history_visibility_ranges",
keyvalues={"room_id": room_id, "start_range": start_range},
updatevalues={"end_range": stream_ordering},
)
self.db_pool.simple_insert_txn(
txn,
table="history_visibility_ranges",
values={
"room_id": room_id,
"visibility": new_visibility,
"start_range": stream_ordering,
"end_range": None,
},
)
prev_visibility = new_visibility
start_range = stream_ordering
def _persist_event_auth_chain_txn(
self,
txn: LoggingTransaction,

View File

@@ -14,14 +14,16 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, cast
from typing import TYPE_CHECKING, Collection, Dict, List, Mapping, Optional, Set, cast
import attr
from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.events import EventBase
from synapse.logging.opentracing import log_kv
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import PostgresEngine
from synapse.types import MultiWriterStreamToken, RoomStreamToken
from synapse.types.handlers.sliding_sync import (
HaveSentRoom,
@@ -451,6 +453,69 @@ class SlidingSyncStore(SQLBaseStore):
room_configs=room_configs,
)
async def get_visibility_for_events(
self, room_id: str, events: Collection[EventBase]
) -> Mapping[str, Optional[str]]:
def get_visibility_for_events_txn(
txn: LoggingTransaction,
) -> Mapping[str, Optional[str]]:
if isinstance(txn.database_engine, PostgresEngine):
sql = """
SELECT start_range, end_range, visibility FROM history_visibility_ranges
WHERE int8range(start_range, end_range, '[)') @> ANY(?::bigint[])
AND room_id = ?
"""
stream_orderings = [
event.internal_metadata.stream_ordering for event in events
]
txn.execute(sql, (stream_orderings, room_id))
ranges = [
((start_range, end_range), visibility)
for start_range, end_range, visibility in txn
]
results: Dict[str, Optional[str]] = {}
for event in events:
stream_ordering = event.internal_metadata.stream_ordering
for (start_range, end_range), visibility in ranges:
if stream_ordering < start_range:
continue
if end_range is not None and end_range <= stream_ordering:
continue
results[event.event_id] = visibility
break
return results
else:
sql = """
SELECT visibility FROM history_visibility_ranges
WHERE start_range <= ? AND (? < end_range OR end_range IS NULL)
AND room_id = ?
"""
results = {}
for event in events:
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.internal_metadata.stream_ordering,
room_id,
),
)
row = txn.fetchone()
if row is not None:
results[event.event_id] = row[0]
return results
return await self.db_pool.runInteraction(
"get_visibility_for_events", get_visibility_for_events_txn
)
@attr.s(auto_attribs=True, frozen=True)
class PerConnectionStateDB:

View File

@@ -0,0 +1,25 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
CREATE TABLE IF NOT EXISTS history_visibility_ranges (
room_id TEXT NOT NULL,
visibility TEXT NOT NULL,
start_range BIGINT NOT NULL,
end_range BIGINT
);
CREATE INDEX history_visibility_ranges_idx ON history_visibility_ranges(room_id, start_range, end_range DESC);
CREATE UNIQUE INDEX history_visibility_ranges_uniq_idx ON history_visibility_ranges(room_id, start_range);
-- CREATE EXTENSION IF NOT EXISTS btree_gist;
-- CREATE INDEX history_visibility_ranges_idx_gist ON history_visibility_ranges USING gist(room_id, int8range(start_range, end_range, '[)]'));

View File

@@ -0,0 +1,15 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
CREATE EXTENSION IF NOT EXISTS btree_gist;
CREATE INDEX history_visibility_ranges_idx_gist ON history_visibility_ranges USING gist(room_id, int8range(start_range, end_range, '[)'));

View File

@@ -105,6 +105,7 @@ async def filter_events_for_client(
The filtered events. The `unsigned` data is annotated with the membership state
of `user_id` at each event.
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
events_before_filtering = events
@@ -117,13 +118,41 @@ async def filter_events_for_client(
[event.event_id for event in events],
)
types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
types = (
_HISTORY_VIS_KEY,
(EventTypes.Member, user_id),
)
if not events:
return []
room_id = events[0].room_id
assert all(event.room_id == room_id for event in events)
visibilities: Dict[str, str] = {}
memberships: Dict[str, Optional[EventBase]] = {}
events_to_fetch = {e.event_id for e in events if not e.internal_metadata.outlier}
if not is_peeking:
fetched_visibilities = await storage.main.get_visibility_for_events(
room_id, [e for e in events if not e.internal_metadata.outlier]
)
for event_id, visibility in fetched_visibilities.items():
if visibility in (
HistoryVisibility.SHARED,
HistoryVisibility.WORLD_READABLE,
):
events_to_fetch.discard(event_id)
visibilities[event_id] = visibility
# we exclude outliers at this point, and then handle them separately later
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
state_filter=StateFilter.from_types(types),
)
if events_to_fetch:
event_id_to_state = await storage.state.get_state_for_events(
events_to_fetch,
state_filter=StateFilter.from_types(types),
)
for event_id, state in event_id_to_state.items():
visibilities[event_id] = get_effective_room_visibility_from_state(state)
memberships[event_id] = state.get((EventTypes.Member, user_id))
# Get the users who are ignored by the requesting user.
ignore_list = await storage.main.ignored_users(user_id)
@@ -140,8 +169,8 @@ async def filter_events_for_client(
] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event: EventBase) -> Optional[EventBase]:
state_after_event = event_id_to_state.get(event.event_id)
filtered = _check_client_allowed_to_see_event(
# state_after_event = event_id_to_state.get(event.event_id)
filtered = _check_client_allowed_to_see_event_with_state(
user_id=user_id,
event=event,
clock=storage.main.clock,
@@ -149,9 +178,10 @@ async def filter_events_for_client(
sender_ignored=event.sender in ignore_list,
always_include_ids=always_include_ids,
retention_policy=retention_policies[event.room_id],
state=state_after_event,
is_peeking=is_peeking,
sender_erased=erased_senders.get(event.sender, False),
visibility=visibilities[event.event_id],
membership_event=memberships.get(event.event_id),
)
if filtered is None:
return None
@@ -165,11 +195,9 @@ async def filter_events_for_client(
user_membership_event: Optional[EventBase]
if event.type == EventTypes.Member and event.state_key == user_id:
user_membership_event = event
elif state_after_event is not None:
user_membership_event = state_after_event.get((EventTypes.Member, user_id))
else:
# unreachable!
raise Exception("Missing state for event that is not user's own membership")
# TODO: Actually get the proper membership
user_membership_event = memberships.get(event_id)
user_membership = (
user_membership_event.membership
@@ -353,6 +381,41 @@ def _check_client_allowed_to_see_event(
the original event if they can see it as normal.
"""
visibility = HistoryVisibility.SHARED
if state is not None:
visibility = get_effective_room_visibility_from_state(state)
membership_event = state.get((EventTypes.Member, user_id)) if state else None
return _check_client_allowed_to_see_event_with_state(
user_id,
event,
clock,
filter_send_to_client,
is_peeking,
always_include_ids,
sender_ignored,
retention_policy,
sender_erased,
visibility=visibility,
membership_event=membership_event,
)
def _check_client_allowed_to_see_event_with_state(
user_id: str,
event: EventBase,
clock: Clock,
filter_send_to_client: bool,
is_peeking: bool,
always_include_ids: FrozenSet[str],
sender_ignored: bool,
retention_policy: RetentionPolicy,
sender_erased: bool,
visibility: str,
membership_event: Optional[EventBase],
) -> Optional[EventBase]:
# Only run some checks if these events aren't about to be sent to clients. This is
# because, if this is not the case, we're probably only checking if the users can
# see events in the room at that point in the DAG, and that shouldn't be decided
@@ -390,12 +453,6 @@ def _check_client_allowed_to_see_event(
)
return None
if state is None:
raise Exception("Missing state for non-outlier event")
# get the room_visibility at the time of the event.
visibility = get_effective_room_visibility_from_state(state)
# Check if the room has lax history visibility, allowing us to skip
# membership checks.
#
@@ -408,6 +465,10 @@ def _check_client_allowed_to_see_event(
):
return event
if membership_event:
state = {(EventTypes.Member, user_id): membership_event}
else:
state = {}
membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
if not membership_result.allowed:
filtered_event_logger.debug(