Fix case where get_partial_current_state_deltas could return >100 rows (#18960)

This commit is contained in:
Andrew Morgan
2025-11-26 17:17:04 +00:00
committed by GitHub
parent c928347779
commit 703464c1f7
4 changed files with 378 additions and 32 deletions

1
changelog.d/18960.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug in the database function for fetching state deltas that could result in unnecessarily long query times.

View File

@@ -683,7 +683,7 @@ class StateStorageController:
# https://github.com/matrix-org/synapse/issues/13008
return await self.stores.main.get_partial_current_state_deltas(
prev_stream_id, max_stream_id
prev_stream_id, max_stream_id, limit=100
)
@trace

View File

@@ -78,27 +78,41 @@ class StateDeltasStore(SQLBaseStore):
)
async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
self, prev_stream_id: int, max_stream_id: int, limit: int = 100
) -> tuple[int, list[StateDelta]]:
"""Fetch a list of room state changes since the given stream id
"""Fetch a list of room state changes since the given stream id.
This may be the partial state if we're lazy joining the room.
This method takes care to handle state deltas that share the same
`stream_id`. That can happen when persisting state in a batch,
potentially as the result of state resolution (both adding new state and
undo'ing previous state).
State deltas are grouped by `stream_id`. When hitting the given `limit`
would return only part of a "group" of state deltas, that entire group
is omitted. Thus, this function may return *up to* `limit` state deltas,
or slightly more when a single group itself exceeds `limit`.
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
limit: the maximum number of rows to return.
Returns:
A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
A maximum of 100 rows will be returned.
"""
prev_stream_id = int(prev_stream_id)
if limit <= 0:
raise ValueError(
"Invalid `limit` passed to `get_partial_current_state_deltas"
)
# check we're not going backwards
assert prev_stream_id <= max_stream_id, (
f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}"
@@ -115,45 +129,62 @@ class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas_txn(
txn: LoggingTransaction,
) -> tuple[int, list[StateDelta]]:
# First we calculate the max stream id that will give us less than
# N results.
# We arbitrarily limit to 100 stream_id entries to ensure we don't
# select toooo many.
sql = """
SELECT stream_id, count(*)
# First we group state deltas by `stream_id` and calculate which
# groups can be returned without exceeding the provided `limit`.
sql_grouped = """
SELECT stream_id, COUNT(*) AS c
FROM current_state_delta_stream
WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
ORDER BY stream_id
LIMIT ?
"""
txn.execute(sql, (prev_stream_id, max_stream_id))
group_limit = limit + 1
txn.execute(sql_grouped, (prev_stream_id, max_stream_id, group_limit))
grouped_rows = txn.fetchall()
total = 0
if not grouped_rows:
# Nothing to return in the range; we are up to date through max_stream_id.
return max_stream_id, []
for stream_id, count in txn:
total += count
if total > 100:
# We arbitrarily limit to 100 entries to ensure we don't
# select toooo many.
logger.debug(
"Clipping current_state_delta_stream rows to stream_id %i",
stream_id,
)
clipped_stream_id = stream_id
# Always retrieve the first group, at the bare minimum. This ensures the
# caller always makes progress, even if a single group exceeds `limit`.
fetch_upto_stream_id, included_rows = grouped_rows[0]
# Determine which other groups we can retrieve at the same time,
# without blowing the budget.
included_all_groups = True
for stream_id, count in grouped_rows[1:]:
if included_rows + count > limit:
included_all_groups = False
break
else:
# if there's no problem, we may as well go right up to the max_stream_id
clipped_stream_id = max_stream_id
included_rows += count
fetch_upto_stream_id = stream_id
# Now actually get the deltas
sql = """
# If we retrieved fewer groups than the limit *and* we didn't hit the
# `LIMIT ?` cap on the grouping query, we know we've caught up with
# the stream.
caught_up_with_stream = (
included_all_groups and len(grouped_rows) < group_limit
)
# At this point we should have advanced, or bailed out early above.
assert fetch_upto_stream_id != prev_stream_id
# 2) Fetch the actual rows for only the included stream_id groups.
sql_rows = """
SELECT stream_id, room_id, type, state_key, event_id, prev_event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, clipped_stream_id))
txn.execute(sql_rows, (prev_stream_id, fetch_upto_stream_id))
rows = txn.fetchall()
clipped_stream_id = (
max_stream_id if caught_up_with_stream else fetch_upto_stream_id
)
return clipped_stream_id, [
StateDelta(
stream_id=row[0],
@@ -163,7 +194,7 @@ class StateDeltasStore(SQLBaseStore):
event_id=row[4],
prev_event_id=row[5],
)
for row in txn.fetchall()
for row in rows
]
return await self.db_pool.runInteraction(

View File

@@ -19,6 +19,7 @@
#
#
import json
import logging
from typing import cast
@@ -33,6 +34,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.types.state import StateFilter
from synapse.util.clock import Clock
from synapse.util.stringutils import random_string
from tests.unittest import HomeserverTestCase
@@ -643,3 +645,315 @@ class StateStoreTestCase(HomeserverTestCase):
),
)
self.assertEqual(context.state_group_before_event, groups[0][0])
class CurrentStateDeltaStreamTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.store = hs.get_datastores().main
self.storage = hs.get_storage_controllers()
self.state_datastore = self.storage.state.stores.state
self.event_creation_handler = hs.get_event_creation_handler()
self.event_builder_factory = hs.get_event_builder_factory()
# Create a made-up room and a user.
self.alice_user_id = UserID.from_string("@alice:test")
self.room = RoomID.from_string("!abc1234:test")
self.get_success(
self.store.store_room(
self.room.to_string(),
room_creator_user_id="@creator:text",
is_public=True,
room_version=RoomVersions.V1,
)
)
def inject_state_event(
self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
{
"type": typ,
"sender": sender.to_string(),
"state_key": state_key,
"room_id": room.to_string(),
"content": content,
},
)
event, unpersisted_context = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
context = self.get_success(unpersisted_context.persist(event))
assert self.storage.persistence is not None
self.get_success(self.storage.persistence.persist_event(event, context))
return event
def test_get_partial_current_state_deltas_limit(self) -> None:
"""
Tests that `get_partial_current_state_deltas` actually returns `limit` rows.
Regression test for https://github.com/element-hq/synapse/pull/18960.
"""
# Inject a create event which other events can auth with.
self.inject_state_event(
self.room, self.alice_user_id, EventTypes.Create, "", {}
)
limit = 2
# Make N*2 state changes in the room, resulting in 2N+1 total state
# events (including the create event) in the room.
for i in range(limit * 2):
self.inject_state_event(
self.room,
self.alice_user_id,
EventTypes.Name,
"",
{"name": f"rename #{i}"},
)
# Call the function under test. This must return <= `limit` rows.
max_stream_id = self.store.get_room_max_stream_ordering()
clipped_stream_id, deltas = self.get_success(
self.store.get_partial_current_state_deltas(
prev_stream_id=0,
max_stream_id=max_stream_id,
limit=limit,
)
)
self.assertLessEqual(
len(deltas), limit, f"Returned {len(deltas)} rows, expected at most {limit}"
)
# Advancing from the clipped point should eventually drain the remainder.
# Make sure we make progress and dont get stuck.
if deltas:
next_prev = clipped_stream_id
next_clipped, next_deltas = self.get_success(
self.store.get_partial_current_state_deltas(
prev_stream_id=next_prev, max_stream_id=max_stream_id, limit=limit
)
)
self.assertNotEqual(
next_clipped, clipped_stream_id, "Did not advance clipped_stream_id"
)
# Still should respect the limit.
self.assertLessEqual(len(next_deltas), limit)
def test_non_unique_stream_ids_in_current_state_delta_stream(self) -> None:
"""
Tests that `get_partial_current_state_deltas` always returns entire
groups of state deltas (grouped by `stream_id`), and never part of one.
We check by passing a `limit` that to the function that, if followed
blindly, would split a group of state deltas that share a `stream_id`.
The test passes if that group is not returned at all (because doing so
would overshoot the limit of returned state deltas).
Regression test for https://github.com/element-hq/synapse/pull/18960.
"""
# Inject a create event to start with.
self.inject_state_event(
self.room, self.alice_user_id, EventTypes.Create, "", {}
)
# Then inject one "real" m.room.name event. This will give us a stream_id that
# we can create some more (fake) events with.
self.inject_state_event(
self.room,
self.alice_user_id,
EventTypes.Name,
"",
{"name": "rename #1"},
)
# Get the stream_id of the last-inserted event.
max_stream_id = self.store.get_room_max_stream_ordering()
# Make 3 more state changes in the room, resulting in 5 total state
# events (including the create event, and the first name update) in
# the room.
#
# All of these state deltas have the same `stream_id` as the original name event.
# Do so by editing the table directly as that's the simplest way to have
# all share the same `stream_id`.
self.get_success(
self.store.db_pool.simple_insert_many(
"current_state_delta_stream",
keys=(
"stream_id",
"room_id",
"type",
"state_key",
"event_id",
"prev_event_id",
"instance_name",
),
values=[
(
max_stream_id,
self.room.to_string(),
EventTypes.Name,
"",
f"${random_string(5)}:test",
json.dumps({"name": f"rename #{i}"}),
"master",
)
for i in range(3)
],
desc="inject_room_name_state_events",
)
)
# Call the function under test with a limit of 4. Without the limit, we
# would return 5 state deltas:
#
# C N N N N
# 1 2 3 4 5
#
# C = m.room.create
# N = m.room.name
#
# With the limit, we should return only the create event, as returning 4
# state deltas would result in splitting a group:
#
# 2 3 3 3 3 - state IDs/groups
# C N N N N
# 1 2 3 4 X
clipped_stream_id, deltas = self.get_success(
self.store.get_partial_current_state_deltas(
prev_stream_id=0,
max_stream_id=max_stream_id,
limit=4,
)
)
# 2 is the stream ID of the m.room.create event.
self.assertEqual(clipped_stream_id, 2)
self.assertEqual(
len(deltas),
1,
f"Returned {len(deltas)} rows, expected only one (the create event): {deltas}",
)
# Advance once more with our limit of 4. We should now get all 4
# `m.room.name` state deltas as they can fit under the limit.
clipped_stream_id, next_deltas = self.get_success(
self.store.get_partial_current_state_deltas(
prev_stream_id=clipped_stream_id, max_stream_id=max_stream_id, limit=4
)
)
self.assertEqual(
clipped_stream_id, 3
) # The stream ID of the 4 m.room.name events.
self.assertEqual(
len(next_deltas),
4,
f"Returned {len(next_deltas)} rows, expected all 4 m.room.name events: {next_deltas}",
)
def test_get_partial_current_state_deltas_does_not_enter_infinite_loop(
self,
) -> None:
"""
Tests that `get_partial_current_state_deltas` does not repeatedly return
zero entries due to the passed `limit` parameter being less than the
size of the next group of state deltas from the given `prev_stream_id`.
"""
# Inject a create event to start with.
self.inject_state_event(
self.room, self.alice_user_id, EventTypes.Create, "", {}
)
# Then inject one "real" m.room.name event. This will give us a stream_id that
# we can create some more (fake) events with.
self.inject_state_event(
self.room,
self.alice_user_id,
EventTypes.Name,
"",
{"name": "rename #1"},
)
# Get the stream_id of the last-inserted event.
max_stream_id = self.store.get_room_max_stream_ordering()
# Make 3 more state changes in the room, resulting in 5 total state
# events (including the create event, and the first name update) in
# the room.
#
# All of these state deltas have the same `stream_id` as the original name event.
# Do so by editing the table directly as that's the simplest way to have
# all share the same `stream_id`.
self.get_success(
self.store.db_pool.simple_insert_many(
"current_state_delta_stream",
keys=(
"stream_id",
"room_id",
"type",
"state_key",
"event_id",
"prev_event_id",
"instance_name",
),
values=[
(
max_stream_id,
self.room.to_string(),
EventTypes.Name,
"",
f"${random_string(5)}:test",
json.dumps({"name": f"rename #{i}"}),
"master",
)
for i in range(3)
],
desc="inject_room_name_state_events",
)
)
# Call the function under test with a limit of 4. Without the limit, we would return
# 5 state deltas:
#
# C N N N N
# 1 2 3 4 5
#
# C = m.room.create
# N = m.room.name
#
# With the limit, we should return only the create event, as returning 4
# state deltas would result in splitting a group:
#
# 2 3 3 3 3 - state IDs/groups
# C N N N N
# 1 2 3 4 X
clipped_stream_id, deltas = self.get_success(
self.store.get_partial_current_state_deltas(
prev_stream_id=2, # Start after the create event (which has stream_id 2).
max_stream_id=max_stream_id,
limit=2, # Less than the size of the next group (which is 4).
)
)
self.assertEqual(
clipped_stream_id, 3
) # The stream ID of the 4 m.room.name events.
# We should get all 4 `m.room.name` state deltas, instead of 0, which
# would result in the caller entering an infinite loop.
self.assertEqual(
len(deltas),
4,
f"Returned {len(deltas)} rows, expected 4 even though it broke our limit: {deltas}",
)