Compare commits

..

2 Commits

Author SHA1 Message Date
Devon Hudson
7450052e60 Add changelog entry 2025-11-12 13:27:45 -07:00
Devon Hudson
a4345c391e Remove support for PostgreSQL 13. 2025-11-12 13:24:43 -07:00
27 changed files with 323 additions and 381 deletions

View File

@@ -72,7 +72,7 @@ trial_postgres_tests = [
{
"python-version": "3.10",
"database": "postgres",
"postgres-version": "13",
"postgres-version": "14",
"extras": "all",
},
{

View File

@@ -617,7 +617,7 @@ jobs:
matrix:
include:
- python-version: "3.10"
postgres-version: "13"
postgres-version: "14"
- python-version: "3.14"
postgres-version: "17"

View File

@@ -1 +0,0 @@
Provide additional servers with federation room directory results.

View File

@@ -1 +0,0 @@
Add a shortcut return when there are no events to purge.

View File

@@ -1 +0,0 @@
Remove authentication from `POST /_matrix/client/v1/delayed_events`, and allow calling this endpoint with the update action to take (`send`/`cancel`/`restart`) in the request path instead of the body.

View File

@@ -1 +0,0 @@
Point out which event caused the exception when checking [MSC4293](https://github.com/matrix-org/matrix-spec-proposals/pull/4293) redactions.

View File

@@ -0,0 +1 @@
Remove support for PostgreSQL 13.

View File

@@ -1 +0,0 @@
Restore printing `sentinel` for the log record `request` when no logcontext is active.

View File

@@ -1 +0,0 @@
Add debug logs to track `Clock` utilities.

View File

@@ -1 +0,0 @@
Run background updates on all databases.

View File

@@ -11,7 +11,7 @@ ARG SYNAPSE_VERSION=latest
ARG FROM=matrixdotorg/synapse-workers:$SYNAPSE_VERSION
ARG DEBIAN_VERSION=trixie
FROM docker.io/library/postgres:13-${DEBIAN_VERSION} AS postgres_base
FROM docker.io/library/postgres:14-${DEBIAN_VERSION} AS postgres_base
FROM $FROM
# First of all, we copy postgres server from the official postgres image,
@@ -26,7 +26,7 @@ RUN adduser --system --uid 999 postgres --home /var/lib/postgresql
COPY --from=postgres_base /usr/lib/postgresql /usr/lib/postgresql
COPY --from=postgres_base /usr/share/postgresql /usr/share/postgresql
COPY --from=postgres_base --chown=postgres /var/run/postgresql /var/run/postgresql
ENV PATH="${PATH}:/usr/lib/postgresql/13/bin"
ENV PATH="${PATH}:/usr/lib/postgresql/14/bin"
ENV PGDATA=/var/lib/postgresql/data
# We also initialize the database at build time, rather than runtime, so that it's faster to spin up the image.

View File

@@ -117,6 +117,14 @@ each upgrade are complete before moving on to the next upgrade, to avoid
stacking them up. You can monitor the currently running background updates with
[the Admin API](usage/administration/admin_api/background_updates.html#status).
# Upgrading to v1.143.0
## Dropping support for PostgreSQL 13
In line with our [deprecation policy](deprecation_policy.md), we've dropped
support for PostgreSQL 13, as it is no longer supported upstream.
This release of Synapse requires PostgreSQL 14+.
# Upgrading to v1.142.0
## Python 3.10+ is now required

View File

@@ -58,7 +58,6 @@ from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import FilteringWorkerStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
@@ -274,7 +273,6 @@ class Store(
RelationsWorkerStore,
EventFederationWorkerStore,
SlidingSyncStore,
DelayedEventsStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View File

@@ -450,8 +450,7 @@ async def start(
await _base.start(hs, freeze=freeze)
# TODO: Feels like this should be moved somewhere else.
for db in hs.get_datastores().databases:
db.updates.start_doing_background_updates()
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
def start_reactor(

View File

@@ -65,6 +65,8 @@ from typing import (
Sequence,
)
from twisted.internet.interfaces import IDelayedCall
from synapse.appservice import (
ApplicationService,
ApplicationServiceState,
@@ -76,7 +78,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.storage.databases.main import DataStore
from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util.clock import Clock, DelayedCallWrapper
from synapse.util.clock import Clock
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -501,7 +503,7 @@ class _Recoverer:
self.service = service
self.callback = callback
self.backoff_counter = 1
self.scheduled_recovery: DelayedCallWrapper | None = None
self.scheduled_recovery: IDelayedCall | None = None
def recover(self) -> None:
delay = 2**self.backoff_counter

View File

@@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import ShadowBanError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
@@ -30,9 +29,11 @@ from synapse.replication.http.delayed_events import (
)
from synapse.storage.databases.main.delayed_events import (
DelayedEventDetails,
DelayID,
EventType,
StateKey,
Timestamp,
UserLocalpart,
)
from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import (
@@ -398,63 +399,96 @@ class DelayedEventsHandler:
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def cancel(self, request: SynapseRequest, delay_id: str) -> None:
async def cancel(self, requester: Requester, delay_id: str) -> None:
"""
Cancels the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit(
None, request.getClientAddress().host
requester,
(requester.user.to_string(), requester.device_id),
)
await make_deferred_yieldable(self._initialized_from_db)
next_send_ts = await self._store.cancel_delayed_event(delay_id)
next_send_ts = await self._store.cancel_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
async def restart(self, request: SynapseRequest, delay_id: str) -> None:
async def restart(self, requester: Requester, delay_id: str) -> None:
"""
Restarts the scheduled delivery of the matching delayed event.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit(
None, request.getClientAddress().host
requester,
(requester.user.to_string(), requester.device_id),
)
await make_deferred_yieldable(self._initialized_from_db)
next_send_ts = await self._store.restart_delayed_event(
delay_id, self._get_current_ts()
delay_id=delay_id,
user_localpart=requester.user.localpart,
current_ts=self._get_current_ts(),
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at(next_send_ts)
async def send(self, request: SynapseRequest, delay_id: str) -> None:
async def send(self, requester: Requester, delay_id: str) -> None:
"""
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
Args:
requester: The owner of the delayed event to act on.
delay_id: The ID of the delayed event to act on.
Raises:
NotFoundError: if no matching delayed event could be found.
"""
assert self._is_master
await self._delayed_event_mgmt_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
# Use standard request limiter for sending delayed events on-demand,
# as an on-demand send is similar to sending a regular event.
await self._request_ratelimiter.ratelimit(requester)
await make_deferred_yieldable(self._initialized_from_db)
event, next_send_ts = await self._store.process_target_delayed_event(delay_id)
event, next_send_ts = await self._store.process_target_delayed_event(
delay_id=delay_id,
user_localpart=requester.user.localpart,
)
if self._next_send_ts_changed(next_send_ts):
self._schedule_next_at_or_none(next_send_ts)
await self._send_event(event)
await self._send_event(
DelayedEventDetails(
delay_id=DelayID(delay_id),
user_localpart=UserLocalpart(requester.user.localpart),
room_id=event.room_id,
type=event.type,
state_key=event.state_key,
origin_server_ts=event.origin_server_ts,
content=event.content,
device_id=event.device_id,
)
)
async def _send_on_timeout(self) -> None:
self._next_delayed_event_call = None
@@ -577,7 +611,9 @@ class DelayedEventsHandler:
finally:
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
try:
await self._store.delete_processed_delayed_event(event.delay_id)
await self._store.delete_processed_delayed_event(
event.delay_id, event.user_localpart
)
except Exception:
logger.exception("Failed to delete processed delayed event")

View File

@@ -321,7 +321,16 @@ class DirectoryHandler:
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
return await self.get_association(room_alias)
result = await self.get_association_from_room_alias(room_alias)
if result is not None:
return {"room_id": result.room_id, "servers": result.servers}
else:
raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
Codes.NOT_FOUND,
)
async def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias

View File

@@ -619,24 +619,19 @@ class LoggingContextFilter(logging.Filter):
True to include the record in the log output.
"""
context = current_context()
# type-ignore: `context` should never be `None`, but if it somehow ends up
# being, then we end up in a death spiral of infinite loops, so let's check, for
record.request = self._default_request
# Avoid overwriting an existing `server_name` on the record. This is running in
# the context of a global log record filter so there may be 3rd-party code that
# adds their own `server_name` and we don't want to interfere with that
# (clobber).
if not hasattr(record, "server_name"):
record.server_name = "unknown_server_from_no_logcontext"
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
#
# Add some default values to avoid log formatting errors.
if context is None:
record.request = self._default_request # type: ignore[unreachable]
# Avoid overwriting an existing `server_name` on the record. This is running in
# the context of a global log record filter so there may be 3rd-party code that
# adds their own `server_name` and we don't want to interfere with that
# (clobber).
if not hasattr(record, "server_name"):
record.server_name = "unknown_server_from_no_logcontext"
# Otherwise, in the normal, expected case, fill in the log record attributes
# from the logcontext.
else:
if context is not None:
def safe_set(attr: str, value: Any) -> None:
"""

View File

@@ -47,11 +47,14 @@ class UpdateDelayedEventServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
try:
action = str(body["action"])
@@ -72,65 +75,11 @@ class UpdateDelayedEventServlet(RestServlet):
)
if enum_action == _UpdateDelayedEventAction.CANCEL:
await self.delayed_events_handler.cancel(request, delay_id)
await self.delayed_events_handler.cancel(requester, delay_id)
elif enum_action == _UpdateDelayedEventAction.RESTART:
await self.delayed_events_handler.restart(request, delay_id)
await self.delayed_events_handler.restart(requester, delay_id)
elif enum_action == _UpdateDelayedEventAction.SEND:
await self.delayed_events_handler.send(request, delay_id)
return 200, {}
class CancelDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/cancel$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.cancel(request, delay_id)
return 200, {}
class RestartDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/restart$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.restart(request, delay_id)
return 200, {}
class SendDelayedEventServlet(RestServlet):
PATTERNS = client_patterns(
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/send$",
releases=(),
)
CATEGORY = "Delayed event management requests"
def __init__(self, hs: "HomeServer"):
super().__init__()
self.delayed_events_handler = hs.get_delayed_events_handler()
async def on_POST(
self, request: SynapseRequest, delay_id: str
) -> tuple[int, JsonDict]:
await self.delayed_events_handler.send(request, delay_id)
await self.delayed_events_handler.send(requester, delay_id)
return 200, {}
@@ -159,7 +108,4 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
UpdateDelayedEventServlet(hs).register(http_server)
CancelDelayedEventServlet(hs).register(http_server)
RestartDelayedEventServlet(hs).register(http_server)
SendDelayedEventServlet(hs).register(http_server)
DelayedEventsServlet(hs).register(http_server)

View File

@@ -13,26 +13,18 @@
#
import logging
from typing import TYPE_CHECKING, NewType
from typing import NewType
import attr
from synapse.api.errors import NotFoundError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
StoreError,
)
from synapse.storage.database import LoggingTransaction, StoreError
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomID
from synapse.util import stringutils
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -63,27 +55,6 @@ class DelayedEventDetails(EventDetails):
class DelayedEventsStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Set delayed events to be uniquely identifiable by their delay_id.
# In practice, delay_ids are already unique because they are generated
# from cryptographically strong random strings.
# Therefore, adding this constraint is not expected to ever fail,
# despite the current pkey technically allowing non-unique delay_ids.
self.db_pool.updates.register_background_index_update(
update_name="delayed_events_idx",
index_name="delayed_events_idx",
table="delayed_events",
columns=("delay_id",),
unique=True,
)
async def get_delayed_events_stream_pos(self) -> int:
"""
Gets the stream position of the background process to watch for state events
@@ -163,7 +134,9 @@ class DelayedEventsStore(SQLBaseStore):
async def restart_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
current_ts: Timestamp,
) -> Timestamp:
"""
@@ -172,6 +145,7 @@ class DelayedEventsStore(SQLBaseStore):
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
current_ts: The current time, which will be used to calculate the new send time.
Returns: The send time of the next delayed event to be sent,
@@ -189,11 +163,13 @@ class DelayedEventsStore(SQLBaseStore):
"""
UPDATE delayed_events
SET send_ts = ? + delay
WHERE delay_id = ? AND NOT is_processed
WHERE delay_id = ? AND user_localpart = ?
AND NOT is_processed
""",
(
current_ts,
delay_id,
user_localpart,
),
)
if txn.rowcount == 0:
@@ -343,15 +319,21 @@ class DelayedEventsStore(SQLBaseStore):
async def process_target_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
) -> tuple[
DelayedEventDetails,
EventDetails,
Timestamp | None,
]:
"""
Marks for processing the matching delayed event, regardless of its timeout time,
as long as it has not already been marked as such.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The details of the matching delayed event,
and the send time of the next delayed event to be sent, if any.
@@ -362,38 +344,39 @@ class DelayedEventsStore(SQLBaseStore):
def process_target_delayed_event_txn(
txn: LoggingTransaction,
) -> tuple[
DelayedEventDetails,
EventDetails,
Timestamp | None,
]:
txn.execute(
"""
UPDATE delayed_events
SET is_processed = TRUE
WHERE delay_id = ? AND NOT is_processed
WHERE delay_id = ? AND user_localpart = ?
AND NOT is_processed
RETURNING
room_id,
event_type,
state_key,
origin_server_ts,
content,
device_id,
user_localpart
device_id
""",
(delay_id,),
(
delay_id,
user_localpart,
),
)
row = txn.fetchone()
if row is None:
raise NotFoundError("Delayed event not found")
event = DelayedEventDetails(
event = EventDetails(
RoomID.from_string(row[0]),
EventType(row[1]),
StateKey(row[2]) if row[2] is not None else None,
Timestamp(row[3]) if row[3] is not None else None,
db_to_json(row[4]),
DeviceID(row[5]) if row[5] is not None else None,
DelayID(delay_id),
UserLocalpart(row[6]),
)
return event, self._get_next_delayed_event_send_ts_txn(txn)
@@ -402,10 +385,19 @@ class DelayedEventsStore(SQLBaseStore):
"process_target_delayed_event", process_target_delayed_event_txn
)
async def cancel_delayed_event(self, delay_id: str) -> Timestamp | None:
async def cancel_delayed_event(
self,
*,
delay_id: str,
user_localpart: str,
) -> Timestamp | None:
"""
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
Args:
delay_id: The ID of the delayed event to restart.
user_localpart: The localpart of the delayed event's owner.
Returns: The send time of the next delayed event to be sent, if any.
Raises:
@@ -421,6 +413,7 @@ class DelayedEventsStore(SQLBaseStore):
table="delayed_events",
keyvalues={
"delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": False,
},
)
@@ -480,7 +473,11 @@ class DelayedEventsStore(SQLBaseStore):
"cancel_delayed_state_events", cancel_delayed_state_events_txn
)
async def delete_processed_delayed_event(self, delay_id: DelayID) -> None:
async def delete_processed_delayed_event(
self,
delay_id: DelayID,
user_localpart: UserLocalpart,
) -> None:
"""
Delete the matching delayed event, as long as it has been marked as processed.
@@ -491,6 +488,7 @@ class DelayedEventsStore(SQLBaseStore):
table="delayed_events",
keyvalues={
"delay_id": delay_id,
"user_localpart": user_localpart,
"is_processed": True,
},
desc="delete_processed_delayed_event",
@@ -556,7 +554,7 @@ def _generate_delay_id() -> DelayID:
# We use the following format for delay IDs:
# syd_<random string>
# They are not scoped to user localparts, but the random string
# is expected to be sufficiently random to be globally unique.
# They are scoped to user localparts, so it is possible for
# the same ID to exist for multiple users.
return DelayID(f"syd_{stringutils.random_string(20)}")

View File

@@ -1600,21 +1600,18 @@ class EventsWorkerStore(SQLBaseStore):
if d:
d.redactions.append(redacter)
# check for MSC4293 redactions
# check for MSC4932 redactions
to_check = []
events: list[_EventRow] = []
for e in evs:
try:
event = event_dict.get(e)
if not event:
continue
events.append(event)
event_json = json.loads(event.json)
room_id = event_json.get("room_id")
user_id = event_json.get("sender")
to_check.append((room_id, user_id))
except Exception as exc:
raise InvalidEventError(f"Invalid event {event_id}") from exc
event = event_dict.get(e)
if not event:
continue
events.append(event)
event_json = json.loads(event.json)
room_id = event_json.get("room_id")
user_id = event_json.get("sender")
to_check.append((room_id, user_id))
# likely that some of these events may be for the same room/user combo, in
# which case we don't need to do redundant queries

View File

@@ -239,16 +239,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
txn.execute("SELECT event_id, should_delete FROM events_to_purge")
event_rows = txn.fetchall()
if len(event_rows) == 0:
logger.info("[purge] no events found to purge")
# For the sake of cleanliness: drop the temp table.
# This will commit the txn in sqlite, so make sure to keep this actually last.
txn.execute("DROP TABLE events_to_purge")
# no referenced state groups
return set()
logger.info(
"[purge] found %i events before cutoff, of which %i can be deleted",
len(event_rows),

View File

@@ -99,8 +99,8 @@ class PostgresEngine(
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?
if not allow_outdated_version and self._version < 130000:
raise RuntimeError("Synapse requires PostgreSQL 13 or above.")
if not allow_outdated_version and self._version < 140000:
raise RuntimeError("Synapse requires PostgreSQL 14 or above.")
with db_conn.cursor() as txn:
txn.execute("SHOW SERVER_ENCODING")

View File

@@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 93 # remember to update the list below when updating
SCHEMA_VERSION = 92 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -168,15 +168,11 @@ Changes in SCHEMA_VERSION = 91
Changes in SCHEMA_VERSION = 92
- Cleaned up a trigger that was added in #18260 and then reverted.
Changes in SCHEMA_VERSION = 93
- MSC4140: Set delayed events to be uniquely identifiable by their delay ID.
"""
SCHEMA_COMPAT_VERSION = (
# Transitive links are no longer written to `event_auth_chain_links`
# TODO: On the next compat bump, update the primary key of `delayed_events`
84
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@@ -1,15 +0,0 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 Element Creations, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(9301, 'delayed_events_idx', '{}');

View File

@@ -14,7 +14,6 @@
#
import logging
from typing import (
Any,
Callable,
@@ -31,14 +30,10 @@ from twisted.internet.task import LoopingCall
from synapse.logging import context
from synapse.types import ISynapseThreadlessReactor
from synapse.util import log_failure
from synapse.util.stringutils import random_string_insecure_fast
P = ParamSpec("P")
logger = logging.getLogger(__name__)
class Clock:
"""
A Clock wraps a Twisted reactor and provides utilities on top of it.
@@ -69,12 +64,7 @@ class Clock:
"""List of active looping calls"""
self._call_id_to_delayed_call: dict[int, IDelayedCall] = {}
"""
Mapping from unique call ID to delayed call.
For "performance", this only tracks a subset of delayed calls: those created
with `call_later` with `call_later_cancel_on_shutdown=True`.
"""
"""Mapping from unique call ID to delayed call"""
self._is_shutdown = False
"""Whether shutdown has been requested by the HomeServer"""
@@ -163,20 +153,11 @@ class Clock:
**kwargs: P.kwargs,
) -> LoopingCall:
"""Common functionality for `looping_call` and `looping_call_now`"""
instance_id = random_string_insecure_fast(5)
if self._is_shutdown:
raise Exception("Cannot start looping call. Clock has been shutdown")
looping_call_context_string = "looping_call"
if now:
looping_call_context_string = "looping_call_now"
def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> Deferred:
logger.debug(
"%s(%s): Executing callback", looping_call_context_string, instance_id
)
assert context.current_context() is context.SENTINEL_CONTEXT, (
"Expected `looping_call` callback from the reactor to start with the sentinel logcontext "
f"but saw {context.current_context()}. In other words, another task shouldn't have "
@@ -220,17 +201,6 @@ class Clock:
d = call.start(msec / 1000.0, now=now)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
self._looping_calls.append(call)
logger.debug(
"%s(%s): Scheduled looping call every %sms later",
looping_call_context_string,
instance_id,
msec,
# Find out who is scheduling the call which makes it easy to follow in the
# logs.
stack_info=True,
)
return call
def cancel_all_looping_calls(self, consumeErrors: bool = True) -> None:
@@ -256,7 +226,7 @@ class Clock:
*args: Any,
call_later_cancel_on_shutdown: bool = True,
**kwargs: Any,
) -> "DelayedCallWrapper":
) -> IDelayedCall:
"""Call something later
Note that the function will be called with generic `call_later` logcontext, so
@@ -275,79 +245,74 @@ class Clock:
issue, we can just track all delayed calls.
**kwargs: Key arguments to pass to function.
"""
call_id = self._delayed_call_id
self._delayed_call_id = self._delayed_call_id + 1
if self._is_shutdown:
raise Exception("Cannot start delayed call. Clock has been shutdown")
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
logger.debug("call_later(%s): Executing callback", call_id)
def create_wrapped_callback(
track_for_shutdown_cancellation: bool,
) -> Callable[P, None]:
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
assert context.current_context() is context.SENTINEL_CONTEXT, (
"Expected `call_later` callback from the reactor to start with the sentinel logcontext "
f"but saw {context.current_context()}. In other words, another task shouldn't have "
"leaked their logcontext to us."
)
assert context.current_context() is context.SENTINEL_CONTEXT, (
"Expected `call_later` callback from the reactor to start with the sentinel logcontext "
f"but saw {context.current_context()}. In other words, another task shouldn't have "
"leaked their logcontext to us."
)
# Because this is a callback from the reactor, we will be using the
# `sentinel` log context at this point. We want the function to log with
# some logcontext as we want to know which server the logs came from.
#
# We use `PreserveLoggingContext` to prevent our new `call_later`
# logcontext from finishing as soon as we exit this function, in case `f`
# returns an awaitable/deferred which would continue running and may try to
# restore the `call_later` context when it's done (because it's trying to
# adhere to the Synapse logcontext rules.)
#
# This also ensures that we return to the `sentinel` context when we exit
# this function and yield control back to the reactor to avoid leaking the
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
try:
with context.PreserveLoggingContext(
context.LoggingContext(
name="call_later", server_name=self._server_name
)
):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
# logcontext to the reactor
context.run_in_background(callback, *args, **kwargs)
finally:
if track_for_shutdown_cancellation:
# We still want to remove the call from the tracking map. Even if
# the callback raises an exception.
self._call_id_to_delayed_call.pop(call_id)
# Because this is a callback from the reactor, we will be using the
# `sentinel` log context at this point. We want the function to log with
# some logcontext as we want to know which server the logs came from.
#
# We use `PreserveLoggingContext` to prevent our new `call_later`
# logcontext from finishing as soon as we exit this function, in case `f`
# returns an awaitable/deferred which would continue running and may try to
# restore the `call_later` context when it's done (because it's trying to
# adhere to the Synapse logcontext rules.)
#
# This also ensures that we return to the `sentinel` context when we exit
# this function and yield control back to the reactor to avoid leaking the
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
try:
with context.PreserveLoggingContext(
context.LoggingContext(
name="call_later", server_name=self._server_name
)
):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
# logcontext to the reactor
context.run_in_background(callback, *args, **kwargs)
finally:
if call_later_cancel_on_shutdown:
# We still want to remove the call from the tracking map. Even if
# the callback raises an exception.
self._call_id_to_delayed_call.pop(call_id)
return wrapped_callback
# We can ignore the lint here since this class is the one location callLater should
# be called.
call = self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) # type: ignore[call-later-not-tracked]
logger.debug(
"call_later(%s): Scheduled call for %ss later (tracked for shutdown: %s)",
call_id,
delay,
call_later_cancel_on_shutdown,
# Find out who is scheduling the call which makes it easy to follow in the
# logs.
stack_info=True,
)
wrapped_call = DelayedCallWrapper(call, call_id, self)
if call_later_cancel_on_shutdown:
self._call_id_to_delayed_call[call_id] = wrapped_call
call_id = self._delayed_call_id
self._delayed_call_id = self._delayed_call_id + 1
return wrapped_call
# We can ignore the lint here since this class is the one location callLater
# should be called.
call = self._reactor.callLater(
delay, create_wrapped_callback(True), *args, **kwargs
) # type: ignore[call-later-not-tracked]
call = DelayedCallWrapper(call, call_id, self)
self._call_id_to_delayed_call[call_id] = call
return call
else:
# We can ignore the lint here since this class is the one location callLater should
# be called.
return self._reactor.callLater(
delay, create_wrapped_callback(False), *args, **kwargs
) # type: ignore[call-later-not-tracked]
def cancel_call_later(
self, wrapped_call: "DelayedCallWrapper", ignore_errs: bool = False
) -> None:
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
try:
logger.debug(
"cancel_call_later: cancelling scheduled call %s", wrapped_call.call_id
)
wrapped_call.delayed_call.cancel()
timer.cancel()
except Exception:
if not ignore_errs:
raise
@@ -362,11 +327,8 @@ class Clock:
"""
# We make a copy here since calling `cancel()` on a delayed_call
# will result in the call removing itself from the map mid-iteration.
for call_id, call in list(self._call_id_to_delayed_call.items()):
for call in list(self._call_id_to_delayed_call.values()):
try:
logger.debug(
"cancel_all_delayed_calls: cancelling scheduled call %s", call_id
)
call.cancel()
except Exception:
if not ignore_errs:
@@ -390,11 +352,8 @@ class Clock:
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
instance_id = random_string_insecure_fast(5)
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
logger.debug("call_when_running(%s): Executing callback", instance_id)
# Since this callback can be invoked immediately if the reactor is already
# running, we can't always assume that we're running in the sentinel
# logcontext (i.e. we can't assert that we're in the sentinel context like
@@ -433,14 +392,6 @@ class Clock:
# callWhenRunning should be called.
self._reactor.callWhenRunning(wrapped_callback, *args, **kwargs) # type: ignore[prefer-synapse-clock-call-when-running]
logger.debug(
"call_when_running(%s): Scheduled call",
instance_id,
# Find out who is scheduling the call which makes it easy to follow in the
# logs.
stack_info=True,
)
def add_system_event_trigger(
self,
phase: str,
@@ -466,16 +417,8 @@ class Clock:
Returns:
an ID that can be used to remove this call with `reactor.removeSystemEventTrigger`.
"""
instance_id = random_string_insecure_fast(5)
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
logger.debug(
"add_system_event_trigger(%s): Executing %s %s callback",
instance_id,
phase,
event_type,
)
assert context.current_context() is context.SENTINEL_CONTEXT, (
"Expected `add_system_event_trigger` callback from the reactor to start with the sentinel logcontext "
f"but saw {context.current_context()}. In other words, another task shouldn't have "
@@ -506,16 +449,6 @@ class Clock:
# logcontext to the reactor
context.run_in_background(callback, *args, **kwargs)
logger.debug(
"add_system_event_trigger(%s) for %s %s",
instance_id,
phase,
event_type,
# Find out who is scheduling the call which makes it easy to follow in the
# logs.
stack_info=True,
)
# We can ignore the lint here since this class is the one location
# `addSystemEventTrigger` should be called.
return self._reactor.addSystemEventTrigger(

View File

@@ -28,7 +28,6 @@ from synapse.types import JsonDict
from synapse.util.clock import Clock
from tests import unittest
from tests.server import FakeChannel
from tests.unittest import HomeserverTestCase
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
@@ -128,10 +127,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(setter_expected, content.get(setter_key), content)
def test_get_delayed_events_auth(self) -> None:
channel = self.make_request("GET", PATH_PREFIX)
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result)
@unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
)
@@ -159,6 +154,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/",
access_token=self.user1_access_token,
)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
@@ -166,6 +162,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
access_token=self.user1_access_token,
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
@@ -178,6 +175,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST",
f"{PATH_PREFIX}/abc",
{},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
@@ -190,6 +188,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
"POST",
f"{PATH_PREFIX}/abc",
{"action": "oops"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(
@@ -197,21 +196,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
channel.json_body["errcode"],
)
@parameterized.expand(
(
(action, action_in_path)
for action in ("cancel", "restart", "send")
for action_in_path in (True, False)
@parameterized.expand(["cancel", "restart", "send"])
def test_update_delayed_event_without_match(self, action: str) -> None:
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/abc",
{"action": action},
self.user1_access_token,
)
)
def test_update_delayed_event_without_match(
self, action: str, action_in_path: bool
) -> None:
channel = self._update_delayed_event("abc", action, action_in_path)
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
@parameterized.expand((True, False))
def test_cancel_delayed_state_event(self, action_in_path: bool) -> None:
def test_cancel_delayed_state_event(self) -> None:
state_key = "to_never_send"
setter_key = "setter"
@@ -226,7 +221,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
assert delay_id is not None
self.assertIsNotNone(delay_id)
self.reactor.advance(1)
events = self._get_delayed_events()
@@ -241,7 +236,12 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self._update_delayed_event(delay_id, "cancel", action_in_path)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "cancel"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events())
@@ -254,11 +254,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
@parameterized.expand((True, False))
@unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
)
def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None:
def test_cancel_delayed_event_ratelimit(self) -> None:
delay_ids = []
for _ in range(2):
channel = self.make_request(
@@ -269,17 +268,38 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
assert delay_id is not None
self.assertIsNotNone(delay_id)
delay_ids.append(delay_id)
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
args = (
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "cancel"},
self.user1_access_token,
)
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
@parameterized.expand((True, False))
def test_send_delayed_state_event(self, action_in_path: bool) -> None:
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_send_delayed_state_event(self) -> None:
state_key = "to_send_on_request"
setter_key = "setter"
@@ -294,7 +314,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
assert delay_id is not None
self.assertIsNotNone(delay_id)
self.reactor.advance(1)
events = self._get_delayed_events()
@@ -309,7 +329,12 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self._update_delayed_event(delay_id, "send", action_in_path)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "send"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertListEqual([], self._get_delayed_events())
content = self.helper.get_state(
@@ -320,9 +345,8 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(setter_expected, content.get(setter_key), content)
@parameterized.expand((True, False))
@unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}})
def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None:
@unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}})
def test_send_delayed_event_ratelimit(self) -> None:
delay_ids = []
for _ in range(2):
channel = self.make_request(
@@ -333,17 +357,38 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
assert delay_id is not None
self.assertIsNotNone(delay_id)
delay_ids.append(delay_id)
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
args = (
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "send"},
self.user1_access_token,
)
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
@parameterized.expand((True, False))
def test_restart_delayed_state_event(self, action_in_path: bool) -> None:
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_restart_delayed_state_event(self) -> None:
state_key = "to_send_on_restarted_timeout"
setter_key = "setter"
@@ -358,7 +403,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
assert delay_id is not None
self.assertIsNotNone(delay_id)
self.reactor.advance(1)
events = self._get_delayed_events()
@@ -373,7 +418,12 @@ class DelayedEventsTestCase(HomeserverTestCase):
expect_code=HTTPStatus.NOT_FOUND,
)
channel = self._update_delayed_event(delay_id, "restart", action_in_path)
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_id}",
{"action": "restart"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.reactor.advance(1)
@@ -399,11 +449,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(setter_expected, content.get(setter_key), content)
@parameterized.expand((True, False))
@unittest.override_config(
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
)
def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None:
def test_restart_delayed_event_ratelimit(self) -> None:
delay_ids = []
for _ in range(2):
channel = self.make_request(
@@ -414,19 +463,37 @@ class DelayedEventsTestCase(HomeserverTestCase):
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
delay_id = channel.json_body.get("delay_id")
assert delay_id is not None
self.assertIsNotNone(delay_id)
delay_ids.append(delay_id)
channel = self._update_delayed_event(
delay_ids.pop(0), "restart", action_in_path
channel = self.make_request(
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"},
self.user1_access_token,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
channel = self._update_delayed_event(
delay_ids.pop(0), "restart", action_in_path
args = (
"POST",
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
{"action": "restart"},
self.user1_access_token,
)
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
self.get_success(
self.hs.get_datastores().main.set_ratelimit_for_user(
self.user1_user_id, 0, 0
)
)
# Test that the request isn't ratelimited anymore.
channel = self.make_request(*args)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
self,
) -> None:
@@ -531,17 +598,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
return content
def _update_delayed_event(
self, delay_id: str, action: str, action_in_path: bool
) -> FakeChannel:
path = f"{PATH_PREFIX}/{delay_id}"
body = {}
if action_in_path:
path += f"/{action}"
else:
body["action"] = action
return self.make_request("POST", path, body)
def _get_path_for_delayed_state(
room_id: str, event_type: str, state_key: str, delay_ms: int