Compare commits

...

11 Commits

Author SHA1 Message Date
Erik Johnston
9339b8b8ea Lint 2024-06-26 16:48:08 +01:00
Erik Johnston
dac74db74e Merge remote-tracking branch 'origin/develop' into erikj/faster_auth_chains 2024-06-26 13:15:14 +01:00
Erik Johnston
90e1df262d Handle SQLite 2024-06-26 13:14:48 +01:00
Erik Johnston
b8b6fe0da3 Remove debug logging 2024-06-26 13:14:42 +01:00
Erik Johnston
fe9fa90af4 Add a cache to auth links 2024-05-17 10:35:03 +01:00
Erik Johnston
4ffe5a4459 Up batch size 2024-05-14 10:26:31 +01:00
Erik Johnston
b58ed63884 Go faster stripes 2024-05-13 13:10:16 +01:00
Erik Johnston
2d4cea496b Debug logging 2024-05-13 09:54:32 +01:00
Erik Johnston
ca79b4d87d Use a sortedset instead 2024-05-09 10:58:00 +01:00
Erik Johnston
202a09cdb3 Newsfile 2024-05-08 16:05:24 +01:00
Erik Johnston
db25e30a25 Perf improvement to getting auth chains 2024-05-08 16:04:35 +01:00
3 changed files with 174 additions and 7 deletions

1
changelog.d/17169.misc Normal file
View File

@@ -0,0 +1 @@
Add database performance improvement when fetching auth chains.

View File

@@ -39,6 +39,7 @@ from typing import (
import attr import attr
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
from sortedcontainers import SortedSet
from synapse.api.constants import MAX_DEPTH from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@@ -118,6 +119,11 @@ class BackfillQueueNavigationItem:
type: str type: str
@attr.s(frozen=True, slots=True, auto_attribs=True)
class _ChainLinksCacheEntry:
links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list)
class _NoChainCoverIndex(Exception): class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str): def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,)) super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -138,6 +144,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self.hs = hs self.hs = hs
self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache(
max_size=10000, cache_name="chain_links_cache"
)
if hs.config.worker.run_background_tasks: if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call( hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000 self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@@ -289,7 +299,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# A map from chain ID to max sequence number *reachable* from any event ID. # A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {} chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, set(event_chains.keys())): for links in self._get_chain_links(
txn, event_chains.keys(), self._chain_links_cache
):
for chain_id in links: for chain_id in links:
if chain_id not in event_chains: if chain_id not in event_chains:
continue continue
@@ -341,7 +353,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@classmethod @classmethod
def _get_chain_links( def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int] cls,
txn: LoggingTransaction,
chains_to_fetch: Collection[int],
cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None,
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all """Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively. links from those chains, recursively.
@@ -353,12 +368,55 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
of origin sequence number, target chain ID and target sequence number. of origin sequence number, target chain ID and target sequence number.
""" """
found_cached_chains = set()
if cache:
entries: Dict[int, _ChainLinksCacheEntry] = {}
for chain_id in chains_to_fetch:
entry = cache.get(chain_id)
if entry:
entries[chain_id] = entry
cached_links: Dict[int, List[Tuple[int, int, int]]] = {}
while entries:
origin_chain_id, entry = entries.popitem()
for (
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
) in entry.links:
if target_chain_id in found_cached_chains:
continue
found_cached_chains.add(target_chain_id)
cache.get(chain_id)
entries[chain_id] = target_entry
cached_links.setdefault(origin_chain_id, []).append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
)
)
yield cached_links
# This query is structured to first get all chain IDs reachable, and # This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more # then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of # rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without # structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a # also returning large quantities of redundant data (which can make it a
# lot slower). # lot slower).
if isinstance(txn.database_engine, PostgresEngine):
# JIT and sequential scans sometimes get hit on this code path, which
# can make the queries much more expensive
txn.execute("SET LOCAL jit = off")
txn.execute("SET LOCAL enable_seqscan = off")
sql = """ sql = """
WITH RECURSIVE links(chain_id) AS ( WITH RECURSIVE links(chain_id) AS (
SELECT SELECT
@@ -377,9 +435,20 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
""" """
while chains_to_fetch: # We fetch the links in batches. Separate batches will likely fetch the
batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) # same set of links (e.g. they'll always pull in the links to create
chains_to_fetch.difference_update(batch2) # event). To try and minimize the amount of redundant links, we query
# the chain IDs in reverse order, as there will be a correlation between
# the order of chain IDs and links (i.e., higher chain IDs are more
# likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 5000
chains_to_fetch_sorted = SortedSet(chains_to_fetch)
chains_to_fetch_sorted.difference_update(found_cached_chains)
while chains_to_fetch_sorted:
batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE))
chains_to_fetch_sorted.difference_update(batch2)
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2 txn.database_engine, "origin_chain_id", batch2
) )
@@ -387,6 +456,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
links: Dict[int, List[Tuple[int, int, int]]] = {} links: Dict[int, List[Tuple[int, int, int]]] = {}
cache_entries: Dict[int, _ChainLinksCacheEntry] = {}
for ( for (
origin_chain_id, origin_chain_id,
origin_sequence_number, origin_sequence_number,
@@ -397,7 +468,28 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
(origin_sequence_number, target_chain_id, target_sequence_number) (origin_sequence_number, target_chain_id, target_sequence_number)
) )
chains_to_fetch.difference_update(links) if cache:
origin_entry = cache_entries.setdefault(
origin_chain_id, _ChainLinksCacheEntry()
)
target_entry = cache_entries.setdefault(
target_chain_id, _ChainLinksCacheEntry()
)
origin_entry.links.append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
)
)
if cache:
for chain_id, entry in cache_entries.items():
if chain_id not in cache:
cache[chain_id] = entry
chains_to_fetch_sorted.difference_update(links)
yield links yield links
@@ -589,7 +681,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# are reachable from any event. # are reachable from any event.
# (We need to take a copy of `seen_chains` as the function mutates it) # (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)): for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache):
for chains in set_to_chain: for chains in set_to_chain:
for chain_id in links: for chain_id in links:
if chain_id not in chains: if chain_id not in chains:

View File

@@ -25,6 +25,7 @@ from synapse.rest.client import room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils.event_injection import inject_event
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@@ -128,3 +129,76 @@ class PurgeTests(HomeserverTestCase):
self.store._invalidate_local_get_event_cache(create_event.event_id) self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
def test_state_groups_state_decreases(self) -> None:
response = self.helper.send(self.room_id, body="first")
first_event_id = response["event_id"]
batches = []
previous_event_id = first_event_id
for i in range(50):
state_event1 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 1},
prev_event_ids=[previous_event_id],
origin_server_ts=1,
)
)
state_event2 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 2},
prev_event_ids=[previous_event_id],
origin_server_ts=2,
)
)
# print(state_event2.origin_server_ts - state_event1.origin_server_ts)
message_event = self.get_success(
inject_event(
self.hs,
type="dummy_event",
sender=self.user_id,
room_id=self.room_id,
content={},
prev_event_ids=[state_event1.event_id, state_event2.event_id],
)
)
token = self.get_success(
self.store.get_topological_token_for_event(state_event1.event_id)
)
batches.append(token)
previous_event_id = message_event.event_id
self.helper.send(self.room_id, body="last event")
def count_state_groups() -> int:
sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?"
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
return rows[0][0]
print(count_state_groups())
for token in batches:
token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
self.get_success(
self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, False
)
)
print(count_state_groups())