mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
2 Commits
devon/back
...
devon/depr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7450052e60 | ||
|
|
a4345c391e |
@@ -72,7 +72,7 @@ trial_postgres_tests = [
|
||||
{
|
||||
"python-version": "3.10",
|
||||
"database": "postgres",
|
||||
"postgres-version": "13",
|
||||
"postgres-version": "14",
|
||||
"extras": "all",
|
||||
},
|
||||
{
|
||||
|
||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -617,7 +617,7 @@ jobs:
|
||||
matrix:
|
||||
include:
|
||||
- python-version: "3.10"
|
||||
postgres-version: "13"
|
||||
postgres-version: "14"
|
||||
|
||||
- python-version: "3.14"
|
||||
postgres-version: "17"
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Provide additional servers with federation room directory results.
|
||||
@@ -1 +0,0 @@
|
||||
Add a shortcut return when there are no events to purge.
|
||||
@@ -1 +0,0 @@
|
||||
Remove authentication from `POST /_matrix/client/v1/delayed_events`, and allow calling this endpoint with the update action to take (`send`/`cancel`/`restart`) in the request path instead of the body.
|
||||
@@ -1 +0,0 @@
|
||||
Point out which event caused the exception when checking [MSC4293](https://github.com/matrix-org/matrix-spec-proposals/pull/4293) redactions.
|
||||
1
changelog.d/19170.removal
Normal file
1
changelog.d/19170.removal
Normal file
@@ -0,0 +1 @@
|
||||
Remove support for PostgreSQL 13.
|
||||
@@ -1 +0,0 @@
|
||||
Restore printing `sentinel` for the log record `request` when no logcontext is active.
|
||||
@@ -1 +0,0 @@
|
||||
Add debug logs to track `Clock` utilities.
|
||||
@@ -1 +0,0 @@
|
||||
Run background updates on all databases.
|
||||
@@ -11,7 +11,7 @@ ARG SYNAPSE_VERSION=latest
|
||||
ARG FROM=matrixdotorg/synapse-workers:$SYNAPSE_VERSION
|
||||
ARG DEBIAN_VERSION=trixie
|
||||
|
||||
FROM docker.io/library/postgres:13-${DEBIAN_VERSION} AS postgres_base
|
||||
FROM docker.io/library/postgres:14-${DEBIAN_VERSION} AS postgres_base
|
||||
|
||||
FROM $FROM
|
||||
# First of all, we copy postgres server from the official postgres image,
|
||||
@@ -26,7 +26,7 @@ RUN adduser --system --uid 999 postgres --home /var/lib/postgresql
|
||||
COPY --from=postgres_base /usr/lib/postgresql /usr/lib/postgresql
|
||||
COPY --from=postgres_base /usr/share/postgresql /usr/share/postgresql
|
||||
COPY --from=postgres_base --chown=postgres /var/run/postgresql /var/run/postgresql
|
||||
ENV PATH="${PATH}:/usr/lib/postgresql/13/bin"
|
||||
ENV PATH="${PATH}:/usr/lib/postgresql/14/bin"
|
||||
ENV PGDATA=/var/lib/postgresql/data
|
||||
|
||||
# We also initialize the database at build time, rather than runtime, so that it's faster to spin up the image.
|
||||
|
||||
@@ -117,6 +117,14 @@ each upgrade are complete before moving on to the next upgrade, to avoid
|
||||
stacking them up. You can monitor the currently running background updates with
|
||||
[the Admin API](usage/administration/admin_api/background_updates.html#status).
|
||||
|
||||
# Upgrading to v1.143.0
|
||||
|
||||
## Dropping support for PostgreSQL 13
|
||||
|
||||
In line with our [deprecation policy](deprecation_policy.md), we've dropped
|
||||
support for PostgreSQL 13, as it is no longer supported upstream.
|
||||
This release of Synapse requires PostgreSQL 14+.
|
||||
|
||||
# Upgrading to v1.142.0
|
||||
|
||||
## Python 3.10+ is now required
|
||||
|
||||
@@ -58,7 +58,6 @@ from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
||||
from synapse.storage.databases.main import FilteringWorkerStore
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.delayed_events import DelayedEventsStore
|
||||
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
|
||||
from synapse.storage.databases.main.e2e_room_keys import EndToEndRoomKeyBackgroundStore
|
||||
@@ -274,7 +273,6 @@ class Store(
|
||||
RelationsWorkerStore,
|
||||
EventFederationWorkerStore,
|
||||
SlidingSyncStore,
|
||||
DelayedEventsStore,
|
||||
):
|
||||
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
|
||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||
|
||||
@@ -450,8 +450,7 @@ async def start(
|
||||
await _base.start(hs, freeze=freeze)
|
||||
|
||||
# TODO: Feels like this should be moved somewhere else.
|
||||
for db in hs.get_datastores().databases:
|
||||
db.updates.start_doing_background_updates()
|
||||
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
|
||||
|
||||
|
||||
def start_reactor(
|
||||
|
||||
@@ -65,6 +65,8 @@ from typing import (
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse.appservice import (
|
||||
ApplicationService,
|
||||
ApplicationServiceState,
|
||||
@@ -76,7 +78,7 @@ from synapse.events import EventBase
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import DeviceListUpdates, JsonMapping
|
||||
from synapse.util.clock import Clock, DelayedCallWrapper
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -501,7 +503,7 @@ class _Recoverer:
|
||||
self.service = service
|
||||
self.callback = callback
|
||||
self.backoff_counter = 1
|
||||
self.scheduled_recovery: DelayedCallWrapper | None = None
|
||||
self.scheduled_recovery: IDelayedCall | None = None
|
||||
|
||||
def recover(self) -> None:
|
||||
delay = 2**self.backoff_counter
|
||||
|
||||
@@ -21,7 +21,6 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import ShadowBanError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions
|
||||
@@ -30,9 +29,11 @@ from synapse.replication.http.delayed_events import (
|
||||
)
|
||||
from synapse.storage.databases.main.delayed_events import (
|
||||
DelayedEventDetails,
|
||||
DelayID,
|
||||
EventType,
|
||||
StateKey,
|
||||
Timestamp,
|
||||
UserLocalpart,
|
||||
)
|
||||
from synapse.storage.databases.main.state_deltas import StateDelta
|
||||
from synapse.types import (
|
||||
@@ -398,63 +399,96 @@ class DelayedEventsHandler:
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
async def cancel(self, request: SynapseRequest, delay_id: str) -> None:
|
||||
async def cancel(self, requester: Requester, delay_id: str) -> None:
|
||||
"""
|
||||
Cancels the scheduled delivery of the matching delayed event.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
requester,
|
||||
(requester.user.to_string(), requester.device_id),
|
||||
)
|
||||
await make_deferred_yieldable(self._initialized_from_db)
|
||||
|
||||
next_send_ts = await self._store.cancel_delayed_event(delay_id)
|
||||
next_send_ts = await self._store.cancel_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at_or_none(next_send_ts)
|
||||
|
||||
async def restart(self, request: SynapseRequest, delay_id: str) -> None:
|
||||
async def restart(self, requester: Requester, delay_id: str) -> None:
|
||||
"""
|
||||
Restarts the scheduled delivery of the matching delayed event.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
requester,
|
||||
(requester.user.to_string(), requester.device_id),
|
||||
)
|
||||
await make_deferred_yieldable(self._initialized_from_db)
|
||||
|
||||
next_send_ts = await self._store.restart_delayed_event(
|
||||
delay_id, self._get_current_ts()
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
current_ts=self._get_current_ts(),
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
async def send(self, request: SynapseRequest, delay_id: str) -> None:
|
||||
async def send(self, requester: Requester, delay_id: str) -> None:
|
||||
"""
|
||||
Immediately sends the matching delayed event, instead of waiting for its scheduled delivery.
|
||||
|
||||
Args:
|
||||
requester: The owner of the delayed event to act on.
|
||||
delay_id: The ID of the delayed event to act on.
|
||||
|
||||
Raises:
|
||||
NotFoundError: if no matching delayed event could be found.
|
||||
"""
|
||||
assert self._is_master
|
||||
await self._delayed_event_mgmt_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
# Use standard request limiter for sending delayed events on-demand,
|
||||
# as an on-demand send is similar to sending a regular event.
|
||||
await self._request_ratelimiter.ratelimit(requester)
|
||||
await make_deferred_yieldable(self._initialized_from_db)
|
||||
|
||||
event, next_send_ts = await self._store.process_target_delayed_event(delay_id)
|
||||
event, next_send_ts = await self._store.process_target_delayed_event(
|
||||
delay_id=delay_id,
|
||||
user_localpart=requester.user.localpart,
|
||||
)
|
||||
|
||||
if self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at_or_none(next_send_ts)
|
||||
|
||||
await self._send_event(event)
|
||||
await self._send_event(
|
||||
DelayedEventDetails(
|
||||
delay_id=DelayID(delay_id),
|
||||
user_localpart=UserLocalpart(requester.user.localpart),
|
||||
room_id=event.room_id,
|
||||
type=event.type,
|
||||
state_key=event.state_key,
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
content=event.content,
|
||||
device_id=event.device_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _send_on_timeout(self) -> None:
|
||||
self._next_delayed_event_call = None
|
||||
@@ -577,7 +611,9 @@ class DelayedEventsHandler:
|
||||
finally:
|
||||
# TODO: If this is a temporary error, retry. Otherwise, consider notifying clients of the failure
|
||||
try:
|
||||
await self._store.delete_processed_delayed_event(event.delay_id)
|
||||
await self._store.delete_processed_delayed_event(
|
||||
event.delay_id, event.user_localpart
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to delete processed delayed event")
|
||||
|
||||
|
||||
@@ -321,7 +321,16 @@ class DirectoryHandler:
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room Alias is not hosted on this homeserver")
|
||||
|
||||
return await self.get_association(room_alias)
|
||||
result = await self.get_association_from_room_alias(room_alias)
|
||||
|
||||
if result is not None:
|
||||
return {"room_id": result.room_id, "servers": result.servers}
|
||||
else:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Room alias %r not found" % (room_alias.to_string(),),
|
||||
Codes.NOT_FOUND,
|
||||
)
|
||||
|
||||
async def _update_canonical_alias(
|
||||
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
|
||||
|
||||
@@ -619,24 +619,19 @@ class LoggingContextFilter(logging.Filter):
|
||||
True to include the record in the log output.
|
||||
"""
|
||||
context = current_context()
|
||||
# type-ignore: `context` should never be `None`, but if it somehow ends up
|
||||
# being, then we end up in a death spiral of infinite loops, so let's check, for
|
||||
record.request = self._default_request
|
||||
|
||||
# Avoid overwriting an existing `server_name` on the record. This is running in
|
||||
# the context of a global log record filter so there may be 3rd-party code that
|
||||
# adds their own `server_name` and we don't want to interfere with that
|
||||
# (clobber).
|
||||
if not hasattr(record, "server_name"):
|
||||
record.server_name = "unknown_server_from_no_logcontext"
|
||||
|
||||
# context should never be None, but if it somehow ends up being, then
|
||||
# we end up in a death spiral of infinite loops, so let's check, for
|
||||
# robustness' sake.
|
||||
#
|
||||
# Add some default values to avoid log formatting errors.
|
||||
if context is None:
|
||||
record.request = self._default_request # type: ignore[unreachable]
|
||||
|
||||
# Avoid overwriting an existing `server_name` on the record. This is running in
|
||||
# the context of a global log record filter so there may be 3rd-party code that
|
||||
# adds their own `server_name` and we don't want to interfere with that
|
||||
# (clobber).
|
||||
if not hasattr(record, "server_name"):
|
||||
record.server_name = "unknown_server_from_no_logcontext"
|
||||
|
||||
# Otherwise, in the normal, expected case, fill in the log record attributes
|
||||
# from the logcontext.
|
||||
else:
|
||||
if context is not None:
|
||||
|
||||
def safe_set(attr: str, value: Any) -> None:
|
||||
"""
|
||||
|
||||
@@ -47,11 +47,14 @@ class UpdateDelayedEventServlet(RestServlet):
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
action = str(body["action"])
|
||||
@@ -72,65 +75,11 @@ class UpdateDelayedEventServlet(RestServlet):
|
||||
)
|
||||
|
||||
if enum_action == _UpdateDelayedEventAction.CANCEL:
|
||||
await self.delayed_events_handler.cancel(request, delay_id)
|
||||
await self.delayed_events_handler.cancel(requester, delay_id)
|
||||
elif enum_action == _UpdateDelayedEventAction.RESTART:
|
||||
await self.delayed_events_handler.restart(request, delay_id)
|
||||
await self.delayed_events_handler.restart(requester, delay_id)
|
||||
elif enum_action == _UpdateDelayedEventAction.SEND:
|
||||
await self.delayed_events_handler.send(request, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class CancelDelayedEventServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/cancel$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
await self.delayed_events_handler.cancel(request, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class RestartDelayedEventServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/restart$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
await self.delayed_events_handler.restart(request, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
class SendDelayedEventServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
r"/org\.matrix\.msc4140/delayed_events/(?P<delay_id>[^/]+)/send$",
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Delayed event management requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, delay_id: str
|
||||
) -> tuple[int, JsonDict]:
|
||||
await self.delayed_events_handler.send(request, delay_id)
|
||||
await self.delayed_events_handler.send(requester, delay_id)
|
||||
return 200, {}
|
||||
|
||||
|
||||
@@ -159,7 +108,4 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
# The following can't currently be instantiated on workers.
|
||||
if hs.config.worker.worker_app is None:
|
||||
UpdateDelayedEventServlet(hs).register(http_server)
|
||||
CancelDelayedEventServlet(hs).register(http_server)
|
||||
RestartDelayedEventServlet(hs).register(http_server)
|
||||
SendDelayedEventServlet(hs).register(http_server)
|
||||
DelayedEventsServlet(hs).register(http_server)
|
||||
|
||||
@@ -13,26 +13,18 @@
|
||||
#
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, NewType
|
||||
from typing import NewType
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
StoreError,
|
||||
)
|
||||
from synapse.storage.database import LoggingTransaction, StoreError
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, RoomID
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -63,27 +55,6 @@ class DelayedEventDetails(EventDetails):
|
||||
|
||||
|
||||
class DelayedEventsStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
# Set delayed events to be uniquely identifiable by their delay_id.
|
||||
# In practice, delay_ids are already unique because they are generated
|
||||
# from cryptographically strong random strings.
|
||||
# Therefore, adding this constraint is not expected to ever fail,
|
||||
# despite the current pkey technically allowing non-unique delay_ids.
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
update_name="delayed_events_idx",
|
||||
index_name="delayed_events_idx",
|
||||
table="delayed_events",
|
||||
columns=("delay_id",),
|
||||
unique=True,
|
||||
)
|
||||
|
||||
async def get_delayed_events_stream_pos(self) -> int:
|
||||
"""
|
||||
Gets the stream position of the background process to watch for state events
|
||||
@@ -163,7 +134,9 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
|
||||
async def restart_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
current_ts: Timestamp,
|
||||
) -> Timestamp:
|
||||
"""
|
||||
@@ -172,6 +145,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
current_ts: The current time, which will be used to calculate the new send time.
|
||||
|
||||
Returns: The send time of the next delayed event to be sent,
|
||||
@@ -189,11 +163,13 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"""
|
||||
UPDATE delayed_events
|
||||
SET send_ts = ? + delay
|
||||
WHERE delay_id = ? AND NOT is_processed
|
||||
WHERE delay_id = ? AND user_localpart = ?
|
||||
AND NOT is_processed
|
||||
""",
|
||||
(
|
||||
current_ts,
|
||||
delay_id,
|
||||
user_localpart,
|
||||
),
|
||||
)
|
||||
if txn.rowcount == 0:
|
||||
@@ -343,15 +319,21 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
|
||||
async def process_target_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
) -> tuple[
|
||||
DelayedEventDetails,
|
||||
EventDetails,
|
||||
Timestamp | None,
|
||||
]:
|
||||
"""
|
||||
Marks for processing the matching delayed event, regardless of its timeout time,
|
||||
as long as it has not already been marked as such.
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
|
||||
Returns: The details of the matching delayed event,
|
||||
and the send time of the next delayed event to be sent, if any.
|
||||
|
||||
@@ -362,38 +344,39 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
def process_target_delayed_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> tuple[
|
||||
DelayedEventDetails,
|
||||
EventDetails,
|
||||
Timestamp | None,
|
||||
]:
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE delayed_events
|
||||
SET is_processed = TRUE
|
||||
WHERE delay_id = ? AND NOT is_processed
|
||||
WHERE delay_id = ? AND user_localpart = ?
|
||||
AND NOT is_processed
|
||||
RETURNING
|
||||
room_id,
|
||||
event_type,
|
||||
state_key,
|
||||
origin_server_ts,
|
||||
content,
|
||||
device_id,
|
||||
user_localpart
|
||||
device_id
|
||||
""",
|
||||
(delay_id,),
|
||||
(
|
||||
delay_id,
|
||||
user_localpart,
|
||||
),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
raise NotFoundError("Delayed event not found")
|
||||
|
||||
event = DelayedEventDetails(
|
||||
event = EventDetails(
|
||||
RoomID.from_string(row[0]),
|
||||
EventType(row[1]),
|
||||
StateKey(row[2]) if row[2] is not None else None,
|
||||
Timestamp(row[3]) if row[3] is not None else None,
|
||||
db_to_json(row[4]),
|
||||
DeviceID(row[5]) if row[5] is not None else None,
|
||||
DelayID(delay_id),
|
||||
UserLocalpart(row[6]),
|
||||
)
|
||||
|
||||
return event, self._get_next_delayed_event_send_ts_txn(txn)
|
||||
@@ -402,10 +385,19 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"process_target_delayed_event", process_target_delayed_event_txn
|
||||
)
|
||||
|
||||
async def cancel_delayed_event(self, delay_id: str) -> Timestamp | None:
|
||||
async def cancel_delayed_event(
|
||||
self,
|
||||
*,
|
||||
delay_id: str,
|
||||
user_localpart: str,
|
||||
) -> Timestamp | None:
|
||||
"""
|
||||
Cancels the matching delayed event, i.e. remove it as long as it hasn't been processed.
|
||||
|
||||
Args:
|
||||
delay_id: The ID of the delayed event to restart.
|
||||
user_localpart: The localpart of the delayed event's owner.
|
||||
|
||||
Returns: The send time of the next delayed event to be sent, if any.
|
||||
|
||||
Raises:
|
||||
@@ -421,6 +413,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"delay_id": delay_id,
|
||||
"user_localpart": user_localpart,
|
||||
"is_processed": False,
|
||||
},
|
||||
)
|
||||
@@ -480,7 +473,11 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"cancel_delayed_state_events", cancel_delayed_state_events_txn
|
||||
)
|
||||
|
||||
async def delete_processed_delayed_event(self, delay_id: DelayID) -> None:
|
||||
async def delete_processed_delayed_event(
|
||||
self,
|
||||
delay_id: DelayID,
|
||||
user_localpart: UserLocalpart,
|
||||
) -> None:
|
||||
"""
|
||||
Delete the matching delayed event, as long as it has been marked as processed.
|
||||
|
||||
@@ -491,6 +488,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
table="delayed_events",
|
||||
keyvalues={
|
||||
"delay_id": delay_id,
|
||||
"user_localpart": user_localpart,
|
||||
"is_processed": True,
|
||||
},
|
||||
desc="delete_processed_delayed_event",
|
||||
@@ -556,7 +554,7 @@ def _generate_delay_id() -> DelayID:
|
||||
|
||||
# We use the following format for delay IDs:
|
||||
# syd_<random string>
|
||||
# They are not scoped to user localparts, but the random string
|
||||
# is expected to be sufficiently random to be globally unique.
|
||||
# They are scoped to user localparts, so it is possible for
|
||||
# the same ID to exist for multiple users.
|
||||
|
||||
return DelayID(f"syd_{stringutils.random_string(20)}")
|
||||
|
||||
@@ -1600,21 +1600,18 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
if d:
|
||||
d.redactions.append(redacter)
|
||||
|
||||
# check for MSC4293 redactions
|
||||
# check for MSC4932 redactions
|
||||
to_check = []
|
||||
events: list[_EventRow] = []
|
||||
for e in evs:
|
||||
try:
|
||||
event = event_dict.get(e)
|
||||
if not event:
|
||||
continue
|
||||
events.append(event)
|
||||
event_json = json.loads(event.json)
|
||||
room_id = event_json.get("room_id")
|
||||
user_id = event_json.get("sender")
|
||||
to_check.append((room_id, user_id))
|
||||
except Exception as exc:
|
||||
raise InvalidEventError(f"Invalid event {event_id}") from exc
|
||||
event = event_dict.get(e)
|
||||
if not event:
|
||||
continue
|
||||
events.append(event)
|
||||
event_json = json.loads(event.json)
|
||||
room_id = event_json.get("room_id")
|
||||
user_id = event_json.get("sender")
|
||||
to_check.append((room_id, user_id))
|
||||
|
||||
# likely that some of these events may be for the same room/user combo, in
|
||||
# which case we don't need to do redundant queries
|
||||
|
||||
@@ -239,16 +239,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
|
||||
txn.execute("SELECT event_id, should_delete FROM events_to_purge")
|
||||
event_rows = txn.fetchall()
|
||||
|
||||
if len(event_rows) == 0:
|
||||
logger.info("[purge] no events found to purge")
|
||||
|
||||
# For the sake of cleanliness: drop the temp table.
|
||||
# This will commit the txn in sqlite, so make sure to keep this actually last.
|
||||
txn.execute("DROP TABLE events_to_purge")
|
||||
# no referenced state groups
|
||||
return set()
|
||||
|
||||
logger.info(
|
||||
"[purge] found %i events before cutoff, of which %i can be deleted",
|
||||
len(event_rows),
|
||||
|
||||
@@ -99,8 +99,8 @@ class PostgresEngine(
|
||||
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
|
||||
|
||||
# Are we on a supported PostgreSQL version?
|
||||
if not allow_outdated_version and self._version < 130000:
|
||||
raise RuntimeError("Synapse requires PostgreSQL 13 or above.")
|
||||
if not allow_outdated_version and self._version < 140000:
|
||||
raise RuntimeError("Synapse requires PostgreSQL 14 or above.")
|
||||
|
||||
with db_conn.cursor() as txn:
|
||||
txn.execute("SHOW SERVER_ENCODING")
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 93 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 92 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
@@ -168,15 +168,11 @@ Changes in SCHEMA_VERSION = 91
|
||||
|
||||
Changes in SCHEMA_VERSION = 92
|
||||
- Cleaned up a trigger that was added in #18260 and then reverted.
|
||||
|
||||
Changes in SCHEMA_VERSION = 93
|
||||
- MSC4140: Set delayed events to be uniquely identifiable by their delay ID.
|
||||
"""
|
||||
|
||||
|
||||
SCHEMA_COMPAT_VERSION = (
|
||||
# Transitive links are no longer written to `event_auth_chain_links`
|
||||
# TODO: On the next compat bump, update the primary key of `delayed_events`
|
||||
84
|
||||
)
|
||||
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 Element Creations, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(9301, 'delayed_events_idx', '{}');
|
||||
@@ -14,7 +14,6 @@
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@@ -31,14 +30,10 @@ from twisted.internet.task import LoopingCall
|
||||
from synapse.logging import context
|
||||
from synapse.types import ISynapseThreadlessReactor
|
||||
from synapse.util import log_failure
|
||||
from synapse.util.stringutils import random_string_insecure_fast
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Clock:
|
||||
"""
|
||||
A Clock wraps a Twisted reactor and provides utilities on top of it.
|
||||
@@ -69,12 +64,7 @@ class Clock:
|
||||
"""List of active looping calls"""
|
||||
|
||||
self._call_id_to_delayed_call: dict[int, IDelayedCall] = {}
|
||||
"""
|
||||
Mapping from unique call ID to delayed call.
|
||||
|
||||
For "performance", this only tracks a subset of delayed calls: those created
|
||||
with `call_later` with `call_later_cancel_on_shutdown=True`.
|
||||
"""
|
||||
"""Mapping from unique call ID to delayed call"""
|
||||
|
||||
self._is_shutdown = False
|
||||
"""Whether shutdown has been requested by the HomeServer"""
|
||||
@@ -163,20 +153,11 @@ class Clock:
|
||||
**kwargs: P.kwargs,
|
||||
) -> LoopingCall:
|
||||
"""Common functionality for `looping_call` and `looping_call_now`"""
|
||||
instance_id = random_string_insecure_fast(5)
|
||||
|
||||
if self._is_shutdown:
|
||||
raise Exception("Cannot start looping call. Clock has been shutdown")
|
||||
|
||||
looping_call_context_string = "looping_call"
|
||||
if now:
|
||||
looping_call_context_string = "looping_call_now"
|
||||
|
||||
def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> Deferred:
|
||||
logger.debug(
|
||||
"%s(%s): Executing callback", looping_call_context_string, instance_id
|
||||
)
|
||||
|
||||
assert context.current_context() is context.SENTINEL_CONTEXT, (
|
||||
"Expected `looping_call` callback from the reactor to start with the sentinel logcontext "
|
||||
f"but saw {context.current_context()}. In other words, another task shouldn't have "
|
||||
@@ -220,17 +201,6 @@ class Clock:
|
||||
d = call.start(msec / 1000.0, now=now)
|
||||
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
|
||||
self._looping_calls.append(call)
|
||||
|
||||
logger.debug(
|
||||
"%s(%s): Scheduled looping call every %sms later",
|
||||
looping_call_context_string,
|
||||
instance_id,
|
||||
msec,
|
||||
# Find out who is scheduling the call which makes it easy to follow in the
|
||||
# logs.
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
def cancel_all_looping_calls(self, consumeErrors: bool = True) -> None:
|
||||
@@ -256,7 +226,7 @@ class Clock:
|
||||
*args: Any,
|
||||
call_later_cancel_on_shutdown: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> "DelayedCallWrapper":
|
||||
) -> IDelayedCall:
|
||||
"""Call something later
|
||||
|
||||
Note that the function will be called with generic `call_later` logcontext, so
|
||||
@@ -275,79 +245,74 @@ class Clock:
|
||||
issue, we can just track all delayed calls.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
call_id = self._delayed_call_id
|
||||
self._delayed_call_id = self._delayed_call_id + 1
|
||||
|
||||
if self._is_shutdown:
|
||||
raise Exception("Cannot start delayed call. Clock has been shutdown")
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
logger.debug("call_later(%s): Executing callback", call_id)
|
||||
def create_wrapped_callback(
|
||||
track_for_shutdown_cancellation: bool,
|
||||
) -> Callable[P, None]:
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
assert context.current_context() is context.SENTINEL_CONTEXT, (
|
||||
"Expected `call_later` callback from the reactor to start with the sentinel logcontext "
|
||||
f"but saw {context.current_context()}. In other words, another task shouldn't have "
|
||||
"leaked their logcontext to us."
|
||||
)
|
||||
|
||||
assert context.current_context() is context.SENTINEL_CONTEXT, (
|
||||
"Expected `call_later` callback from the reactor to start with the sentinel logcontext "
|
||||
f"but saw {context.current_context()}. In other words, another task shouldn't have "
|
||||
"leaked their logcontext to us."
|
||||
)
|
||||
# Because this is a callback from the reactor, we will be using the
|
||||
# `sentinel` log context at this point. We want the function to log with
|
||||
# some logcontext as we want to know which server the logs came from.
|
||||
#
|
||||
# We use `PreserveLoggingContext` to prevent our new `call_later`
|
||||
# logcontext from finishing as soon as we exit this function, in case `f`
|
||||
# returns an awaitable/deferred which would continue running and may try to
|
||||
# restore the `call_later` context when it's done (because it's trying to
|
||||
# adhere to the Synapse logcontext rules.)
|
||||
#
|
||||
# This also ensures that we return to the `sentinel` context when we exit
|
||||
# this function and yield control back to the reactor to avoid leaking the
|
||||
# current logcontext to the reactor (which would then get picked up and
|
||||
# associated with the next thing the reactor does)
|
||||
try:
|
||||
with context.PreserveLoggingContext(
|
||||
context.LoggingContext(
|
||||
name="call_later", server_name=self._server_name
|
||||
)
|
||||
):
|
||||
# We use `run_in_background` to reset the logcontext after `f` (or the
|
||||
# awaitable returned by `f`) completes to avoid leaking the current
|
||||
# logcontext to the reactor
|
||||
context.run_in_background(callback, *args, **kwargs)
|
||||
finally:
|
||||
if track_for_shutdown_cancellation:
|
||||
# We still want to remove the call from the tracking map. Even if
|
||||
# the callback raises an exception.
|
||||
self._call_id_to_delayed_call.pop(call_id)
|
||||
|
||||
# Because this is a callback from the reactor, we will be using the
|
||||
# `sentinel` log context at this point. We want the function to log with
|
||||
# some logcontext as we want to know which server the logs came from.
|
||||
#
|
||||
# We use `PreserveLoggingContext` to prevent our new `call_later`
|
||||
# logcontext from finishing as soon as we exit this function, in case `f`
|
||||
# returns an awaitable/deferred which would continue running and may try to
|
||||
# restore the `call_later` context when it's done (because it's trying to
|
||||
# adhere to the Synapse logcontext rules.)
|
||||
#
|
||||
# This also ensures that we return to the `sentinel` context when we exit
|
||||
# this function and yield control back to the reactor to avoid leaking the
|
||||
# current logcontext to the reactor (which would then get picked up and
|
||||
# associated with the next thing the reactor does)
|
||||
try:
|
||||
with context.PreserveLoggingContext(
|
||||
context.LoggingContext(
|
||||
name="call_later", server_name=self._server_name
|
||||
)
|
||||
):
|
||||
# We use `run_in_background` to reset the logcontext after `f` (or the
|
||||
# awaitable returned by `f`) completes to avoid leaking the current
|
||||
# logcontext to the reactor
|
||||
context.run_in_background(callback, *args, **kwargs)
|
||||
finally:
|
||||
if call_later_cancel_on_shutdown:
|
||||
# We still want to remove the call from the tracking map. Even if
|
||||
# the callback raises an exception.
|
||||
self._call_id_to_delayed_call.pop(call_id)
|
||||
return wrapped_callback
|
||||
|
||||
# We can ignore the lint here since this class is the one location callLater should
|
||||
# be called.
|
||||
call = self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) # type: ignore[call-later-not-tracked]
|
||||
|
||||
logger.debug(
|
||||
"call_later(%s): Scheduled call for %ss later (tracked for shutdown: %s)",
|
||||
call_id,
|
||||
delay,
|
||||
call_later_cancel_on_shutdown,
|
||||
# Find out who is scheduling the call which makes it easy to follow in the
|
||||
# logs.
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
wrapped_call = DelayedCallWrapper(call, call_id, self)
|
||||
if call_later_cancel_on_shutdown:
|
||||
self._call_id_to_delayed_call[call_id] = wrapped_call
|
||||
call_id = self._delayed_call_id
|
||||
self._delayed_call_id = self._delayed_call_id + 1
|
||||
|
||||
return wrapped_call
|
||||
# We can ignore the lint here since this class is the one location callLater
|
||||
# should be called.
|
||||
call = self._reactor.callLater(
|
||||
delay, create_wrapped_callback(True), *args, **kwargs
|
||||
) # type: ignore[call-later-not-tracked]
|
||||
call = DelayedCallWrapper(call, call_id, self)
|
||||
self._call_id_to_delayed_call[call_id] = call
|
||||
return call
|
||||
else:
|
||||
# We can ignore the lint here since this class is the one location callLater should
|
||||
# be called.
|
||||
return self._reactor.callLater(
|
||||
delay, create_wrapped_callback(False), *args, **kwargs
|
||||
) # type: ignore[call-later-not-tracked]
|
||||
|
||||
def cancel_call_later(
|
||||
self, wrapped_call: "DelayedCallWrapper", ignore_errs: bool = False
|
||||
) -> None:
|
||||
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
|
||||
try:
|
||||
logger.debug(
|
||||
"cancel_call_later: cancelling scheduled call %s", wrapped_call.call_id
|
||||
)
|
||||
wrapped_call.delayed_call.cancel()
|
||||
timer.cancel()
|
||||
except Exception:
|
||||
if not ignore_errs:
|
||||
raise
|
||||
@@ -362,11 +327,8 @@ class Clock:
|
||||
"""
|
||||
# We make a copy here since calling `cancel()` on a delayed_call
|
||||
# will result in the call removing itself from the map mid-iteration.
|
||||
for call_id, call in list(self._call_id_to_delayed_call.items()):
|
||||
for call in list(self._call_id_to_delayed_call.values()):
|
||||
try:
|
||||
logger.debug(
|
||||
"cancel_all_delayed_calls: cancelling scheduled call %s", call_id
|
||||
)
|
||||
call.cancel()
|
||||
except Exception:
|
||||
if not ignore_errs:
|
||||
@@ -390,11 +352,8 @@ class Clock:
|
||||
*args: Postional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
instance_id = random_string_insecure_fast(5)
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
logger.debug("call_when_running(%s): Executing callback", instance_id)
|
||||
|
||||
# Since this callback can be invoked immediately if the reactor is already
|
||||
# running, we can't always assume that we're running in the sentinel
|
||||
# logcontext (i.e. we can't assert that we're in the sentinel context like
|
||||
@@ -433,14 +392,6 @@ class Clock:
|
||||
# callWhenRunning should be called.
|
||||
self._reactor.callWhenRunning(wrapped_callback, *args, **kwargs) # type: ignore[prefer-synapse-clock-call-when-running]
|
||||
|
||||
logger.debug(
|
||||
"call_when_running(%s): Scheduled call",
|
||||
instance_id,
|
||||
# Find out who is scheduling the call which makes it easy to follow in the
|
||||
# logs.
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
def add_system_event_trigger(
|
||||
self,
|
||||
phase: str,
|
||||
@@ -466,16 +417,8 @@ class Clock:
|
||||
Returns:
|
||||
an ID that can be used to remove this call with `reactor.removeSystemEventTrigger`.
|
||||
"""
|
||||
instance_id = random_string_insecure_fast(5)
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
logger.debug(
|
||||
"add_system_event_trigger(%s): Executing %s %s callback",
|
||||
instance_id,
|
||||
phase,
|
||||
event_type,
|
||||
)
|
||||
|
||||
assert context.current_context() is context.SENTINEL_CONTEXT, (
|
||||
"Expected `add_system_event_trigger` callback from the reactor to start with the sentinel logcontext "
|
||||
f"but saw {context.current_context()}. In other words, another task shouldn't have "
|
||||
@@ -506,16 +449,6 @@ class Clock:
|
||||
# logcontext to the reactor
|
||||
context.run_in_background(callback, *args, **kwargs)
|
||||
|
||||
logger.debug(
|
||||
"add_system_event_trigger(%s) for %s %s",
|
||||
instance_id,
|
||||
phase,
|
||||
event_type,
|
||||
# Find out who is scheduling the call which makes it easy to follow in the
|
||||
# logs.
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
# We can ignore the lint here since this class is the one location
|
||||
# `addSystemEventTrigger` should be called.
|
||||
return self._reactor.addSystemEventTrigger(
|
||||
|
||||
@@ -28,7 +28,6 @@ from synapse.types import JsonDict
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
PATH_PREFIX = "/_matrix/client/unstable/org.matrix.msc4140/delayed_events"
|
||||
@@ -128,10 +127,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
def test_get_delayed_events_auth(self) -> None:
|
||||
channel = self.make_request("GET", PATH_PREFIX)
|
||||
self.assertEqual(HTTPStatus.UNAUTHORIZED, channel.code, channel.result)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||
)
|
||||
@@ -159,6 +154,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/",
|
||||
access_token=self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||
|
||||
@@ -166,6 +162,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
access_token=self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
@@ -178,6 +175,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
@@ -190,6 +188,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{"action": "oops"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(
|
||||
@@ -197,21 +196,17 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
channel.json_body["errcode"],
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
(
|
||||
(action, action_in_path)
|
||||
for action in ("cancel", "restart", "send")
|
||||
for action_in_path in (True, False)
|
||||
@parameterized.expand(["cancel", "restart", "send"])
|
||||
def test_update_delayed_event_without_match(self, action: str) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/abc",
|
||||
{"action": action},
|
||||
self.user1_access_token,
|
||||
)
|
||||
)
|
||||
def test_update_delayed_event_without_match(
|
||||
self, action: str, action_in_path: bool
|
||||
) -> None:
|
||||
channel = self._update_delayed_event("abc", action, action_in_path)
|
||||
self.assertEqual(HTTPStatus.NOT_FOUND, channel.code, channel.result)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
def test_cancel_delayed_state_event(self, action_in_path: bool) -> None:
|
||||
def test_cancel_delayed_state_event(self) -> None:
|
||||
state_key = "to_never_send"
|
||||
|
||||
setter_key = "setter"
|
||||
@@ -226,7 +221,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
assert delay_id is not None
|
||||
self.assertIsNotNone(delay_id)
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
@@ -241,7 +236,12 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self._update_delayed_event(delay_id, "cancel", action_in_path)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "cancel"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
|
||||
@@ -254,11 +254,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
@unittest.override_config(
|
||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||
)
|
||||
def test_cancel_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||
def test_cancel_delayed_event_ratelimit(self) -> None:
|
||||
delay_ids = []
|
||||
for _ in range(2):
|
||||
channel = self.make_request(
|
||||
@@ -269,17 +268,38 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
assert delay_id is not None
|
||||
self.assertIsNotNone(delay_id)
|
||||
delay_ids.append(delay_id)
|
||||
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "cancel"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "cancel", action_in_path)
|
||||
args = (
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "cancel"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
def test_send_delayed_state_event(self, action_in_path: bool) -> None:
|
||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
||||
self.user1_user_id, 0, 0
|
||||
)
|
||||
)
|
||||
|
||||
# Test that the request isn't ratelimited anymore.
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
def test_send_delayed_state_event(self) -> None:
|
||||
state_key = "to_send_on_request"
|
||||
|
||||
setter_key = "setter"
|
||||
@@ -294,7 +314,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
assert delay_id is not None
|
||||
self.assertIsNotNone(delay_id)
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
@@ -309,7 +329,12 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self._update_delayed_event(delay_id, "send", action_in_path)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "send"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertListEqual([], self._get_delayed_events())
|
||||
content = self.helper.get_state(
|
||||
@@ -320,9 +345,8 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
@unittest.override_config({"rc_message": {"per_second": 2.5, "burst_count": 3}})
|
||||
def test_send_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||
@unittest.override_config({"rc_message": {"per_second": 3.5, "burst_count": 4}})
|
||||
def test_send_delayed_event_ratelimit(self) -> None:
|
||||
delay_ids = []
|
||||
for _ in range(2):
|
||||
channel = self.make_request(
|
||||
@@ -333,17 +357,38 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
assert delay_id is not None
|
||||
self.assertIsNotNone(delay_id)
|
||||
delay_ids.append(delay_id)
|
||||
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "send"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
channel = self._update_delayed_event(delay_ids.pop(0), "send", action_in_path)
|
||||
args = (
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "send"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
def test_restart_delayed_state_event(self, action_in_path: bool) -> None:
|
||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
||||
self.user1_user_id, 0, 0
|
||||
)
|
||||
)
|
||||
|
||||
# Test that the request isn't ratelimited anymore.
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
def test_restart_delayed_state_event(self) -> None:
|
||||
state_key = "to_send_on_restarted_timeout"
|
||||
|
||||
setter_key = "setter"
|
||||
@@ -358,7 +403,7 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
assert delay_id is not None
|
||||
self.assertIsNotNone(delay_id)
|
||||
|
||||
self.reactor.advance(1)
|
||||
events = self._get_delayed_events()
|
||||
@@ -373,7 +418,12 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
expect_code=HTTPStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
channel = self._update_delayed_event(delay_id, "restart", action_in_path)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_id}",
|
||||
{"action": "restart"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
self.reactor.advance(1)
|
||||
@@ -399,11 +449,10 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(setter_expected, content.get(setter_key), content)
|
||||
|
||||
@parameterized.expand((True, False))
|
||||
@unittest.override_config(
|
||||
{"rc_delayed_event_mgmt": {"per_second": 0.5, "burst_count": 1}}
|
||||
)
|
||||
def test_restart_delayed_event_ratelimit(self, action_in_path: bool) -> None:
|
||||
def test_restart_delayed_event_ratelimit(self) -> None:
|
||||
delay_ids = []
|
||||
for _ in range(2):
|
||||
channel = self.make_request(
|
||||
@@ -414,19 +463,37 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
delay_id = channel.json_body.get("delay_id")
|
||||
assert delay_id is not None
|
||||
self.assertIsNotNone(delay_id)
|
||||
delay_ids.append(delay_id)
|
||||
|
||||
channel = self._update_delayed_event(
|
||||
delay_ids.pop(0), "restart", action_in_path
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "restart"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
channel = self._update_delayed_event(
|
||||
delay_ids.pop(0), "restart", action_in_path
|
||||
args = (
|
||||
"POST",
|
||||
f"{PATH_PREFIX}/{delay_ids.pop(0)}",
|
||||
{"action": "restart"},
|
||||
self.user1_access_token,
|
||||
)
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.TOO_MANY_REQUESTS, channel.code, channel.result)
|
||||
|
||||
# Add the current user to the ratelimit overrides, allowing them no ratelimiting.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_ratelimit_for_user(
|
||||
self.user1_user_id, 0, 0
|
||||
)
|
||||
)
|
||||
|
||||
# Test that the request isn't ratelimited anymore.
|
||||
channel = self.make_request(*args)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
|
||||
def test_delayed_state_is_not_cancelled_by_new_state_from_same_user(
|
||||
self,
|
||||
) -> None:
|
||||
@@ -531,17 +598,6 @@ class DelayedEventsTestCase(HomeserverTestCase):
|
||||
|
||||
return content
|
||||
|
||||
def _update_delayed_event(
|
||||
self, delay_id: str, action: str, action_in_path: bool
|
||||
) -> FakeChannel:
|
||||
path = f"{PATH_PREFIX}/{delay_id}"
|
||||
body = {}
|
||||
if action_in_path:
|
||||
path += f"/{action}"
|
||||
else:
|
||||
body["action"] = action
|
||||
return self.make_request("POST", path, body)
|
||||
|
||||
|
||||
def _get_path_for_delayed_state(
|
||||
room_id: str, event_type: str, state_key: str, delay_ms: int
|
||||
|
||||
Reference in New Issue
Block a user