Compare commits

...

25 Commits

Author SHA1 Message Date
Erik Johnston
b4fd84de35 Comments 2025-01-24 15:19:39 +00:00
Erik Johnston
ce41c93878 Clean up 2025-01-24 15:16:48 +00:00
Erik Johnston
1e26c26ffc Add function to check if we can delete state groups 2025-01-24 11:27:33 +00:00
Erik Johnston
d7cd57a848 Add row level lock 2025-01-24 11:27:20 +00:00
Erik Johnston
e0d84ab922 Add 'state_groups_persisting' table 2025-01-23 14:56:28 +00:00
Erik Johnston
d5c1473c00 comments 2025-01-23 11:29:46 +00:00
Erik Johnston
721512cad0 Check for deletions 2025-01-22 16:54:01 +00:00
Erik Johnston
e4887f823a Fixup 2025-01-22 14:18:07 +00:00
Erik Johnston
38246d23f6 Comment 2025-01-22 14:15:45 +00:00
Erik Johnston
ed1b4eff45 Check more stuff 2025-01-22 14:14:27 +00:00
Erik Johnston
3f2e5730f3 Fix tests 2025-01-22 14:05:04 +00:00
Erik Johnston
192fd59943 Check state epoch on state group persistence 2025-01-22 14:02:24 +00:00
Erik Johnston
0b4eeacb58 Move stuff 2025-01-22 11:55:12 +00:00
Erik Johnston
ae4a6304cd Check for deleted state groups 2025-01-21 10:56:12 +00:00
Erik Johnston
c02938c670 Check that we have persisted the events 2025-01-21 10:09:29 +00:00
Erik Johnston
ef1daf38bd WIP 2025-01-20 17:22:34 +00:00
Erik Johnston
73603ef1db WIP 2025-01-17 11:08:59 +00:00
Erik Johnston
6318af4a46 Start adding state_epoch to eventcontext 2025-01-09 15:57:12 +00:00
Erik Johnston
4268448971 Fixup inheritance 2025-01-09 15:35:23 +00:00
Erik Johnston
1a5950a39e Expand comment 2025-01-09 14:34:56 +00:00
Erik Johnston
37441464f2 More move 2025-01-09 14:10:25 +00:00
Erik Johnston
ec4c476270 Move stuff 2025-01-09 14:08:23 +00:00
Erik Johnston
9db36a6ab1 Add helper functions 2025-01-08 16:14:51 +00:00
Erik Johnston
73a4d298c8 Periodically advance epoch 2025-01-08 16:06:20 +00:00
Erik Johnston
b63f5b6580 WIP 2025-01-08 16:06:20 +00:00
11 changed files with 523 additions and 47 deletions

View File

@@ -36,6 +36,7 @@ if TYPE_CHECKING:
from synapse.types.state import StateFilter
@attr.s(slots=True, auto_attribs=True)
class UnpersistedEventContextBase(ABC):
"""
This is a base class for EventContext and UnpersistedEventContext, objects which
@@ -47,11 +48,12 @@ class UnpersistedEventContextBase(ABC):
_storage: storage controllers for interfacing with the database
app_service: If the associated event is being sent by a (local) application service, that
app service.
state_epoch: The state epoch of when we created the event, if not an outlier
"""
def __init__(self, storage_controller: "StorageControllers"):
self._storage: "StorageControllers" = storage_controller
self.app_service: Optional[ApplicationService] = None
_storage: "StorageControllers"
state_epoch: Optional[int]
app_service: Optional[ApplicationService] = attr.field(default=None, init=False)
@abstractmethod
async def persist(
@@ -132,13 +134,11 @@ class EventContext(UnpersistedEventContextBase):
incomplete state.
"""
_storage: "StorageControllers"
state_group_deltas: Dict[Tuple[int, int], StateMap[str]]
rejected: Optional[str] = None
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
_state_delta_due_to_event: Optional[StateMap[str]] = None
app_service: Optional[ApplicationService] = None
partial_state: bool = False
@@ -150,6 +150,7 @@ class EventContext(UnpersistedEventContextBase):
state_delta_due_to_event: Optional[StateMap[str]],
partial_state: bool,
state_group_deltas: Dict[Tuple[int, int], StateMap[str]],
state_epoch: int,
) -> "EventContext":
return EventContext(
storage=storage,
@@ -158,6 +159,7 @@ class EventContext(UnpersistedEventContextBase):
state_delta_due_to_event=state_delta_due_to_event,
state_group_deltas=state_group_deltas,
partial_state=partial_state,
state_epoch=state_epoch,
)
@staticmethod
@@ -165,7 +167,11 @@ class EventContext(UnpersistedEventContextBase):
storage: "StorageControllers",
) -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage, state_group_deltas={})
return EventContext(
storage=storage,
state_group_deltas={},
state_epoch=None,
)
async def persist(self, event: EventBase) -> "EventContext":
return self
@@ -191,6 +197,7 @@ class EventContext(UnpersistedEventContextBase):
),
"app_service_id": self.app_service.id if self.app_service else None,
"partial_state": self.partial_state,
"state_epoch": self.state_epoch,
}
@staticmethod
@@ -218,6 +225,7 @@ class EventContext(UnpersistedEventContextBase):
),
rejected=input["rejected"],
partial_state=input.get("partial_state", False),
state_epoch=input.get("state_epoch", None),
)
app_service_id = input["app_service_id"]
@@ -347,7 +355,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
A map of the state before the event, i.e. the state at `state_group_before_event`
"""
_storage: "StorageControllers"
state_epoch: int # `state_epoch` is required for `UnpersistedEventContext`
state_group_before_event: Optional[int]
state_group_after_event: Optional[int]
state_delta_due_to_event: Optional[StateMap[str]]
@@ -390,6 +398,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
partial_state=unpersisted_context.partial_state,
state_group_deltas=state_group_deltas,
state_epoch=unpersisted_context.state_epoch,
)
events_and_persisted_context.append((event, context))
return events_and_persisted_context
@@ -464,6 +473,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
state_delta_due_to_event=self.state_delta_due_to_event,
state_group_deltas=state_group_deltas,
partial_state=self.partial_state,
state_epoch=self.state_epoch,
)
def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]:

View File

@@ -151,6 +151,8 @@ class FederationEventHandler:
def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state_store = hs.get_datastores().state
self._state_epoch_store = hs.get_datastores().state_epochs
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -580,7 +582,9 @@ class FederationEventHandler:
room_version.identifier,
state_maps_to_resolve,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_epoch_store
),
)
)
else:
@@ -1179,7 +1183,9 @@ class FederationEventHandler:
room_version,
state_maps,
event_map={event_id: event},
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_epoch_store
),
)
except Exception as e:
@@ -1874,7 +1880,9 @@ class FederationEventHandler:
room_version,
[local_state_id_map, claimed_auth_events_id_map],
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_epoch_store
),
)
)
else:
@@ -2014,7 +2022,9 @@ class FederationEventHandler:
room_version,
state_sets,
event_map=None,
state_res_store=StateResolutionStore(self._store),
state_res_store=StateResolutionStore(
self._store, self._state_epoch_store
),
)
)
else:

View File

@@ -59,11 +59,13 @@ from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
from synapse.util.stringutils import shortstr
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore
from synapse.storage.databases.state.epochs import StateEpochDataStore
logger = logging.getLogger(__name__)
metrics_logger = logging.getLogger("synapse.state.metrics")
@@ -194,6 +196,8 @@ class StateHandler:
self._storage_controllers = hs.get_storage_controllers()
self._events_shard_config = hs.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._state_store = hs.get_datastores().state
self._state_epoch_store = hs.get_datastores().state_epochs
self._update_current_state_client = (
ReplicationUpdateCurrentStateRestServlet.make_client(hs)
@@ -311,6 +315,11 @@ class StateHandler:
"""
assert not event.internal_metadata.is_outlier()
# Record the state epoch before we start calculating state groups, to
# ensure that nothing we're relying on gets deleted. See the store class
# docstring for more information.
state_epoch = await self._state_epoch_store.get_state_epoch()
#
# first of all, figure out the state before the event, unless we
# already have it.
@@ -396,6 +405,7 @@ class StateHandler:
delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
state_map_before_event=state_ids_before_event,
state_epoch=state_epoch,
)
#
@@ -426,6 +436,7 @@ class StateHandler:
delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
state_map_before_event=state_ids_before_event,
state_epoch=state_epoch,
)
async def compute_event_context(
@@ -475,7 +486,10 @@ class StateHandler:
@trace
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: StrCollection, await_full_state: bool = True
self,
room_id: str,
event_ids: StrCollection,
await_full_state: bool = True,
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
@@ -495,6 +509,15 @@ class StateHandler:
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
# Check we have the events persisted as non-outliers, to ensure we have
# the state.
persisted_event_ids = await self.store.have_events_in_timeline(event_ids)
missing_event_ids = set(event_ids) - persisted_event_ids
if missing_event_ids:
raise Exception(
f"Trying to resolve state across events we have not persisted: {shortstr(missing_event_ids)}",
)
state_groups = await self._state_storage_controller.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
@@ -511,6 +534,19 @@ class StateHandler:
) = await self._state_storage_controller.get_state_group_delta(
state_group_id
)
if prev_group:
# Ensure that we still have the prev group, and ensure we don't
# delete it while we're persisting the event.
missing_state_group = (
await self._state_epoch_store.check_state_groups_and_bump_deletion(
{prev_group}
)
)
if missing_state_group:
prev_group = None
delta_ids = None
return _StateCacheEntry(
state=None,
state_group=state_group_id,
@@ -531,7 +567,7 @@ class StateHandler:
room_version,
state_to_resolve,
None,
state_res_store=StateResolutionStore(self.store),
state_res_store=StateResolutionStore(self.store, self._state_epoch_store),
)
return result
@@ -663,7 +699,25 @@ class StateResolutionHandler:
async with self.resolve_linearizer.queue(group_names):
cache = self._state_cache.get(group_names, None)
if cache:
return cache
# Check that the returned cache entry doesn't point to deleted
# state groups.
state_groups_to_check = set()
if cache.state_group is not None:
state_groups_to_check.add(cache.state_group)
if cache.prev_group is not None:
state_groups_to_check.add(cache.prev_group)
missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion(
state_groups_to_check
)
if not missing_state_groups:
return cache
else:
# There are missing state groups, so let's remove the stale
# entry and continue as if it was a cache miss.
self._state_cache.pop(group_names, None)
logger.info(
"Resolving state for %s with groups %s",
@@ -671,6 +725,16 @@ class StateResolutionHandler:
list(group_names),
)
# We double check that none of the state groups have been deleted.
# They shouldn't be as all these state groups should be referenced.
missing_state_groups = await state_res_store.state_epoch_store.check_state_groups_and_bump_deletion(
group_names
)
if missing_state_groups:
raise Exception(
f"State groups have been deleted: {shortstr(missing_state_groups)}"
)
state_groups_histogram.observe(len(state_groups_ids))
new_state = await self.resolve_events_with_store(
@@ -884,7 +948,8 @@ class StateResolutionStore:
in well defined way.
"""
store: "DataStore"
main_store: "DataStore"
state_epoch_store: "StateEpochDataStore"
def get_events(
self, event_ids: StrCollection, allow_rejected: bool = False
@@ -899,7 +964,7 @@ class StateResolutionStore:
An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
return self.main_store.get_events(
event_ids,
redact_behaviour=EventRedactBehaviour.as_is,
get_prev_content=False,
@@ -920,4 +985,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(room_id, state_sets)
return self.main_store.get_auth_chain_difference(room_id, state_sets)

View File

@@ -332,6 +332,7 @@ class EventsPersistenceStorageController:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
self._state_epoch_store = stores.state_epochs
assert stores.persist_events
self.persist_events_store = stores.persist_events
@@ -549,7 +550,9 @@ class EventsPersistenceStorageController:
room_version,
state_maps_by_state_group,
event_map=None,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_epoch_store
),
)
return await res.get_state(self._state_controller, StateFilter.all())
@@ -635,15 +638,20 @@ class EventsPersistenceStorageController:
room_id, [e for e, _ in chunk]
)
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)
# Stop the state groups from being deleted while we're persisting
# them.
async with self._state_epoch_store.persisting_state_group_references(
events_and_contexts
):
await self.persist_events_store._persist_events_and_state_updates(
room_id,
chunk,
state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
)
return replaced_events
@@ -965,7 +973,9 @@ class EventsPersistenceStorageController:
room_version,
state_groups,
events_map,
state_res_store=StateResolutionStore(self.main_store),
state_res_store=StateResolutionStore(
self.main_store, self._state_epoch_store
),
)
state_resolutions_during_persistence.inc()

View File

@@ -26,6 +26,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.databases.state.epochs import StateEpochDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
@@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]):
main
state
persist_events
state_epochs
"""
databases: List[DatabasePool]
main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class`
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]
state_epochs: StateEpochDataStore
def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
@@ -63,6 +66,7 @@ class Databases(Generic[DataStoreT]):
self.databases = []
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
state_epochs: Optional[StateEpochDataStore] = None
persist_events: Optional[PersistEventsStore] = None
for database_config in hs.config.database.databases:
@@ -114,7 +118,8 @@ class Databases(Generic[DataStoreT]):
if state:
raise Exception("'state' data store already configured")
state = StateGroupDataStore(database, db_conn, hs)
state_epochs = StateEpochDataStore(database, db_conn, hs)
state = StateGroupDataStore(database, db_conn, hs, state_epochs)
db_conn.commit()
@@ -135,7 +140,7 @@ class Databases(Generic[DataStoreT]):
if not main:
raise Exception("No 'main' database configured")
if not state:
if not state or not state_epochs:
raise Exception("No 'state' database configured")
# We use local variables here to ensure that the databases do not have
@@ -143,3 +148,4 @@ class Databases(Generic[DataStoreT]):
self.main = main # type: ignore[assignment]
self.state = state
self.persist_events = persist_events
self.state_epochs = state_epochs

View File

@@ -0,0 +1,332 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import contextlib
from typing import (
TYPE_CHECKING,
AbstractSet,
AsyncIterator,
Collection,
Set,
Tuple,
)
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.util.stringutils import shortstr
if TYPE_CHECKING:
from synapse.server import HomeServer
class StateEpochDataStore:
"""Manages state epochs and checks for state group deletion.
Deleting state groups is challenging as before we actually delete them we
need to ensure that there are no in-flight events that refer to the state
groups that we want to delete.
To handle this, we take two approaches. First, before we persist any event
we ensure that the state groups still exist and mark in the
`state_groups_persisting` table that the state group is about to be used.
(Note that we have to have the extra table here as state groups and events
can be in different databases, and thus we can't check for the existence of
state groups in the persist event transaction). Once the event has been
persisted, we can remove the row from `state_groups_persisting`. So long as
we check that table before deleting state groups, we can ensure that we
never persist events that reference deleted state groups, maintaining
database integrity.
However, we want to avoid throwing exceptions so deep in the process of
persisting events. So we use a concept of `state_epochs`, where we mark
state groups as pending/proposed for deletion and wait for a certain number
epoch increments before performing the deletion. When we come to handle new
events that reference state groups, we check if they are pending deletion
and bump the epoch when they'll be deleted in (to give a chance for the
event to be persisted, or not).
"""
# How frequently, roughly, to increment epochs.
TIME_BETWEEN_EPOCH_INCREMENTS_MS = 5 * 60 * 1000
# The number of epoch increases that must have happened between marking a
# state group as pending and actually deleting it.
NUMBER_EPOCHS_BEFORE_DELETION = 3
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self._clock = hs.get_clock()
self.db_pool = database
self._instance_name = hs.get_instance_name()
# TODO: Clear from `state_groups_persisting` any holdovers from previous
# running instance.
if hs.config.worker.run_background_tasks:
# Add a background loop to periodically check if we should bump
# state epoch.
self._clock.looping_call_now(
self._advance_state_epoch, self.TIME_BETWEEN_EPOCH_INCREMENTS_MS / 5
)
@wrap_as_background_process("_advance_state_epoch")
async def _advance_state_epoch(self) -> None:
"""Advances the state epoch, checking that we haven't advanced it too
recently.
"""
now = self._clock.time_msec()
update_if_before_ts = now - self.TIME_BETWEEN_EPOCH_INCREMENTS_MS
def advance_state_epoch_txn(txn: LoggingTransaction) -> None:
sql = """
UPDATE state_epoch
SET state_epoch = state_epoch + 1, updated_ts = ?
WHERE updated_ts <= ?
"""
txn.execute(sql, (now, update_if_before_ts))
await self.db_pool.runInteraction(
"_advance_state_epoch", advance_state_epoch_txn, db_autocommit=True
)
async def get_state_epoch(self) -> int:
"""Get the current state epoch"""
return await self.db_pool.simple_select_one_onecol(
table="state_epoch",
retcol="state_epoch",
keyvalues={},
desc="get_state_epoch",
)
async def check_state_groups_and_bump_deletion(
self, state_groups: AbstractSet[int]
) -> Collection[int]:
"""Checks to make sure that the state groups haven't been deleted, and
if they're pending deletion we delay it (allowing time for any event
that will use them to finish persisting).
Returns:
The state groups that are missing, if any.
"""
return await self.db_pool.runInteraction(
"check_state_groups_and_bump_deletion",
self._check_state_groups_and_bump_deletion_txn,
state_groups,
)
def _check_state_groups_and_bump_deletion_txn(
self, txn: LoggingTransaction, state_groups: AbstractSet[int]
) -> Collection[int]:
existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups)
if state_groups - existing_state_groups:
return state_groups - existing_state_groups
clause, args = make_in_list_sql_clause(
self.db_pool.engine, "state_group", state_groups
)
sql = f"""
UPDATE state_groups_pending_deletion
SET state_epoch = (SELECT state_epoch FROM state_epoch)
WHERE {clause}
"""
txn.execute(sql, args)
return ()
def _get_existing_groups_with_lock(
self, txn: LoggingTransaction, state_groups: Collection[int]
) -> AbstractSet[int]:
"""Return which of the given state groups are in the database, and locks
those rows with `KEY SHARE` to ensure they don't get concurrently
deleted."""
clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups)
sql = f"""
SELECT id FROM state_groups
WHERE {clause}
"""
if isinstance(self.db_pool.engine, PostgresEngine):
# On postgres we add a row level lock to the rows to ensure that we
# conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will
# not conflict with other read
sql += """
FOR KEY SHARE
"""
txn.execute(sql, args)
return {state_group for (state_group,) in txn}
@contextlib.asynccontextmanager
async def persisting_state_group_references(
self, event_and_contexts: Collection[Tuple[EventBase, EventContext]]
) -> AsyncIterator[None]:
"""Wraps the persistence of the given events and contexts, ensuring that
any state groups referenced still exist and that they don't get deleted
during this."""
referenced_state_groups: Set[int] = set()
state_epochs = []
for event, ctx in event_and_contexts:
if ctx.rejected or event.internal_metadata.is_outlier():
continue
assert ctx.state_epoch is not None
assert ctx.state_group is not None
state_epochs.append(ctx.state_epoch)
referenced_state_groups.add(ctx.state_group)
if ctx.state_group_before_event:
referenced_state_groups.add(ctx.state_group_before_event)
if not referenced_state_groups:
# We don't reference any state groups, so nothing to do
yield
return
assert state_epochs # If we have state groups we have a state epoch
min_state_epoch = min(state_epochs)
await self.db_pool.runInteraction(
"mark_state_groups_as_used",
self._mark_state_groups_as_used_txn,
min_state_epoch,
referenced_state_groups,
)
try:
yield None
finally:
await self.db_pool.simple_delete_many(
table="state_groups_persisting",
column="state_group",
iterable=referenced_state_groups,
keyvalues={"instance_name": self._instance_name},
desc="persisting_state_group_references_delete",
)
def _mark_state_groups_as_used_txn(
self, txn: LoggingTransaction, min_state_epoch: int, state_groups: Set[int]
) -> None:
"""Marks the given state groups as used. Also checks that the given
state epoch is not too old."""
existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups)
missing_state_groups = state_groups - existing_state_groups
if missing_state_groups:
raise Exception(
f"state groups have been deleted: {shortstr(missing_state_groups)}"
)
# Make sure the state epoch isn't too old.
current_epoch = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_epoch",
retcol="state_epoch",
keyvalues={},
)
if current_epoch - min_state_epoch > self.NUMBER_EPOCHS_BEFORE_DELETION:
raise Exception("Event took too long to persist")
self.db_pool.simple_delete_many_batch_txn(
txn,
table="state_groups_pending_deletion",
keys=("state_group",),
values=[(state_group,) for state_group in state_groups],
)
self.db_pool.simple_insert_many_txn(
txn,
table="state_groups_persisting",
keys=("state_group", "instance_name"),
values=[(state_group, self._instance_name) for state_group in state_groups],
)
def get_state_groups_that_can_be_purged_txn(
self, txn: LoggingTransaction, state_groups: Collection[int]
) -> Collection[int]:
"""Given a set of state groups, return which state groups can be deleted."""
if not state_groups:
return state_groups
if isinstance(self.db_pool.engine, PostgresEngine):
# On postgres we want to lock the rows FOR UPDATE as early as
# possible to help conflicts.
clause, args = make_in_list_sql_clause(
self.db_pool.engine, "id", state_groups
)
sql = """
SELECT id FROM state_groups
WHERE {clause}
FOR UPDATE
"""
txn.execute(sql, args)
current_state_epoch = self.db_pool.simple_select_one_onecol_txn(
txn,
table="state_epoch",
retcol="state_epoch",
keyvalues={},
)
# Check the deletion status in the DB of the given state groups
clause, args = make_in_list_sql_clause(
self.db_pool.engine, column="state_group", iterable=state_groups
)
sql = f"""
SELECT state_group, state_epoch FROM (
SELECT state_group, state_epoch FROM state_groups_pending_deletion
UNION
SELECT state_group, null FROM state_groups_persisting
) AS s
WHERE {clause}
"""
txn.execute(sql, args)
can_delete = set()
for state_group, state_epoch in txn:
if state_epoch is None:
# A null state epoch means that we are currently persisting
# events that reference the state group, so we don't delete
# them.
continue
if current_state_epoch - state_epoch < self.NUMBER_EPOCHS_BEFORE_DELETION:
# Not enough state epochs have occurred to allow us to delete.
continue
can_delete.add(state_group)
return can_delete

View File

@@ -36,7 +36,10 @@ import attr
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase
from synapse.events.snapshot import (
UnpersistedEventContext,
UnpersistedEventContextBase,
)
from synapse.logging.opentracing import tag_args, trace
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -55,6 +58,7 @@ from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.state.epochs import StateEpochDataStore
logger = logging.getLogger(__name__)
@@ -83,8 +87,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
epoch_store: "StateEpochDataStore",
):
super().__init__(database, db_conn, hs)
self._epoch_store = epoch_store
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
@@ -467,14 +473,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
Returns:
A list of state groups
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
# We need to check that the prev group isn't about to be deleted
is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
{prev_group},
)
if not is_in_db:
if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
@@ -546,6 +551,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
for key, state_id in context.state_delta_due_to_event.items()
],
)
return events_and_context
return await self.db_pool.runInteraction(
@@ -601,14 +607,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
The state group if successfully created, or None if the state
needs to be persisted as a full state.
"""
is_in_db = self.db_pool.simple_select_one_onecol_txn(
# We need to check that the prev group isn't about to be deleted
is_missing = self._epoch_store._check_state_groups_and_bump_deletion_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
{prev_group},
)
if not is_in_db:
if is_missing:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)

View File

@@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 88 # remember to update the list below when updating
SCHEMA_VERSION = 89 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the

View File

@@ -0,0 +1,37 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2024 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
CREATE TABLE IF NOT EXISTS state_epoch (
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
state_epoch BIGINT NOT NULL,
updated_ts BIGINT NOT NULL,
CHECK (Lock='X')
);
INSERT INTO state_epoch (state_epoch, updated_ts) VALUES (0, 0);
CREATE TABLE IF NOT EXISTS state_groups_pending_deletion (
state_group BIGINT NOT NULL,
state_epoch BIGINT NOT NULL,
PRIMARY KEY (state_group, state_epoch)
);
CREATE INDEX state_groups_pending_deletion_epoch ON state_groups_pending_deletion(state_epoch);
CREATE TABLE IF NOT EXISTS state_groups_persisting (
state_group BIGINT NOT NULL,
instance_name TEXT NOT NULL,
PRIMARY KEY (state_group, instance_name)
);

View File

@@ -807,6 +807,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
epoch_store = self.hs.get_datastores().state_epochs
# Create the room.
kermit_user_id = self.register_user("kermit", "test")
@@ -958,7 +959,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
state_res_store=StateResolutionStore(main_store),
state_res_store=StateResolutionStore(main_store, epoch_store),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -1003,7 +1004,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
state_res_store=StateResolutionStore(main_store),
state_res_store=StateResolutionStore(main_store, epoch_store),
full_conflicted_set=set(),
)
),

View File

@@ -441,7 +441,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
persist_events_store._store_event_txn(
txn,
[
(e, EventContext(self.hs.get_storage_controllers(), {}))
(e, EventContext(self.hs.get_storage_controllers(), 1, {}))
for e in events
],
)