Compare commits

...

9 Commits

Author SHA1 Message Date
Erik Johnston
90e1df262d Handle SQLite 2024-06-26 13:14:48 +01:00
Erik Johnston
b8b6fe0da3 Remove debug logging 2024-06-26 13:14:42 +01:00
Erik Johnston
fe9fa90af4 Add a cache to auth links 2024-05-17 10:35:03 +01:00
Erik Johnston
4ffe5a4459 Up batch size 2024-05-14 10:26:31 +01:00
Erik Johnston
b58ed63884 Go faster stripes 2024-05-13 13:10:16 +01:00
Erik Johnston
2d4cea496b Debug logging 2024-05-13 09:54:32 +01:00
Erik Johnston
ca79b4d87d Use a sortedset instead 2024-05-09 10:58:00 +01:00
Erik Johnston
202a09cdb3 Newsfile 2024-05-08 16:05:24 +01:00
Erik Johnston
db25e30a25 Perf improvement to getting auth chains 2024-05-08 16:04:35 +01:00
3 changed files with 179 additions and 7 deletions

1
changelog.d/17169.misc Normal file
View File

@@ -0,0 +1 @@
Add database performance improvement when fetching auth chains.

View File

@@ -21,6 +21,7 @@
import datetime
import itertools
import logging
import time
from queue import Empty, PriorityQueue
from typing import (
TYPE_CHECKING,
@@ -39,6 +40,7 @@ from typing import (
import attr
from prometheus_client import Counter, Gauge
from sortedcontainers import SortedSet
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
@@ -118,6 +120,11 @@ class BackfillQueueNavigationItem:
type: str
@attr.s(frozen=True, slots=True, auto_attribs=True)
class _ChainLinksCacheEntry:
links: List[Tuple[int, int, int, "_ChainLinksCacheEntry"]] = attr.Factory(list)
class _NoChainCoverIndex(Exception):
def __init__(self, room_id: str):
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
@@ -138,6 +145,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self.hs = hs
self._chain_links_cache: LruCache[int, _ChainLinksCacheEntry] = LruCache(
max_size=10000, cache_name="chain_links_cache"
)
if hs.config.worker.run_background_tasks:
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
@@ -283,7 +294,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, set(event_chains.keys())):
for links in self._get_chain_links(
txn, event_chains.keys(), self._chain_links_cache
):
for chain_id in links:
if chain_id not in event_chains:
continue
@@ -335,7 +348,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@classmethod
def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
cls,
txn: LoggingTransaction,
chains_to_fetch: Collection[int],
cache: Optional[LruCache[int, _ChainLinksCacheEntry]] = None,
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively.
@@ -347,12 +363,55 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
of origin sequence number, target chain ID and target sequence number.
"""
found_cached_chains = set()
if cache:
entries: Dict[int, _ChainLinksCacheEntry] = {}
for chain_id in chains_to_fetch:
entry = cache.get(chain_id)
if entry:
entries[chain_id] = entry
cached_links: Dict[int, List[Tuple[int, int, int]]] = {}
while entries:
origin_chain_id, entry = entries.popitem()
for (
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
) in entry.links:
if target_chain_id in found_cached_chains:
continue
found_cached_chains.add(target_chain_id)
cache.get(chain_id)
entries[chain_id] = target_entry
cached_links.setdefault(origin_chain_id, []).append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
)
)
yield cached_links
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
if isinstance(txn.database_engine, PostgresEngine):
# JIT and sequential scans sometimes get hit on this code path, which
# can make the queries much more expensive
txn.execute("SET LOCAL jit = off")
txn.execute("SET LOCAL enable_seqscan = off")
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
@@ -371,9 +430,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
# We fetch the links in batches. Separate batches will likely fetch the
# same set of links (e.g. they'll always pull in the links to create
# event). To try and minimize the amount of redundant links, we query
# the chain IDs in reverse order, as there will be a correlation between
# the order of chain IDs and links (i.e., higher chain IDs are more
# likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 5000
chains_to_fetch_sorted = SortedSet(chains_to_fetch)
chains_to_fetch_sorted.difference_update(found_cached_chains)
start_block = time.monotonic()
while chains_to_fetch_sorted:
batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE))
chains_to_fetch_sorted.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
@@ -381,6 +453,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
links: Dict[int, List[Tuple[int, int, int]]] = {}
cache_entries: Dict[int, _ChainLinksCacheEntry] = {}
for (
origin_chain_id,
origin_sequence_number,
@@ -391,10 +465,33 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
(origin_sequence_number, target_chain_id, target_sequence_number)
)
chains_to_fetch.difference_update(links)
if cache:
origin_entry = cache_entries.setdefault(
origin_chain_id, _ChainLinksCacheEntry()
)
target_entry = cache_entries.setdefault(
target_chain_id, _ChainLinksCacheEntry()
)
origin_entry.links.append(
(
origin_sequence_number,
target_chain_id,
target_sequence_number,
target_entry,
)
)
if cache:
for chain_id, entry in cache_entries.items():
if chain_id not in cache:
cache[chain_id] = entry
chains_to_fetch_sorted.difference_update(links)
yield links
end_block = time.monotonic()
def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> Set[str]:
@@ -581,7 +678,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# are reachable from any event.
# (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)):
for links in self._get_chain_links(txn, seen_chains, self._chain_links_cache):
for chains in set_to_chain:
for chain_id in links:
if chain_id not in chains:

View File

@@ -25,6 +25,7 @@ from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.util import Clock
from tests.test_utils.event_injection import inject_event
from tests.unittest import HomeserverTestCase
@@ -128,3 +129,76 @@ class PurgeTests(HomeserverTestCase):
self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
def test_state_groups_state_decreases(self) -> None:
response = self.helper.send(self.room_id, body="first")
first_event_id = response["event_id"]
batches = []
previous_event_id = first_event_id
for i in range(50):
state_event1 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 1},
prev_event_ids=[previous_event_id],
origin_server_ts=1,
)
)
state_event2 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 2},
prev_event_ids=[previous_event_id],
origin_server_ts=2,
)
)
# print(state_event2.origin_server_ts - state_event1.origin_server_ts)
message_event = self.get_success(
inject_event(
self.hs,
type="dummy_event",
sender=self.user_id,
room_id=self.room_id,
content={},
prev_event_ids=[state_event1.event_id, state_event2.event_id],
)
)
token = self.get_success(
self.store.get_topological_token_for_event(state_event1.event_id)
)
batches.append(token)
previous_event_id = message_event.event_id
self.helper.send(self.room_id, body="last event")
def count_state_groups() -> int:
sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?"
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
return rows[0][0]
print(count_state_groups())
for token in batches:
token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
self.get_success(
self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, False
)
)
print(count_state_groups())