Compare commits

...

19 Commits

Author SHA1 Message Date
Eric Eastwood
41874bb8c8 Add changelog 2025-09-19 16:27:21 -05:00
Eric Eastwood
4fe5413b41 Fix add_system_event_trigger lints 2025-09-19 16:24:22 -05:00
Eric Eastwood
a906d1aa0b Add lint to prefer Clock.add_system_event_trigger 2025-09-19 16:22:25 -05:00
Eric Eastwood
2432390bf2 Add Clock.add_system_event_trigger(...) 2025-09-19 15:51:14 -05:00
Eric Eastwood
4a9730ee76 Fix wrong config typo 2025-09-19 15:36:42 -05:00
Eric Eastwood
19a4f8c741 Fix lints 2025-09-19 15:20:20 -05:00
Eric Eastwood
62f93ded1f Better changelog 2025-09-19 15:13:46 -05:00
Eric Eastwood
7b83c9fcbc Add changelog 2025-09-19 15:11:58 -05:00
Eric Eastwood
6938134f7d We can't assume sentinel context when callback is called 2025-09-19 15:05:44 -05:00
Eric Eastwood
0f8076ab41 Fix copy-paste typos 2025-09-19 15:00:30 -05:00
Eric Eastwood
731f36e131 Remove manual LoggingContext 2025-09-19 14:59:12 -05:00
Eric Eastwood
12e7cf4487 Fix reactor.callWhenRunning lints 2025-09-19 14:57:48 -05:00
Eric Eastwood
f1e6887d14 Add lint to prefer Clock.call_when_running 2025-09-19 14:28:30 -05:00
Eric Eastwood
2d1331f0e4 Add Clock.call_when_running(...) 2025-09-19 14:03:50 -05:00
Eric Eastwood
bccd224489 Fix imports 2025-09-18 21:34:52 -05:00
Eric Eastwood
e45dd4f03e Split out JSON and Clock utilities to avoid circular imports 2025-09-18 21:18:49 -05:00
Max Kratz
4367fb2d07 OIDC doc: adds missing jwt_config values to authentik example (#18931)
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2025-09-18 15:05:41 +01:00
Andrew Morgan
b596faa4ec Cache _get_e2e_cross_signing_signatures_for_devices (#18899) 2025-09-18 12:06:08 +01:00
Eric Eastwood
6f9fab1089 Fix open redirect in legacy SSO flow (idp) (#18909)
- Validate the `idp` parameter to only accept the ones that are known in
the config file
- URL-encode the `idp` parameter for safety's sake (this is the main
fix)

Fix https://github.com/matrix-org/internal-config/issues/1651 (internal
link)

Regressed in https://github.com/element-hq/synapse/pull/17972
2025-09-17 13:54:47 -05:00
304 changed files with 1077 additions and 619 deletions

View File

@@ -0,0 +1 @@
Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load.

1
changelog.d/18909.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix open redirect in legacy SSO flow with the `idp` query parameter.

2
changelog.d/18931.doc Normal file
View File

@@ -0,0 +1,2 @@
Clarify necessary `jwt_config` parameter in OIDC documentation for authentik.
Contributed by @maxkratz.

1
changelog.d/18944.misc Normal file
View File

@@ -0,0 +1 @@
Introduce `Clock.call_when_running(...)` to wrap startup code in a logcontext, ensuring we can identify which server generated the logs.

1
changelog.d/18945.misc Normal file
View File

@@ -0,0 +1 @@
Introduce `Clock.add_system_event_trigger(...)` to wrap system event callback code in a logcontext, ensuring we can identify which server generated the logs.

View File

@@ -186,6 +186,7 @@ oidc_providers:
4. Note the slug of your application, Client ID and Client Secret.
Note: RSA keys must be used for signing for Authentik, ECC keys do not work.
Note: The provider must have a signing key set and must not use an encryption key.
Synapse config:
```yaml
@@ -204,6 +205,12 @@ oidc_providers:
config:
localpart_template: "{{ user.preferred_username }}"
display_name_template: "{{ user.preferred_username|capitalize }}" # TO BE FILLED: If your users have names in Authentik and you want those in Synapse, this should be replaced with user.name|capitalize.
[...]
jwt_config:
enabled: true
secret: "your client secret" # TO BE FILLED (same as `client_secret` above)
algorithm: "RS256"
# (...other fields)
```
### Dex

View File

@@ -68,6 +68,18 @@ PROMETHEUS_METRIC_MISSING_FROM_LIST_TO_CHECK = ErrorCode(
category="per-homeserver-tenant-metrics",
)
PREFER_SYNAPSE_CLOCK_CALL_WHEN_RUNNING = ErrorCode(
"prefer-synapse-clock-call-when-running",
"`synapse.util.Clock.call_when_running` should be used instead of `reactor.callWhenRunning`",
category="synapse-reactor-clock",
)
PREFER_SYNAPSE_CLOCK_ADD_SYSTEM_EVENT_TRIGGER = ErrorCode(
"prefer-synapse-clock-add-system-event-trigger",
"`synapse.util.Clock.add_system_event_trigger` should be used instead of `reactor.addSystemEventTrigger`",
category="synapse-reactor-clock",
)
class Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
@@ -229,9 +241,77 @@ class SynapsePlugin(Plugin):
):
return check_is_cacheable_wrapper
if fullname in (
"twisted.internet.interfaces.IReactorCore.callWhenRunning",
"synapse.types.ISynapseThreadlessReactor.callWhenRunning",
"synapse.types.ISynapseReactor.callWhenRunning",
):
return check_call_when_running
if fullname in (
"twisted.internet.interfaces.IReactorCore.addSystemEventTrigger",
"synapse.types.ISynapseThreadlessReactor.addSystemEventTrigger",
"synapse.types.ISynapseReactor.addSystemEventTrigger",
):
return check_add_system_event_trigger
return None
def check_call_when_running(ctx: MethodSigContext) -> CallableType:
"""
Ensure that the `reactor.callWhenRunning` callsites aren't used.
`synapse.util.Clock.call_when_running` should always be used instead of
`reactor.callWhenRunning`.
Since `reactor.callWhenRunning` is a reactor callback, the callback will start out
with the sentinel logcontext. `synapse.util.Clock` starts a default logcontext as we
want to know which server the logs came from.
Args:
ctx: The `FunctionSigContext` from mypy.
"""
signature: CallableType = ctx.default_signature
ctx.api.fail(
(
"Expected all `reactor.callWhenRunning` calls to use `synapse.util.Clock.call_when_running` instead. "
"This is so all Synapse code runs with a logcontext as we want to know which server the logs came from."
),
ctx.context,
code=PREFER_SYNAPSE_CLOCK_CALL_WHEN_RUNNING,
)
return signature
def check_add_system_event_trigger(ctx: MethodSigContext) -> CallableType:
"""
Ensure that the `reactor.addSystemEventTrigger` callsites aren't used.
`synapse.util.Clock.add_system_event_trigger` should always be used instead of
`reactor.addSystemEventTrigger`.
Since `reactor.addSystemEventTrigger` is a reactor callback, the callback will start out
with the sentinel logcontext. `synapse.util.Clock` starts a default logcontext as we
want to know which server the logs came from.
Args:
ctx: The `FunctionSigContext` from mypy.
"""
signature: CallableType = ctx.default_signature
ctx.api.fail(
(
"Expected all `reactor.addSystemEventTrigger` calls to use `synapse.util.Clock.add_system_event_trigger` instead. "
"This is so all Synapse code runs with a logcontext as we want to know which server the logs came from."
),
ctx.context,
code=PREFER_SYNAPSE_CLOCK_ADD_SYSTEM_EVENT_TRIGGER,
)
return signature
def analyze_prometheus_metric_classes(ctx: ClassDefContext) -> None:
"""
Cross-check the list of Prometheus metric classes against the

View File

@@ -30,7 +30,7 @@ from signedjson.sign import sign_json
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.util import json_encoder
from synapse.util.json import json_encoder
def main() -> None:

View File

@@ -54,11 +54,11 @@ from twisted.internet import defer, reactor as reactor_
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.context import (
LoggingContext,
make_deferred_yieldable,
run_in_background,
)
from synapse.notifier import ReplicationNotifier
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import FilteringWorkerStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -98,8 +98,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import SYNAPSE_VERSION, Clock
from synapse.util.stringutils import random_string
from synapse.util import SYNAPSE_VERSION
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
@@ -318,31 +317,16 @@ class Store(
)
class MockHomeserver:
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
def __init__(self, config: HomeServerConfig):
self.clock = Clock(reactor)
self.config = config
self.hostname = config.server.server_name
self.version_string = SYNAPSE_VERSION
self.instance_id = random_string(5)
def get_clock(self) -> Clock:
return self.clock
def get_reactor(self) -> ISynapseReactor:
return reactor
def get_instance_id(self) -> str:
return self.instance_id
def get_instance_name(self) -> str:
return "master"
def should_send_federation(self) -> bool:
return False
def get_replication_notifier(self) -> ReplicationNotifier:
return ReplicationNotifier()
super().__init__(
hostname=config.server.server_name,
config=config,
reactor=reactor,
version_string=f"Synapse/{SYNAPSE_VERSION}",
)
class Porter:
@@ -351,12 +335,12 @@ class Porter:
sqlite_config: Dict[str, Any],
progress: "Progress",
batch_size: int,
hs_config: HomeServerConfig,
hs: HomeServer,
):
self.sqlite_config = sqlite_config
self.progress = progress
self.batch_size = batch_size
self.hs_config = hs_config
self.hs = hs
async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
if table in APPEND_ONLY_TABLES:
@@ -676,8 +660,7 @@ class Porter:
engine = create_engine(db_config.config)
hs = MockHomeserver(self.hs_config)
server_name = hs.hostname
server_name = self.hs.hostname
with make_conn(
db_config=db_config,
@@ -688,16 +671,16 @@ class Porter:
engine.check_database(
db_conn, allow_outdated_version=allow_outdated_version
)
prepare_database(db_conn, engine, config=self.hs_config)
prepare_database(db_conn, engine, config=self.hs.config)
# Type safety: ignore that we're using Mock homeservers here.
store = Store(
DatabasePool(
hs, # type: ignore[arg-type]
self.hs,
db_config,
engine,
),
db_conn,
hs, # type: ignore[arg-type]
self.hs,
)
db_conn.commit()
@@ -795,7 +778,7 @@ class Porter:
return
self.postgres_store = self.build_db_store(
self.hs_config.database.get_single_database()
self.hs.config.database.get_single_database()
)
await self.remove_ignored_background_updates_from_database()
@@ -1584,6 +1567,8 @@ def main() -> None:
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
hs = MockHomeserver(config)
def start(stdscr: Optional["curses.window"] = None) -> None:
progress: Progress
if stdscr:
@@ -1595,15 +1580,14 @@ def main() -> None:
sqlite_config=sqlite_config,
progress=progress,
batch_size=args.batch_size,
hs_config=config,
hs=hs,
)
@defer.inlineCallbacks
def run() -> Generator["defer.Deferred[Any]", Any, None]:
with LoggingContext("synapse_port_db_run"):
yield defer.ensureDeferred(porter.run())
yield defer.ensureDeferred(porter.run())
reactor.callWhenRunning(run)
hs.get_clock().call_when_running(run)
reactor.run()

View File

@@ -74,7 +74,7 @@ def run_background_updates(hs: HomeServer) -> None:
)
)
reactor.callWhenRunning(run)
hs.get_clock().call_when_running(run)
reactor.run()

View File

@@ -43,9 +43,9 @@ from synapse.logging.opentracing import (
from synapse.metrics import SERVER_NAME_LABEL
from synapse.synapse_rust.http_client import HttpClient
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from synapse.util.json import json_decoder
from . import introspection_response_timer

View File

@@ -48,9 +48,9 @@ from synapse.logging.opentracing import (
from synapse.metrics import SERVER_NAME_LABEL
from synapse.synapse_rust.http_client import HttpClient
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from synapse.util.json import json_decoder
from . import introspection_response_timer

View File

@@ -30,7 +30,7 @@ from typing import Any, Dict, List, Optional, Union
from twisted.web import http
from synapse.util import json_decoder
from synapse.util.json import json_decoder
if typing.TYPE_CHECKING:
from synapse.config.homeserver import HomeServerConfig

View File

@@ -26,7 +26,7 @@ from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import RatelimitSettings
from synapse.storage.databases.main import DataStore
from synapse.types import Requester
from synapse.util import Clock
from synapse.util.clock import Clock
if TYPE_CHECKING:
# To avoid circular imports:

View File

@@ -22,6 +22,7 @@
"""Contains the URL paths to prefix various aspects of the server with."""
import hmac
import urllib.parse
from hashlib import sha256
from typing import Optional
from urllib.parse import urlencode, urljoin
@@ -96,11 +97,21 @@ class LoginSSORedirectURIBuilder:
serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url})
if idp_id:
# Since this is a user-controlled string, make it safe to include in a URL path.
url_encoded_idp_id = urllib.parse.quote(
idp_id,
# Since this defaults to `safe="/"`, we have to override it. We're
# working with an individual URL path parameter so there shouldn't be
# any slashes in it which could change the request path.
safe="",
encoding="utf8",
)
resultant_url = urljoin(
# We have to add a trailing slash to the base URL to ensure that the
# last path segment is not stripped away when joining with another path.
f"{base_url}/",
f"{idp_id}?{serialized_query_parameters}",
f"{url_encoded_idp_id}?{serialized_query_parameters}",
)
else:
resultant_url = f"{base_url}?{serialized_query_parameters}"

View File

@@ -241,7 +241,7 @@ def redirect_stdio_to_logs() -> None:
def register_start(
cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
hs: "HomeServer", cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Register a callback with the reactor, to be called once it is running
@@ -278,7 +278,8 @@ def register_start(
# on as normal.
os._exit(1)
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
clock = hs.get_clock()
clock.call_when_running(lambda: defer.ensureDeferred(wrapper()))
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
@@ -517,7 +518,9 @@ async def start(hs: "HomeServer") -> None:
# numbers of DNS requests don't starve out other users of the threadpool.
resolver_threadpool = ThreadPool(name="gai_resolver")
resolver_threadpool.start()
reactor.addSystemEventTrigger("during", "shutdown", resolver_threadpool.stop)
hs.get_clock().add_system_event_trigger(
"during", "shutdown", resolver_threadpool.stop
)
reactor.installNameResolver(
GAIResolver(reactor, getThreadPool=lambda: resolver_threadpool)
)
@@ -604,7 +607,7 @@ async def start(hs: "HomeServer") -> None:
logger.info("Shutting down...")
# Log when we start the shut down process.
hs.get_reactor().addSystemEventTrigger("before", "shutdown", log_shutdown)
hs.get_clock().add_system_event_trigger("before", "shutdown", log_shutdown)
setup_sentry(hs)
setup_sdnotify(hs)
@@ -719,7 +722,7 @@ def setup_sdnotify(hs: "HomeServer") -> None:
# we're not using systemd.
sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),))
hs.get_reactor().addSystemEventTrigger(
hs.get_clock().add_system_event_trigger(
"before", "shutdown", sdnotify, b"STOPPING=1"
)

View File

@@ -356,11 +356,9 @@ def start(config_options: List[str]) -> None:
handle_startup_exception(e)
async def start() -> None:
# Re-establish log context now that we're back from the reactor
with LoggingContext("start"):
await _base.start(hs)
await _base.start(hs)
register_start(start)
register_start(hs, start)
# redirect stdio to the logs, if configured.
if not hs.config.logging.no_redirect_stdio:

View File

@@ -377,19 +377,17 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
handle_startup_exception(e)
async def start() -> None:
# Re-establish log context now that we're back from the reactor
with LoggingContext("start"):
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
await _base.start(hs)
await _base.start(hs)
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
register_start(start)
register_start(hs, start)
return hs

View File

@@ -84,7 +84,7 @@ from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import Clock
from synapse.util.clock import Clock
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -38,7 +38,7 @@ from synapse.storage.databases.main import DataStore
from synapse.synapse_rust.events import EventInternalMetadata
from synapse.types import EventID, JsonDict, StrCollection
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.clock import Clock
from synapse.util.stringutils import random_string
if TYPE_CHECKING:

View File

@@ -178,7 +178,7 @@ from synapse.types import (
StrCollection,
get_domain_from_id,
)
from synapse.util import Clock
from synapse.util.clock import Clock
from synapse.util.metrics import Measure
from synapse.util.retryutils import filter_destinations_by_retry_limiter

View File

@@ -36,7 +36,7 @@ from synapse.logging.opentracing import (
)
from synapse.metrics import SERVER_NAME_LABEL
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.json import json_decoder
from synapse.util.metrics import measure_func
if TYPE_CHECKING:

View File

@@ -62,7 +62,7 @@ class DeactivateAccountHandler:
# Start the user parter loop so it can resume parting users from rooms where
# it left off (if it has work left to do).
if hs.config.worker.worker_app is None:
hs.get_reactor().callWhenRunning(self._start_user_parting)
hs.get_clock().call_when_running(self._start_user_parting)
else:
self._notify_account_deactivated_client = (
ReplicationNotifyAccountDeactivatedServlet.make_client(hs)

View File

@@ -1002,7 +1002,7 @@ class DeviceWriterHandler(DeviceHandler):
# rolling-restarting Synapse.
if self._is_main_device_list_writer:
# On start up check if there are any updates pending.
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
hs.get_clock().call_when_running(self._handle_new_device_update_async)
self.device_list_updater = DeviceListUpdater(hs, self)
hs.get_federation_registry().register_edu_handler(
EduTypes.DEVICE_LIST_UPDATE,

View File

@@ -34,7 +34,7 @@ from synapse.logging.opentracing import (
set_tag,
)
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.json import json_encoder
from synapse.util.stringutils import random_string
if TYPE_CHECKING:

View File

@@ -44,9 +44,9 @@ from synapse.types import (
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.cancellation import cancellable
from synapse.util.json import json_decoder
from synapse.util.retryutils import (
NotRetryingDestination,
filter_destinations_by_retry_limiter,

View File

@@ -39,8 +39,8 @@ from synapse.http import RequestTimedOutError
from synapse.http.client import SimpleHttpClient
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.json import json_decoder
from synapse.util.stringutils import (
assert_valid_client_secret,
random_string,

View File

@@ -81,9 +81,10 @@ from synapse.types import (
create_requester,
)
from synapse.types.state import StateFilter
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util import log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.json import json_decoder, json_encoder
from synapse.util.metrics import measure_func
from synapse.visibility import get_effective_room_visibility_from_state

View File

@@ -67,8 +67,9 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.clock import Clock
from synapse.util.json import json_decoder
from synapse.util.macaroons import MacaroonGenerator, OidcSessionData
from synapse.util.templates import _localpart_from_email_filter

View File

@@ -541,7 +541,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
hs.get_reactor().addSystemEventTrigger(
hs.get_clock().add_system_event_trigger(
"before",
"shutdown",
run_as_background_process,
@@ -842,7 +842,7 @@ class PresenceHandler(BasePresenceHandler):
# have not yet been persisted
self.unpersisted_users_changes: Set[str] = set()
hs.get_reactor().addSystemEventTrigger(
hs.get_clock().add_system_event_trigger(
"before",
"shutdown",
run_as_background_process,

View File

@@ -27,7 +27,7 @@ from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.util import json_decoder
from synapse.util.json import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -87,8 +87,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.metrics import SERVER_NAME_LABEL
from synapse.types import ISynapseReactor, StrSequence
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.json import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -49,7 +49,7 @@ from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
from synapse.util import Clock
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)

View File

@@ -27,7 +27,6 @@ from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
@@ -35,8 +34,10 @@ from twisted.web.iweb import IAgent, IResponse
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
from synapse.types import ISynapseThreadlessReactor
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.clock import Clock
from synapse.util.json import json_decoder
from synapse.util.metrics import Measure
# period to cache .well-known results for by default
@@ -88,7 +89,7 @@ class WellKnownResolver:
def __init__(
self,
server_name: str,
reactor: IReactorTime,
reactor: ISynapseThreadlessReactor,
agent: IAgent,
user_agent: bytes,
well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,

View File

@@ -89,8 +89,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.metrics import SERVER_NAME_LABEL
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred
from synapse.util.json import json_decoder
from synapse.util.metrics import Measure
from synapse.util.stringutils import parse_and_validate_server_name

View File

@@ -52,10 +52,11 @@ from zope.interface import implementer
from twisted.internet import defer, interfaces, reactor
from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
from twisted.web import resource
from synapse.types import ISynapseThreadlessReactor
try:
from twisted.web.pages import notFound
except ImportError:
@@ -77,10 +78,11 @@ from synapse.api.errors import (
from synapse.config.homeserver import HomeServerConfig
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import Clock, json_encoder
from synapse.util.caches import intern_dict
from synapse.util.cancellation import is_function_cancellable
from synapse.util.clock import Clock
from synapse.util.iterutils import chunk_seq
from synapse.util.json import json_encoder
if TYPE_CHECKING:
import opentracing
@@ -410,7 +412,7 @@ class DirectServeJsonResource(_AsyncResource):
clock: Optional[Clock] = None,
):
if clock is None:
clock = Clock(cast(IReactorTime, reactor))
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
super().__init__(clock, extract_context)
self.canonical_json = canonical_json
@@ -589,7 +591,7 @@ class DirectServeHtmlResource(_AsyncResource):
clock: Optional[Clock] = None,
):
if clock is None:
clock = Clock(cast(IReactorTime, reactor))
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
super().__init__(clock, extract_context)

View File

@@ -51,7 +51,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http import redact_uri
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
from synapse.util import json_decoder
from synapse.util.json import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -60,8 +60,18 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
else:
reactor_to_use = reactor
# call our hook when the reactor start up
reactor_to_use.callWhenRunning(on_reactor_running)
# Call our hook when the reactor start up
#
# type-ignore: Ideally, we'd use `Clock.call_when_running(...)`, but
# `PeriodicallyFlushingMemoryHandler` is instantiated via Python logging
# configuration, so it's not straightforward to pass in the homeserver's clock
# (and we don't want to burden other peoples logging config with the details).
#
# The important reason why we want to use `Clock.call_when_running` is so that
# the callback runs with a logcontext as we want to know which server the logs
# came from. But since we don't log anything in the callback, it's safe to
# ignore the lint here.
reactor_to_use.callWhenRunning(on_reactor_running) # type: ignore[prefer-synapse-clock-call-when-running]
def shouldFlush(self, record: LogRecord) -> bool:
"""

View File

@@ -204,7 +204,7 @@ from twisted.web.http import Request
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
from synapse.util import json_decoder, json_encoder
from synapse.util.json import json_decoder, json_encoder
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest

View File

@@ -54,8 +54,8 @@ from synapse.logging.context import (
make_deferred_yieldable,
run_in_background,
)
from synapse.util import Clock
from synapse.util.async_helpers import DeferredEvent
from synapse.util.clock import Clock
from synapse.util.stringutils import is_ascii
if TYPE_CHECKING:

View File

@@ -55,7 +55,7 @@ from synapse.api.errors import NotFoundError
from synapse.logging.context import defer_to_thread, run_in_background
from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
from synapse.media._base import ThreadedFileSender
from synapse.util import Clock
from synapse.util.clock import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
from ..types import JsonDict

View File

@@ -27,7 +27,7 @@ import attr
from synapse.media.preview_html import parse_html_description
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.json import json_decoder
if TYPE_CHECKING:
from lxml import etree

View File

@@ -46,9 +46,9 @@ from synapse.media.oembed import OEmbedProvider
from synapse.media.preview_html import decode_body, parse_html_to_open_graph
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.json import json_encoder
from synapse.util.stringutils import random_string
if TYPE_CHECKING:

View File

@@ -158,9 +158,9 @@ from synapse.types import (
create_requester,
)
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.caches.descriptors import CachedFunction, cached as _cached
from synapse.util.clock import Clock
from synapse.util.frozenutils import freeze
if TYPE_CHECKING:

View File

@@ -29,7 +29,7 @@ import logging
from typing import List, Optional, Tuple, Type, TypeVar
from synapse.replication.tcp.streams._base import StreamRow
from synapse.util import json_decoder, json_encoder
from synapse.util.json import json_decoder, json_encoder
logger = logging.getLogger(__name__)

View File

@@ -27,7 +27,7 @@ from prometheus_client import Counter, Histogram
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable
from synapse.metrics import SERVER_NAME_LABEL
from synapse.util import json_decoder, json_encoder
from synapse.util.json import json_decoder, json_encoder
if TYPE_CHECKING:
from txredisapi import ConnectionHandler

View File

@@ -55,7 +55,7 @@ from synapse.replication.tcp.commands import (
ServerCommand,
parse_command_from_line,
)
from synapse.util import Clock
from synapse.util.clock import Clock
from synapse.util.stringutils import random_string
if TYPE_CHECKING:

View File

@@ -58,8 +58,8 @@ from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
from synapse.types.rest.client import SlidingSyncBody
from synapse.util import json_decoder
from synapse.util.caches.lrucache import LruCache
from synapse.util.json import json_decoder
from ._base import client_patterns, set_timeline_upper_limit

View File

@@ -38,8 +38,8 @@ from synapse.http.servlet import (
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.types.rest import RequestBodyModel
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.json import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -63,6 +63,22 @@ class PickIdpResource(DirectServeHtmlResource):
if not idp:
return await self._serve_id_picker(request, client_redirect_url)
# Validate the `idp` query parameter. We should only be working with known IdPs.
# No need waste further effort if we don't know about it.
#
# Although, we primarily prevent open redirect attacks by URL encoding all of
# the parameters we use in the redirect URL below, this validation also helps
# prevent Synapse from crafting arbitrary URLs and being used in open redirect
# attacks (defense in depth).
providers = self._sso_handler.get_identity_providers()
auth_provider = providers.get(idp)
if not auth_provider:
logger.info("Unknown idp %r", idp)
self._sso_handler.render_error(
request, "unknown_idp", "Unknown identity provider ID"
)
return
# Otherwise, redirect to the login SSO redirect endpoint for the given IdP
# (which will in turn take us to the the IdP's redirect URI).
#

View File

@@ -28,7 +28,7 @@ from synapse.api.errors import NotFoundError
from synapse.http.server import DirectServeJsonResource
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.json import json_encoder
from synapse.util.stringutils import parse_server_name
if TYPE_CHECKING:

View File

@@ -156,7 +156,7 @@ from synapse.storage.controllers import StorageControllers
from synapse.streams.events import EventSources
from synapse.synapse_rust.rendezvous import RendezvousHandler
from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
from synapse.util.clock import Clock
from synapse.util.distributor import Distributor
from synapse.util.macaroons import MacaroonGenerator
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -1007,7 +1007,7 @@ class HomeServer(metaclass=abc.ABCMeta):
)
media_threadpool.start()
self.get_reactor().addSystemEventTrigger(
self.get_clock().add_system_event_trigger(
"during", "shutdown", media_threadpool.stop
)

View File

@@ -29,8 +29,8 @@ from synapse.storage.database import (
make_in_list_sql_clause, # noqa: F401
)
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
from synapse.util.caches.descriptors import CachedFunction
from synapse.util.json import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -45,7 +45,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection, Cursor
from synapse.types import JsonDict, StrCollection
from synapse.util import Clock, json_encoder
from synapse.util.clock import Clock
from synapse.util.json import json_encoder
from . import engines

View File

@@ -2653,8 +2653,7 @@ def make_in_list_sql_clause(
# These overloads ensure that `columns` and `iterable` values have the same length.
# Suppress "Single overload definition, multiple required" complaint.
@overload # type: ignore[misc]
@overload
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str],
@@ -2662,6 +2661,14 @@ def make_tuple_in_list_sql_clause(
) -> Tuple[str, list]: ...
@overload
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, str, str],
iterable: Collection[Tuple[Any, Any, Any]],
) -> Tuple[str, list]: ...
def make_tuple_in_list_sql_clause(
database_engine: BaseDatabaseEngine,
columns: Tuple[str, ...],

View File

@@ -48,9 +48,9 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
from synapse.storage.invite_rule import InviteRulesConfig
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -42,8 +42,8 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -21,6 +21,7 @@
import itertools
import json
import logging
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
@@ -62,6 +63,12 @@ PURGE_HISTORY_CACHE_NAME = "ph_cache_fake"
# As above, but for invalidating room caches on room deletion
DELETE_ROOM_CACHE_NAME = "dr_cache_fake"
# This cache takes a list of tuples as its first argument, which requires
# special handling.
GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = (
"_get_e2e_cross_signing_signatures_for_device"
)
# How long between cache invalidation table cleanups, once we have caught up
# with the backlog.
REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h")
@@ -270,6 +277,33 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
# room membership.
#
# self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined]
elif (
row.cache_func
== GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME
):
# "keys" is a list of strings, where each string is a
# JSON-encoded representation of the tuple keys, i.e.
# keys: ['["@userid:domain", "DEVICEID"]','["@userid2:domain", "DEVICEID2"]']
#
# This is a side-effect of not being able to send nested
# information over replication.
for json_str in row.keys:
try:
user_id, device_id = json.loads(json_str)
except (json.JSONDecodeError, TypeError):
logger.error(
"Failed to deserialise cache key as valid JSON: %s",
json_str,
)
continue
# Invalidate each key.
#
# Note: .invalidate takes a tuple of arguments, hence the need
# to nest our tuple in another tuple.
self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined]
((user_id, device_id),)
)
else:
self._attempt_to_invalidate_cache(row.cache_func, row.keys)

View File

@@ -32,7 +32,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -455,7 +455,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
self.hs.get_clock().add_system_event_trigger(
"before", "shutdown", self._update_client_ips_batch
)

View File

@@ -22,7 +22,8 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction, StoreError
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomID
from synapse.util import json_encoder, stringutils as stringutils
from synapse.util import stringutils
from synapse.util.json import json_encoder
logger = logging.getLogger(__name__)

View File

@@ -53,10 +53,11 @@ from synapse.storage.database import (
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, StrCollection
from synapse.util import Duration, json_encoder
from synapse.util import Duration
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.json import json_encoder
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:

View File

@@ -64,11 +64,11 @@ from synapse.types import (
StrCollection,
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.json import json_decoder, json_encoder
from synapse.util.stringutils import shortstr
if TYPE_CHECKING:

View File

@@ -41,7 +41,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.types import JsonDict, JsonSerializable, StreamKeyType
from synapse.util import json_encoder
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -20,6 +20,7 @@
#
#
import abc
import json
from typing import (
TYPE_CHECKING,
Any,
@@ -60,10 +61,10 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, JsonMapping, MultiWriterStreamToken
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.json import json_decoder, json_encoder
if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
@@ -354,15 +355,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
cross_sigs_result = (
await self._get_e2e_cross_signing_signatures_for_devices(batch)
)
# add each cross-signing signature to the correct device in the result dict.
for user_id, key_id, device_id, signature in cross_sigs_result:
for (
user_id,
device_id,
), signature_list in cross_sigs_result.items():
target_device_result = result[user_id][device_id]
# We've only looked up cross-signatures for non-deleted devices with key
# data.
assert target_device_result is not None
@@ -373,7 +376,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
signing_user_signatures = target_device_signatures.setdefault(
user_id, {}
)
signing_user_signatures[key_id] = signature
for key_id, signature in signature_list:
signing_user_signatures[key_id] = signature
log_kv(result)
return result
@@ -479,41 +484,83 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return result
def _get_e2e_cross_signing_signatures_for_devices_txn(
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> List[Tuple[str, str, str, str]]:
"""Get cross-signing signatures for a given list of devices
Returns signatures made by the owners of the devices.
Returns: a list of results; each entry in the list is a tuple of
(user_id, key_id, target_device_id, signature).
@cached()
def _get_e2e_cross_signing_signatures_for_device(
self,
user_id_and_device_id: Tuple[str, str],
) -> Sequence[Tuple[str, str]]:
"""
signature_query_clauses = []
signature_query_params = []
The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
See @cachedList for why a separate method is needed.
"""
raise NotImplementedError()
for user_id, device_id in device_query:
signature_query_clauses.append(
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
@cachedList(
cached_method_name="_get_e2e_cross_signing_signatures_for_device",
list_name="device_query",
)
async def _get_e2e_cross_signing_signatures_for_devices(
self, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
"""Get cross-signing signatures for a given list of user IDs and devices.
Args:
An iterable containing tuples of (user ID, device ID).
Returns:
A mapping of results. The keys are the original (user_id, device_id)
tuple, while the value is the matching list of tuples of
(key_id, signature). The value will be an empty list if no
signatures exist for the device.
Given this method is annotated with `@cachedList`, the return dict's
keys match the tuples within `device_query`, so that cache entries can
be computed from the corresponding values.
As results are cached, the return type is immutable.
"""
def _get_e2e_cross_signing_signatures_for_devices_txn(
txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
where_clause_sql, where_clause_params = make_tuple_in_list_sql_clause(
self.database_engine,
columns=("target_user_id", "target_device_id", "user_id"),
iterable=[
(user_id, device_id, user_id) for user_id, device_id in device_query
],
)
signature_query_params.extend([user_id, device_id, user_id])
signature_sql = """
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE %s
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
signature_sql = f"""
SELECT user_id, key_id, target_device_id, signature
FROM e2e_cross_signing_signatures WHERE {where_clause_sql}
"""
txn.execute(signature_sql, signature_query_params)
return cast(
List[
Tuple[
str,
str,
str,
str,
]
],
txn.fetchall(),
txn.execute(signature_sql, where_clause_params)
devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
# `@cachedList` requires we return one key for every item in `device_query`.
# Pre-populate `devices_and_signatures` with each key so that none are missing.
#
# If any are missing, they will be cached as `None`, which is not
# what callers expected.
for user_id, device_id in device_query:
devices_and_signatures.setdefault((user_id, device_id), [])
# Populate the return dictionary with each found key_id and signature.
for user_id, key_id, target_device_id, signature in txn.fetchall():
signature_tuple = (key_id, signature)
devices_and_signatures[(user_id, target_device_id)].append(
signature_tuple
)
return devices_and_signatures
return await self.db_pool.runInteraction(
"_get_e2e_cross_signing_signatures_for_devices_txn",
_get_e2e_cross_signing_signatures_for_devices_txn,
device_query,
)
async def get_e2e_one_time_keys(
@@ -1772,26 +1819,71 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_id: the user who made the signatures
signatures: signatures to add
"""
await self.db_pool.simple_insert_many(
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)
def _store_e2e_cross_signing_signatures(
txn: LoggingTransaction,
signatures: "Iterable[SignatureListItem]",
) -> None:
self.db_pool.simple_insert_many_txn(
txn,
"e2e_cross_signing_signatures",
keys=(
"user_id",
"key_id",
"target_user_id",
"target_device_id",
"signature",
),
values=[
(
user_id,
item.signing_key_id,
item.target_user_id,
item.target_device_id,
item.signature,
)
for item in signatures
],
)
to_invalidate = [
# Each entry is a tuple of arguments to
# `_get_e2e_cross_signing_signatures_for_device`, which
# itself takes a tuple. Hence the double-tuple.
((user_id, item.target_device_id),)
for item in signatures
],
desc="add_e2e_signing_key",
]
if to_invalidate:
# Invalidate the local cache of this worker.
for cache_key in to_invalidate:
txn.call_after(
self._get_e2e_cross_signing_signatures_for_device.invalidate,
cache_key,
)
# Stream cache invalidate keys over replication.
#
# We can only send a primitive per function argument across
# replication.
#
# Encode the array of strings as a JSON string, and we'll unpack
# it on the other side.
to_send = [
(json.dumps([user_id, item.target_device_id]),)
for item in signatures
]
self._send_invalidation_to_replication_bulk(
txn,
cache_name=self._get_e2e_cross_signing_signatures_for_device.__name__,
key_tuples=to_send,
)
await self.db_pool.runInteraction(
"add_e2e_signing_key",
_store_e2e_cross_signing_signatures,
signatures,
)

View File

@@ -59,11 +59,11 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -107,8 +107,8 @@ from synapse.storage.database import (
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -83,9 +83,9 @@ from synapse.types import (
)
from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
from synapse.types.state import StateFilter
from synapse.util import json_encoder
from synapse.util.events import get_plain_text_topic_from_event_content
from synapse.util.iterutils import batch_iter, sorted_topologically
from synapse.util.json import json_encoder
from synapse.util.stringutils import non_null_str_or_none
if TYPE_CHECKING:

View File

@@ -58,8 +58,8 @@ from synapse.types import JsonDict, RoomStreamToken, StateMap, StrCollection
from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
from synapse.types.state import StateFilter
from synapse.types.storage import _BackgroundUpdates
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -38,7 +38,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.types import ISynapseReactor
from synapse.util import Clock
from synapse.util.clock import Clock
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
@@ -99,7 +99,7 @@ class LockStore(SQLBaseStore):
# lead to a race, as we may drop the lock while we are still processing.
# However, a) it should be a small window, b) the lock is best effort
# anyway and c) we want to really avoid leaking locks when we restart.
hs.get_reactor().addSystemEventTrigger(
hs.get_clock().add_system_event_trigger(
"before",
"shutdown",
self._on_shutdown,

View File

@@ -56,10 +56,11 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import gather_results
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -42,8 +42,8 @@ from synapse.storage.database import (
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -55,10 +55,10 @@ from synapse.types import (
PersistedPosition,
StrCollection,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -65,8 +65,8 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.json import json_encoder
from synapse.util.stringutils import MXC_REGEX
if TYPE_CHECKING:

View File

@@ -30,7 +30,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -35,8 +35,8 @@ from synapse.types.handlers.sliding_sync import (
RoomStatusMap,
RoomSyncConfig,
)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -30,8 +30,8 @@ from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.json import json_encoder
logger = logging.getLogger(__name__)

View File

@@ -29,7 +29,7 @@ from synapse.storage.database import (
make_in_list_sql_clause,
)
from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus
from synapse.util import json_encoder
from synapse.util.json import json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer

View File

@@ -27,7 +27,8 @@ from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict
from synapse.util import json_encoder, stringutils
from synapse.util import stringutils
from synapse.util.json import json_encoder
@attr.s(slots=True, auto_attribs=True)

View File

@@ -116,13 +116,27 @@ StrSequence = Union[Tuple[str, ...], List[str]]
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
class ISynapseThreadlessReactor(
IReactorTCP,
IReactorSSL,
IReactorUNIX,
IReactorPluggableNameResolver,
IReactorTime,
IReactorCore,
Interface,
):
"""
The interfaces necessary for Synapse to function (without threads).
Helpful because we use `twisted.internet.testing.MemoryReactorClock` in tests which
doesn't implement `IReactorThreads`.
"""
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
ISynapseThreadlessReactor,
IReactorThreads,
Interface,
):

View File

@@ -20,12 +20,9 @@
#
import collections.abc
import json
import logging
import typing
from typing import (
Any,
Callable,
Dict,
Iterator,
Mapping,
@@ -36,17 +33,11 @@ from typing import (
)
import attr
from immutabledict import immutabledict
from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import ParamSpec
from twisted.internet import defer, task
from twisted.internet.interfaces import IDelayedCall, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.logging import context
if typing.TYPE_CHECKING:
pass
@@ -62,41 +53,6 @@ class Duration:
DAY_MS = 24 * HOUR_MS
def _reject_invalid_json(val: Any) -> None:
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val)
def _handle_immutabledict(obj: Any) -> Dict[Any, Any]:
"""Helper for json_encoder. Makes immutabledicts serializable by returning
the underlying dict
"""
if type(obj) is immutabledict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
try:
# Safety: we catch the AttributeError immediately below.
return obj._dict
except AttributeError:
# If all else fails, resort to making a copy of the immutabledict
return dict(obj)
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
# A custom JSON encoder which:
# * handles immutabledicts
# * produces valid JSON (no NaNs etc)
# * reduces redundant whitespace
json_encoder = json.JSONEncoder(
allow_nan=False, separators=(",", ":"), default=_handle_immutabledict
)
# Create a custom decoder to reject Python extensions to JSON.
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
def unwrapFirstError(failure: Failure) -> Failure:
# Deprecated: you probably just want to catch defer.FirstError and reraise
# the subFailure's value, which will do a better job of preserving stacktraces.
@@ -105,129 +61,6 @@ def unwrapFirstError(failure: Failure) -> Failure:
return failure.value.subFailure
P = ParamSpec("P")
@attr.s(slots=True)
class Clock:
"""
A Clock wraps a Twisted reactor and provides utilities on top of it.
Args:
reactor: The Twisted reactor to use.
"""
_reactor: IReactorTime = attr.ib()
async def sleep(self, seconds: float) -> None:
d: defer.Deferred[float] = defer.Deferred()
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
await d
def time(self) -> float:
"""Returns the current system time in seconds since epoch."""
return self._reactor.seconds()
def time_msec(self) -> int:
"""Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000)
def looping_call(
self,
f: Callable[P, object],
msec: float,
*args: P.args,
**kwargs: P.kwargs,
) -> LoopingCall:
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
If the function given to `looping_call` returns an awaitable/deferred, the next
call isn't scheduled until after the returned awaitable has finished. We get
this functionality thanks to this function being a thin wrapper around
`twisted.internet.task.LoopingCall`.
Note that the function will be called with no logcontext, so if it is anything
other than trivial, you probably want to wrap it in run_as_background_process.
Args:
f: The function to call repeatedly.
msec: How long to wait between calls in milliseconds.
*args: Positional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
return self._looping_call_common(f, msec, False, *args, **kwargs)
def looping_call_now(
self,
f: Callable[P, object],
msec: float,
*args: P.args,
**kwargs: P.kwargs,
) -> LoopingCall:
"""Call a function immediately, and then repeatedly thereafter.
As with `looping_call`: subsequent calls are not scheduled until after the
the Awaitable returned by a previous call has finished.
Also as with `looping_call`: the function is called with no logcontext and
you probably want to wrap it in `run_as_background_process`.
Args:
f: The function to call repeatedly.
msec: How long to wait between calls in milliseconds.
*args: Positional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
return self._looping_call_common(f, msec, True, *args, **kwargs)
def _looping_call_common(
self,
f: Callable[P, object],
msec: float,
now: bool,
*args: P.args,
**kwargs: P.kwargs,
) -> LoopingCall:
"""Common functionality for `looping_call` and `looping_call_now`"""
call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=now)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(
self, delay: float, callback: Callable, *args: Any, **kwargs: Any
) -> IDelayedCall:
"""Call something later
Note that the function will be called with no logcontext, so if it is anything
other than trivial, you probably want to wrap it in run_as_background_process.
Args:
delay: How long to wait in seconds.
callback: Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
with context.PreserveLoggingContext():
callback(*args, **kwargs)
with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
try:
timer.cancel()
except Exception:
if not ignore_errs:
raise
def log_failure(
failure: Failure, msg: str, consumeErrors: bool = True
) -> Optional[Failure]:

View File

@@ -65,7 +65,8 @@ from synapse.logging.context import (
run_coroutine_in_background,
run_in_background,
)
from synapse.util import Clock
from synapse.types import ISynapseThreadlessReactor
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)
@@ -566,7 +567,7 @@ class Linearizer:
if not clock:
from twisted.internet import reactor
clock = Clock(cast(IReactorTime, reactor))
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
self._clock = clock
self.max_count = max_count

View File

@@ -39,7 +39,7 @@ from twisted.internet import defer
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics import SERVER_NAME_LABEL
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)

View File

@@ -579,9 +579,12 @@ def cachedList(
Used to do batch lookups for an already created cache. One of the arguments
is specified as a list that is iterated through to lookup keys in the
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
the cache gets passed to the original function, which is expected to results
the cache gets passed to the original function, which is expected to result
in a map of key to value for each passed value. The new results are stored in the
original cache. Note that any missing values are cached as None.
original cache.
Note that any values in the input that end up being missing from both the
cache and the returned dictionary will be cached as `None`.
Args:
cached_method_name: The name of the single-item lookup method.

View File

@@ -29,8 +29,8 @@ from twisted.internet import defer
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock
from synapse.util.caches import EvictionReason, register_cache
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)

View File

@@ -46,20 +46,21 @@ from typing import (
)
from twisted.internet import defer, reactor
from twisted.internet.interfaces import IReactorTime
from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import (
run_as_background_process,
)
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.types import ISynapseThreadlessReactor
from synapse.util import caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
from synapse.util.caches.treecache import (
TreeCache,
iterate_tree_cache_entry,
iterate_tree_cache_items,
)
from synapse.util.clock import Clock
from synapse.util.linked_list import ListNode
if TYPE_CHECKING:
@@ -496,7 +497,7 @@ class LruCache(Generic[KT, VT]):
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
if clock is None:
real_clock = Clock(cast(IReactorTime, reactor))
real_clock = Clock(cast(ISynapseThreadlessReactor, reactor))
else:
real_clock = clock

View File

@@ -41,9 +41,9 @@ from synapse.logging.opentracing import (
start_active_span,
start_active_span_follows_from,
)
from synapse.util import Clock
from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
from synapse.util.caches import EvictionReason, register_cache
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)

264
synapse/util/clock.py Normal file
View File

@@ -0,0 +1,264 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl_3.0.html>.
#
#
from typing import (
Any,
Callable,
)
import attr
from typing_extensions import ParamSpec
from twisted.internet import defer, task
from twisted.internet.interfaces import IDelayedCall
from twisted.internet.task import LoopingCall
from synapse.logging import context
from synapse.types import ISynapseThreadlessReactor
from synapse.util import log_failure
P = ParamSpec("P")
@attr.s(slots=True)
class Clock:
"""
A Clock wraps a Twisted reactor and provides utilities on top of it.
Args:
reactor: The Twisted reactor to use.
"""
_reactor: ISynapseThreadlessReactor = attr.ib()
async def sleep(self, seconds: float) -> None:
d: defer.Deferred[float] = defer.Deferred()
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
await d
def time(self) -> float:
"""Returns the current system time in seconds since epoch."""
return self._reactor.seconds()
def time_msec(self) -> int:
"""Returns the current system time in milliseconds since epoch."""
return int(self.time() * 1000)
def looping_call(
self,
f: Callable[P, object],
msec: float,
*args: P.args,
**kwargs: P.kwargs,
) -> LoopingCall:
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
If the function given to `looping_call` returns an awaitable/deferred, the next
call isn't scheduled until after the returned awaitable has finished. We get
this functionality thanks to this function being a thin wrapper around
`twisted.internet.task.LoopingCall`.
Note that the function will be called with no logcontext, so if it is anything
other than trivial, you probably want to wrap it in run_as_background_process.
Args:
f: The function to call repeatedly.
msec: How long to wait between calls in milliseconds.
*args: Positional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
return self._looping_call_common(f, msec, False, *args, **kwargs)
def looping_call_now(
self,
f: Callable[P, object],
msec: float,
*args: P.args,
**kwargs: P.kwargs,
) -> LoopingCall:
"""Call a function immediately, and then repeatedly thereafter.
As with `looping_call`: subsequent calls are not scheduled until after the
the Awaitable returned by a previous call has finished.
Also as with `looping_call`: the function is called with no logcontext and
you probably want to wrap it in `run_as_background_process`.
Args:
f: The function to call repeatedly.
msec: How long to wait between calls in milliseconds.
*args: Positional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
return self._looping_call_common(f, msec, True, *args, **kwargs)
def _looping_call_common(
self,
f: Callable[P, object],
msec: float,
now: bool,
*args: P.args,
**kwargs: P.kwargs,
) -> LoopingCall:
"""Common functionality for `looping_call` and `looping_call_now`"""
call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=now)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(
self, delay: float, callback: Callable, *args: Any, **kwargs: Any
) -> IDelayedCall:
"""Call something later
Note that the function will be called with no logcontext, so if it is anything
other than trivial, you probably want to wrap it in run_as_background_process.
Args:
delay: How long to wait in seconds.
callback: Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
with context.PreserveLoggingContext():
callback(*args, **kwargs)
with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
try:
timer.cancel()
except Exception:
if not ignore_errs:
raise
def call_when_running(
self,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Call a function when the reactor is running.
If the reactor has not started, the callable will be scheduled to run when it
does start. Otherwise, the callable will be invoked immediately.
Args:
callback: Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
# Since this callback can be invoked immediately if the reactor is already
# running, we can't always assume that we're running in the sentinel
# logcontext (i.e. we can't assert that we're in the sentinel context like
# we can in other methods).
#
# We will only be running in the sentinel logcontext if the reactor was not
# running when `call_when_running` was invoked and later starts up.
#
# assert context.current_context() is context.SENTINEL_CONTEXT
# Because this is a callback from the reactor, we will be using the
# `sentinel` log context at this point. We want the function to log with
# some logcontext as we want to know which server the logs came from.
#
# We use `PreserveLoggingContext` to prevent our new `call_when_running`
# logcontext from finishing as soon as we exit this function, in case `f`
# returns an awaitable/deferred which would continue running and may try to
# restore the `loop_call` context when it's done (because it's trying to
# adhere to the Synapse logcontext rules.)
#
# This also ensures that we return to the `sentinel` context when we exit
# this function and yield control back to the reactor to avoid leaking the
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
with context.PreserveLoggingContext(
context.LoggingContext("call_when_running")
):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
# logcontext to the reactor
context.run_in_background(callback, *args, **kwargs)
# We can ignore the lint here since this class is the one location
# callWhenRunning should be called.
self._reactor.callWhenRunning(wrapped_callback, *args, **kwargs) # type: ignore[prefer-synapse-clock-call-when-running]
def add_system_event_trigger(
self,
phase: str,
event_type: str,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Add a function to be called when a system event occurs.
Equivalent to `reactor.addSystemEventTrigger` (see the that docstring for more
details), but ensures that the callback is run in a logging context.
Args:
phase: a time to call the event -- either the string 'before', 'after', or
'during', describing when to call it relative to the event's execution.
eventType: this is a string describing the type of event.
callback: Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
assert context.current_context() is context.SENTINEL_CONTEXT, (
"Expected `add_system_event_trigger` callback from the reactor to start with the sentinel logcontext "
f"but saw {context.current_context()}. In other words, another task shouldn't have "
"leaked their logcontext to us."
)
# Because this is a callback from the reactor, we will be using the
# `sentinel` log context at this point. We want the function to log with
# some logcontext as we want to know which server the logs came from.
#
# We use `PreserveLoggingContext` to prevent our new `system_event`
# logcontext from finishing as soon as we exit this function, in case `f`
# returns an awaitable/deferred which would continue running and may try to
# restore the `loop_call` context when it's done (because it's trying to
# adhere to the Synapse logcontext rules.)
#
# This also ensures that we return to the `sentinel` context when we exit
# this function and yield control back to the reactor to avoid leaking the
# current logcontext to the reactor (which would then get picked up and
# associated with the next thing the reactor does)
with context.PreserveLoggingContext(context.LoggingContext("system_event")):
# We use `run_in_background` to reset the logcontext after `f` (or the
# awaitable returned by `f`) completes to avoid leaking the current
# logcontext to the reactor
context.run_in_background(callback, *args, **kwargs)
# We can ignore the lint here since this class is the one location
# `addSystemEventTrigger` should be called.
self._reactor.addSystemEventTrigger(
phase, event_type, wrapped_callback, *args, **kwargs
) # type: ignore[prefer-synapse-clock-add-system-event-trigger]

57
synapse/util/json.py Normal file
View File

@@ -0,0 +1,57 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl_3.0.html>.
#
#
import json
from typing import (
Any,
Dict,
)
from immutabledict import immutabledict
def _reject_invalid_json(val: Any) -> None:
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise ValueError("Invalid JSON value: '%s'" % val)
def _handle_immutabledict(obj: Any) -> Dict[Any, Any]:
"""Helper for json_encoder. Makes immutabledicts serializable by returning
the underlying dict
"""
if type(obj) is immutabledict:
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
try:
# Safety: we catch the AttributeError immediately below.
return obj._dict
except AttributeError:
# If all else fails, resort to making a copy of the immutabledict
return dict(obj)
raise TypeError(
"Object of type %s is not JSON serializable" % obj.__class__.__name__
)
# A custom JSON encoder which:
# * handles immutabledicts
# * produces valid JSON (no NaNs etc)
# * reduces redundant whitespace
json_encoder = json.JSONEncoder(
allow_nan=False, separators=(",", ":"), default=_handle_immutabledict
)
# Create a custom decoder to reject Python extensions to JSON.
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)

View File

@@ -28,7 +28,8 @@ import attr
import pymacaroons
from pymacaroons.exceptions import MacaroonVerificationFailedException
from synapse.util import Clock, stringutils
from synapse.util import stringutils
from synapse.util.clock import Clock
MacaroonType = Literal["access", "delete_pusher", "session"]

View File

@@ -42,7 +42,7 @@ from synapse.logging.context import (
current_context,
)
from synapse.metrics import SERVER_NAME_LABEL, InFlightGauge
from synapse.util import Clock
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)

View File

@@ -53,7 +53,7 @@ from synapse.logging.context import (
)
from synapse.logging.opentracing import start_active_span
from synapse.metrics import SERVER_NAME_LABEL, Histogram, LaterGauge
from synapse.util import Clock
from synapse.util.clock import Clock
if typing.TYPE_CHECKING:
from contextlib import _GeneratorContextManager

View File

@@ -27,7 +27,7 @@ from synapse.api.errors import CodeMessageException
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage import DataStore
from synapse.types import StrCollection
from synapse.util import Clock
from synapse.util.clock import Clock
if TYPE_CHECKING:
from synapse.notifier import Notifier

View File

@@ -55,7 +55,7 @@ from synapse.types import (
get_domain_from_id,
)
from synapse.types.state import StateFilter
from synapse.util import Clock
from synapse.util.clock import Clock
logger = logging.getLogger(__name__)
filtered_event_logger = logging.getLogger("synapse.visibility.filtered_event_debug")

View File

@@ -62,7 +62,10 @@ def make_test(
return res
d.addBoth(on_done)
reactor.callWhenRunning(lambda: d.callback(True))
# type-ignore: This is outside of Synapse (just a utility benchmark script)
# so we don't need to worry about which server the logs are coming from
# (`Clock.call_when_running` manages the logcontext for us).
reactor.callWhenRunning(lambda: d.callback(True)) # type: ignore[prefer-synapse-clock-call-when-running]
reactor.run()
# mypy thinks this is an object for some reason.

View File

@@ -37,7 +37,7 @@ from synapse.config.logger import _setup_stdlib_logging
from synapse.logging import RemoteHandler
from synapse.synapse_rust import reset_logging_config
from synapse.types import ISynapseReactor
from synapse.util import Clock
from synapse.util.clock import Clock
class LineCounter(LineOnlyReceiver):

View File

@@ -39,7 +39,7 @@ from synapse.appservice import ApplicationService
from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, UserID
from synapse.util import Clock
from synapse.util.clock import Clock
from tests import unittest
from tests.unittest import override_config

View File

@@ -33,7 +33,7 @@ from synapse.api.filtering import Filter
from synapse.api.presence import UserPresenceState
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.clock import Clock
from synapse.util.frozenutils import freeze
from tests import unittest

View File

@@ -17,7 +17,7 @@ from twisted.internet.testing import MemoryReactor
from synapse.api.urls import LoginSSORedirectURIBuilder
from synapse.server import HomeServer
from synapse.util import Clock
from synapse.util.clock import Clock
from tests.unittest import HomeserverTestCase
@@ -53,3 +53,29 @@ class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase):
),
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22",
)
def test_idp_id_with_slash_is_escaped(self) -> None:
"""
Test to make sure that we properly URL encode the IdP ID.
"""
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="foo/bar",
client_redirect_url="http://example.com/redirect",
),
"https://test/_matrix/client/v3/login/sso/redirect/foo%2Fbar?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
)
def test_url_as_idp_id_is_escaped(self) -> None:
"""
Test to make sure that we properly URL encode the IdP ID.
The IdP ID shouldn't be a URL.
"""
self.assertEqual(
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
idp_id="http://should-not-be-url.com/",
client_redirect_url="http://example.com/redirect",
),
"https://test/_matrix/client/v3/login/sso/redirect/http%3A%2F%2Fshould-not-be-url.com%2F?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
)

View File

@@ -29,7 +29,7 @@ from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.clock import Clock
from tests.server import make_request
from tests.unittest import HomeserverTestCase

Some files were not shown because too many files have changed in this diff Show More