mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-07 01:20:16 +00:00
Compare commits
19 Commits
madlittlem
...
madlittlem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41874bb8c8 | ||
|
|
4fe5413b41 | ||
|
|
a906d1aa0b | ||
|
|
2432390bf2 | ||
|
|
4a9730ee76 | ||
|
|
19a4f8c741 | ||
|
|
62f93ded1f | ||
|
|
7b83c9fcbc | ||
|
|
6938134f7d | ||
|
|
0f8076ab41 | ||
|
|
731f36e131 | ||
|
|
12e7cf4487 | ||
|
|
f1e6887d14 | ||
|
|
2d1331f0e4 | ||
|
|
bccd224489 | ||
|
|
e45dd4f03e | ||
|
|
4367fb2d07 | ||
|
|
b596faa4ec | ||
|
|
6f9fab1089 |
1
changelog.d/18899.feature
Normal file
1
changelog.d/18899.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load.
|
||||
1
changelog.d/18909.bugfix
Normal file
1
changelog.d/18909.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix open redirect in legacy SSO flow with the `idp` query parameter.
|
||||
2
changelog.d/18931.doc
Normal file
2
changelog.d/18931.doc
Normal file
@@ -0,0 +1,2 @@
|
||||
Clarify necessary `jwt_config` parameter in OIDC documentation for authentik.
|
||||
Contributed by @maxkratz.
|
||||
1
changelog.d/18944.misc
Normal file
1
changelog.d/18944.misc
Normal file
@@ -0,0 +1 @@
|
||||
Introduce `Clock.call_when_running(...)` to wrap startup code in a logcontext, ensuring we can identify which server generated the logs.
|
||||
1
changelog.d/18945.misc
Normal file
1
changelog.d/18945.misc
Normal file
@@ -0,0 +1 @@
|
||||
Introduce `Clock.add_system_event_trigger(...)` to wrap system event callback code in a logcontext, ensuring we can identify which server generated the logs.
|
||||
@@ -186,6 +186,7 @@ oidc_providers:
|
||||
4. Note the slug of your application, Client ID and Client Secret.
|
||||
|
||||
Note: RSA keys must be used for signing for Authentik, ECC keys do not work.
|
||||
Note: The provider must have a signing key set and must not use an encryption key.
|
||||
|
||||
Synapse config:
|
||||
```yaml
|
||||
@@ -204,6 +205,12 @@ oidc_providers:
|
||||
config:
|
||||
localpart_template: "{{ user.preferred_username }}"
|
||||
display_name_template: "{{ user.preferred_username|capitalize }}" # TO BE FILLED: If your users have names in Authentik and you want those in Synapse, this should be replaced with user.name|capitalize.
|
||||
[...]
|
||||
jwt_config:
|
||||
enabled: true
|
||||
secret: "your client secret" # TO BE FILLED (same as `client_secret` above)
|
||||
algorithm: "RS256"
|
||||
# (...other fields)
|
||||
```
|
||||
|
||||
### Dex
|
||||
|
||||
@@ -68,6 +68,18 @@ PROMETHEUS_METRIC_MISSING_FROM_LIST_TO_CHECK = ErrorCode(
|
||||
category="per-homeserver-tenant-metrics",
|
||||
)
|
||||
|
||||
PREFER_SYNAPSE_CLOCK_CALL_WHEN_RUNNING = ErrorCode(
|
||||
"prefer-synapse-clock-call-when-running",
|
||||
"`synapse.util.Clock.call_when_running` should be used instead of `reactor.callWhenRunning`",
|
||||
category="synapse-reactor-clock",
|
||||
)
|
||||
|
||||
PREFER_SYNAPSE_CLOCK_ADD_SYSTEM_EVENT_TRIGGER = ErrorCode(
|
||||
"prefer-synapse-clock-add-system-event-trigger",
|
||||
"`synapse.util.Clock.add_system_event_trigger` should be used instead of `reactor.addSystemEventTrigger`",
|
||||
category="synapse-reactor-clock",
|
||||
)
|
||||
|
||||
|
||||
class Sentinel(enum.Enum):
|
||||
# defining a sentinel in this way allows mypy to correctly handle the
|
||||
@@ -229,9 +241,77 @@ class SynapsePlugin(Plugin):
|
||||
):
|
||||
return check_is_cacheable_wrapper
|
||||
|
||||
if fullname in (
|
||||
"twisted.internet.interfaces.IReactorCore.callWhenRunning",
|
||||
"synapse.types.ISynapseThreadlessReactor.callWhenRunning",
|
||||
"synapse.types.ISynapseReactor.callWhenRunning",
|
||||
):
|
||||
return check_call_when_running
|
||||
|
||||
if fullname in (
|
||||
"twisted.internet.interfaces.IReactorCore.addSystemEventTrigger",
|
||||
"synapse.types.ISynapseThreadlessReactor.addSystemEventTrigger",
|
||||
"synapse.types.ISynapseReactor.addSystemEventTrigger",
|
||||
):
|
||||
return check_add_system_event_trigger
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def check_call_when_running(ctx: MethodSigContext) -> CallableType:
|
||||
"""
|
||||
Ensure that the `reactor.callWhenRunning` callsites aren't used.
|
||||
|
||||
`synapse.util.Clock.call_when_running` should always be used instead of
|
||||
`reactor.callWhenRunning`.
|
||||
|
||||
Since `reactor.callWhenRunning` is a reactor callback, the callback will start out
|
||||
with the sentinel logcontext. `synapse.util.Clock` starts a default logcontext as we
|
||||
want to know which server the logs came from.
|
||||
|
||||
Args:
|
||||
ctx: The `FunctionSigContext` from mypy.
|
||||
"""
|
||||
signature: CallableType = ctx.default_signature
|
||||
ctx.api.fail(
|
||||
(
|
||||
"Expected all `reactor.callWhenRunning` calls to use `synapse.util.Clock.call_when_running` instead. "
|
||||
"This is so all Synapse code runs with a logcontext as we want to know which server the logs came from."
|
||||
),
|
||||
ctx.context,
|
||||
code=PREFER_SYNAPSE_CLOCK_CALL_WHEN_RUNNING,
|
||||
)
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def check_add_system_event_trigger(ctx: MethodSigContext) -> CallableType:
|
||||
"""
|
||||
Ensure that the `reactor.addSystemEventTrigger` callsites aren't used.
|
||||
|
||||
`synapse.util.Clock.add_system_event_trigger` should always be used instead of
|
||||
`reactor.addSystemEventTrigger`.
|
||||
|
||||
Since `reactor.addSystemEventTrigger` is a reactor callback, the callback will start out
|
||||
with the sentinel logcontext. `synapse.util.Clock` starts a default logcontext as we
|
||||
want to know which server the logs came from.
|
||||
|
||||
Args:
|
||||
ctx: The `FunctionSigContext` from mypy.
|
||||
"""
|
||||
signature: CallableType = ctx.default_signature
|
||||
ctx.api.fail(
|
||||
(
|
||||
"Expected all `reactor.addSystemEventTrigger` calls to use `synapse.util.Clock.add_system_event_trigger` instead. "
|
||||
"This is so all Synapse code runs with a logcontext as we want to know which server the logs came from."
|
||||
),
|
||||
ctx.context,
|
||||
code=PREFER_SYNAPSE_CLOCK_ADD_SYSTEM_EVENT_TRIGGER,
|
||||
)
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def analyze_prometheus_metric_classes(ctx: ClassDefContext) -> None:
|
||||
"""
|
||||
Cross-check the list of Prometheus metric classes against the
|
||||
|
||||
@@ -30,7 +30,7 @@ from signedjson.sign import sign_json
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@@ -54,11 +54,11 @@ from twisted.internet import defer, reactor as reactor_
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.logging.context import (
|
||||
LoggingContext,
|
||||
make_deferred_yieldable,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.notifier import ReplicationNotifier
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
|
||||
from synapse.storage.databases.main import FilteringWorkerStore
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
@@ -98,8 +98,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import SYNAPSE_VERSION, Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util import SYNAPSE_VERSION
|
||||
|
||||
# Cast safety: Twisted does some naughty magic which replaces the
|
||||
# twisted.internet.reactor module with a Reactor instance at runtime.
|
||||
@@ -318,31 +317,16 @@ class Store(
|
||||
)
|
||||
|
||||
|
||||
class MockHomeserver:
|
||||
class MockHomeserver(HomeServer):
|
||||
DATASTORE_CLASS = DataStore
|
||||
|
||||
def __init__(self, config: HomeServerConfig):
|
||||
self.clock = Clock(reactor)
|
||||
self.config = config
|
||||
self.hostname = config.server.server_name
|
||||
self.version_string = SYNAPSE_VERSION
|
||||
self.instance_id = random_string(5)
|
||||
|
||||
def get_clock(self) -> Clock:
|
||||
return self.clock
|
||||
|
||||
def get_reactor(self) -> ISynapseReactor:
|
||||
return reactor
|
||||
|
||||
def get_instance_id(self) -> str:
|
||||
return self.instance_id
|
||||
|
||||
def get_instance_name(self) -> str:
|
||||
return "master"
|
||||
|
||||
def should_send_federation(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_replication_notifier(self) -> ReplicationNotifier:
|
||||
return ReplicationNotifier()
|
||||
super().__init__(
|
||||
hostname=config.server.server_name,
|
||||
config=config,
|
||||
reactor=reactor,
|
||||
version_string=f"Synapse/{SYNAPSE_VERSION}",
|
||||
)
|
||||
|
||||
|
||||
class Porter:
|
||||
@@ -351,12 +335,12 @@ class Porter:
|
||||
sqlite_config: Dict[str, Any],
|
||||
progress: "Progress",
|
||||
batch_size: int,
|
||||
hs_config: HomeServerConfig,
|
||||
hs: HomeServer,
|
||||
):
|
||||
self.sqlite_config = sqlite_config
|
||||
self.progress = progress
|
||||
self.batch_size = batch_size
|
||||
self.hs_config = hs_config
|
||||
self.hs = hs
|
||||
|
||||
async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
|
||||
if table in APPEND_ONLY_TABLES:
|
||||
@@ -676,8 +660,7 @@ class Porter:
|
||||
|
||||
engine = create_engine(db_config.config)
|
||||
|
||||
hs = MockHomeserver(self.hs_config)
|
||||
server_name = hs.hostname
|
||||
server_name = self.hs.hostname
|
||||
|
||||
with make_conn(
|
||||
db_config=db_config,
|
||||
@@ -688,16 +671,16 @@ class Porter:
|
||||
engine.check_database(
|
||||
db_conn, allow_outdated_version=allow_outdated_version
|
||||
)
|
||||
prepare_database(db_conn, engine, config=self.hs_config)
|
||||
prepare_database(db_conn, engine, config=self.hs.config)
|
||||
# Type safety: ignore that we're using Mock homeservers here.
|
||||
store = Store(
|
||||
DatabasePool(
|
||||
hs, # type: ignore[arg-type]
|
||||
self.hs,
|
||||
db_config,
|
||||
engine,
|
||||
),
|
||||
db_conn,
|
||||
hs, # type: ignore[arg-type]
|
||||
self.hs,
|
||||
)
|
||||
db_conn.commit()
|
||||
|
||||
@@ -795,7 +778,7 @@ class Porter:
|
||||
return
|
||||
|
||||
self.postgres_store = self.build_db_store(
|
||||
self.hs_config.database.get_single_database()
|
||||
self.hs.config.database.get_single_database()
|
||||
)
|
||||
|
||||
await self.remove_ignored_background_updates_from_database()
|
||||
@@ -1584,6 +1567,8 @@ def main() -> None:
|
||||
config = HomeServerConfig()
|
||||
config.parse_config_dict(hs_config, "", "")
|
||||
|
||||
hs = MockHomeserver(config)
|
||||
|
||||
def start(stdscr: Optional["curses.window"] = None) -> None:
|
||||
progress: Progress
|
||||
if stdscr:
|
||||
@@ -1595,15 +1580,14 @@ def main() -> None:
|
||||
sqlite_config=sqlite_config,
|
||||
progress=progress,
|
||||
batch_size=args.batch_size,
|
||||
hs_config=config,
|
||||
hs=hs,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def run() -> Generator["defer.Deferred[Any]", Any, None]:
|
||||
with LoggingContext("synapse_port_db_run"):
|
||||
yield defer.ensureDeferred(porter.run())
|
||||
yield defer.ensureDeferred(porter.run())
|
||||
|
||||
reactor.callWhenRunning(run)
|
||||
hs.get_clock().call_when_running(run)
|
||||
|
||||
reactor.run()
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ def run_background_updates(hs: HomeServer) -> None:
|
||||
)
|
||||
)
|
||||
|
||||
reactor.callWhenRunning(run)
|
||||
hs.get_clock().call_when_running(run)
|
||||
|
||||
reactor.run()
|
||||
|
||||
|
||||
@@ -43,9 +43,9 @@ from synapse.logging.opentracing import (
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.synapse_rust.http_client import HttpClient
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
from . import introspection_response_timer
|
||||
|
||||
|
||||
@@ -48,9 +48,9 @@ from synapse.logging.opentracing import (
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.synapse_rust.http_client import HttpClient
|
||||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
from . import introspection_response_timer
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from twisted.web import http
|
||||
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.api.errors import LimitExceededError
|
||||
from synapse.config.ratelimiting import RatelimitSettings
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import Requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# To avoid circular imports:
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
"""Contains the URL paths to prefix various aspects of the server with."""
|
||||
|
||||
import hmac
|
||||
import urllib.parse
|
||||
from hashlib import sha256
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urljoin
|
||||
@@ -96,11 +97,21 @@ class LoginSSORedirectURIBuilder:
|
||||
serialized_query_parameters = urlencode({"redirectUrl": client_redirect_url})
|
||||
|
||||
if idp_id:
|
||||
# Since this is a user-controlled string, make it safe to include in a URL path.
|
||||
url_encoded_idp_id = urllib.parse.quote(
|
||||
idp_id,
|
||||
# Since this defaults to `safe="/"`, we have to override it. We're
|
||||
# working with an individual URL path parameter so there shouldn't be
|
||||
# any slashes in it which could change the request path.
|
||||
safe="",
|
||||
encoding="utf8",
|
||||
)
|
||||
|
||||
resultant_url = urljoin(
|
||||
# We have to add a trailing slash to the base URL to ensure that the
|
||||
# last path segment is not stripped away when joining with another path.
|
||||
f"{base_url}/",
|
||||
f"{idp_id}?{serialized_query_parameters}",
|
||||
f"{url_encoded_idp_id}?{serialized_query_parameters}",
|
||||
)
|
||||
else:
|
||||
resultant_url = f"{base_url}?{serialized_query_parameters}"
|
||||
|
||||
@@ -241,7 +241,7 @@ def redirect_stdio_to_logs() -> None:
|
||||
|
||||
|
||||
def register_start(
|
||||
cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
|
||||
hs: "HomeServer", cb: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
"""Register a callback with the reactor, to be called once it is running
|
||||
|
||||
@@ -278,7 +278,8 @@ def register_start(
|
||||
# on as normal.
|
||||
os._exit(1)
|
||||
|
||||
reactor.callWhenRunning(lambda: defer.ensureDeferred(wrapper()))
|
||||
clock = hs.get_clock()
|
||||
clock.call_when_running(lambda: defer.ensureDeferred(wrapper()))
|
||||
|
||||
|
||||
def listen_metrics(bind_addresses: StrCollection, port: int) -> None:
|
||||
@@ -517,7 +518,9 @@ async def start(hs: "HomeServer") -> None:
|
||||
# numbers of DNS requests don't starve out other users of the threadpool.
|
||||
resolver_threadpool = ThreadPool(name="gai_resolver")
|
||||
resolver_threadpool.start()
|
||||
reactor.addSystemEventTrigger("during", "shutdown", resolver_threadpool.stop)
|
||||
hs.get_clock().add_system_event_trigger(
|
||||
"during", "shutdown", resolver_threadpool.stop
|
||||
)
|
||||
reactor.installNameResolver(
|
||||
GAIResolver(reactor, getThreadPool=lambda: resolver_threadpool)
|
||||
)
|
||||
@@ -604,7 +607,7 @@ async def start(hs: "HomeServer") -> None:
|
||||
logger.info("Shutting down...")
|
||||
|
||||
# Log when we start the shut down process.
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", log_shutdown)
|
||||
hs.get_clock().add_system_event_trigger("before", "shutdown", log_shutdown)
|
||||
|
||||
setup_sentry(hs)
|
||||
setup_sdnotify(hs)
|
||||
@@ -719,7 +722,7 @@ def setup_sdnotify(hs: "HomeServer") -> None:
|
||||
# we're not using systemd.
|
||||
sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),))
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
hs.get_clock().add_system_event_trigger(
|
||||
"before", "shutdown", sdnotify, b"STOPPING=1"
|
||||
)
|
||||
|
||||
|
||||
@@ -356,11 +356,9 @@ def start(config_options: List[str]) -> None:
|
||||
handle_startup_exception(e)
|
||||
|
||||
async def start() -> None:
|
||||
# Re-establish log context now that we're back from the reactor
|
||||
with LoggingContext("start"):
|
||||
await _base.start(hs)
|
||||
await _base.start(hs)
|
||||
|
||||
register_start(start)
|
||||
register_start(hs, start)
|
||||
|
||||
# redirect stdio to the logs, if configured.
|
||||
if not hs.config.logging.no_redirect_stdio:
|
||||
|
||||
@@ -377,19 +377,17 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
|
||||
handle_startup_exception(e)
|
||||
|
||||
async def start() -> None:
|
||||
# Re-establish log context now that we're back from the reactor
|
||||
with LoggingContext("start"):
|
||||
# Load the OIDC provider metadatas, if OIDC is enabled.
|
||||
if hs.config.oidc.oidc_enabled:
|
||||
oidc = hs.get_oidc_handler()
|
||||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
await oidc.load_metadata()
|
||||
# Load the OIDC provider metadatas, if OIDC is enabled.
|
||||
if hs.config.oidc.oidc_enabled:
|
||||
oidc = hs.get_oidc_handler()
|
||||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
await oidc.load_metadata()
|
||||
|
||||
await _base.start(hs)
|
||||
await _base.start(hs)
|
||||
|
||||
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
|
||||
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
|
||||
|
||||
register_start(start)
|
||||
register_start(hs, start)
|
||||
|
||||
return hs
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import DeviceListUpdates, JsonMapping
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -38,7 +38,7 @@ from synapse.storage.databases.main import DataStore
|
||||
from synapse.synapse_rust.events import EventInternalMetadata
|
||||
from synapse.types import EventID, JsonDict, StrCollection
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -178,7 +178,7 @@ from synapse.types import (
|
||||
StrCollection,
|
||||
get_domain_from_id,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.retryutils import filter_destinations_by_retry_limiter
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from synapse.logging.opentracing import (
|
||||
)
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.json import json_decoder
|
||||
from synapse.util.metrics import measure_func
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -62,7 +62,7 @@ class DeactivateAccountHandler:
|
||||
# Start the user parter loop so it can resume parting users from rooms where
|
||||
# it left off (if it has work left to do).
|
||||
if hs.config.worker.worker_app is None:
|
||||
hs.get_reactor().callWhenRunning(self._start_user_parting)
|
||||
hs.get_clock().call_when_running(self._start_user_parting)
|
||||
else:
|
||||
self._notify_account_deactivated_client = (
|
||||
ReplicationNotifyAccountDeactivatedServlet.make_client(hs)
|
||||
|
||||
@@ -1002,7 +1002,7 @@ class DeviceWriterHandler(DeviceHandler):
|
||||
# rolling-restarting Synapse.
|
||||
if self._is_main_device_list_writer:
|
||||
# On start up check if there are any updates pending.
|
||||
hs.get_reactor().callWhenRunning(self._handle_new_device_update_async)
|
||||
hs.get_clock().call_when_running(self._handle_new_device_update_async)
|
||||
self.device_list_updater = DeviceListUpdater(hs, self)
|
||||
hs.get_federation_registry().register_edu_handler(
|
||||
EduTypes.DEVICE_LIST_UPDATE,
|
||||
|
||||
@@ -34,7 +34,7 @@ from synapse.logging.opentracing import (
|
||||
set_tag,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.json import json_encoder
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -44,9 +44,9 @@ from synapse.types import (
|
||||
get_domain_from_id,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
)
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.json import json_decoder
|
||||
from synapse.util.retryutils import (
|
||||
NotRetryingDestination,
|
||||
filter_destinations_by_retry_limiter,
|
||||
|
||||
@@ -39,8 +39,8 @@ from synapse.http import RequestTimedOutError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict, Requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.hash import sha256_and_url_safe_base64
|
||||
from synapse.util.json import json_decoder
|
||||
from synapse.util.stringutils import (
|
||||
assert_valid_client_secret,
|
||||
random_string,
|
||||
|
||||
@@ -81,9 +81,10 @@ from synapse.types import (
|
||||
create_requester,
|
||||
)
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
|
||||
from synapse.util import log_failure, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, gather_results
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.json import json_decoder, json_encoder
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.visibility import get_effective_room_visibility_from_state
|
||||
|
||||
|
||||
@@ -67,8 +67,9 @@ from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
||||
from synapse.util import Clock, json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.json import json_decoder
|
||||
from synapse.util.macaroons import MacaroonGenerator, OidcSessionData
|
||||
from synapse.util.templates import _localpart_from_email_filter
|
||||
|
||||
|
||||
@@ -541,7 +541,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
|
||||
)
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
hs.get_clock().add_system_event_trigger(
|
||||
"before",
|
||||
"shutdown",
|
||||
run_as_background_process,
|
||||
@@ -842,7 +842,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||
# have not yet been persisted
|
||||
self.unpersisted_users_changes: Set[str] = set()
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
hs.get_clock().add_system_event_trigger(
|
||||
"before",
|
||||
"shutdown",
|
||||
run_as_background_process,
|
||||
|
||||
@@ -27,7 +27,7 @@ from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -87,8 +87,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.types import ISynapseReactor, StrSequence
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -49,7 +49,7 @@ from synapse.http.federation.well_known_resolver import WellKnownResolver
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ from typing import Callable, Dict, Optional, Tuple
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.web.client import RedirectAgent
|
||||
from twisted.web.http import stringToDatetime
|
||||
from twisted.web.http_headers import Headers
|
||||
@@ -35,8 +34,10 @@ from twisted.web.iweb import IAgent, IResponse
|
||||
|
||||
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util import Clock, json_decoder
|
||||
from synapse.types import ISynapseThreadlessReactor
|
||||
from synapse.util.caches.ttlcache import TTLCache
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.json import json_decoder
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
# period to cache .well-known results for by default
|
||||
@@ -88,7 +89,7 @@ class WellKnownResolver:
|
||||
def __init__(
|
||||
self,
|
||||
server_name: str,
|
||||
reactor: IReactorTime,
|
||||
reactor: ISynapseThreadlessReactor,
|
||||
agent: IAgent,
|
||||
user_agent: bytes,
|
||||
well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
|
||||
|
||||
@@ -89,8 +89,8 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred
|
||||
from synapse.util.json import json_decoder
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
|
||||
@@ -52,10 +52,11 @@ from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, interfaces, reactor
|
||||
from twisted.internet.defer import CancelledError
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.python import failure
|
||||
from twisted.web import resource
|
||||
|
||||
from synapse.types import ISynapseThreadlessReactor
|
||||
|
||||
try:
|
||||
from twisted.web.pages import notFound
|
||||
except ImportError:
|
||||
@@ -77,10 +78,11 @@ from synapse.api.errors import (
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
|
||||
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
|
||||
from synapse.util import Clock, json_encoder
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.cancellation import is_function_cancellable
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.iterutils import chunk_seq
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import opentracing
|
||||
@@ -410,7 +412,7 @@ class DirectServeJsonResource(_AsyncResource):
|
||||
clock: Optional[Clock] = None,
|
||||
):
|
||||
if clock is None:
|
||||
clock = Clock(cast(IReactorTime, reactor))
|
||||
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
|
||||
|
||||
super().__init__(clock, extract_context)
|
||||
self.canonical_json = canonical_json
|
||||
@@ -589,7 +591,7 @@ class DirectServeHtmlResource(_AsyncResource):
|
||||
clock: Optional[Clock] = None,
|
||||
):
|
||||
if clock is None:
|
||||
clock = Clock(cast(IReactorTime, reactor))
|
||||
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
|
||||
|
||||
super().__init__(clock, extract_context)
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http import redact_uri
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.types import JsonDict, RoomAlias, RoomID, StrCollection
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -60,8 +60,18 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
|
||||
else:
|
||||
reactor_to_use = reactor
|
||||
|
||||
# call our hook when the reactor start up
|
||||
reactor_to_use.callWhenRunning(on_reactor_running)
|
||||
# Call our hook when the reactor start up
|
||||
#
|
||||
# type-ignore: Ideally, we'd use `Clock.call_when_running(...)`, but
|
||||
# `PeriodicallyFlushingMemoryHandler` is instantiated via Python logging
|
||||
# configuration, so it's not straightforward to pass in the homeserver's clock
|
||||
# (and we don't want to burden other peoples logging config with the details).
|
||||
#
|
||||
# The important reason why we want to use `Clock.call_when_running` is so that
|
||||
# the callback runs with a logcontext as we want to know which server the logs
|
||||
# came from. But since we don't log anything in the callback, it's safe to
|
||||
# ignore the lint here.
|
||||
reactor_to_use.callWhenRunning(on_reactor_running) # type: ignore[prefer-synapse-clock-call-when-running]
|
||||
|
||||
def shouldFlush(self, record: LogRecord) -> bool:
|
||||
"""
|
||||
|
||||
@@ -204,7 +204,7 @@ from twisted.web.http import Request
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.json import json_decoder, json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
@@ -54,8 +54,8 @@ from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import DeferredEvent
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -55,7 +55,7 @@ from synapse.api.errors import NotFoundError
|
||||
from synapse.logging.context import defer_to_thread, run_in_background
|
||||
from synapse.logging.opentracing import start_active_span, trace, trace_with_opname
|
||||
from synapse.media._base import ThreadedFileSender
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||
|
||||
from ..types import JsonDict
|
||||
|
||||
@@ -27,7 +27,7 @@ import attr
|
||||
|
||||
from synapse.media.preview_html import parse_html_description
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lxml import etree
|
||||
|
||||
@@ -46,9 +46,9 @@ from synapse.media.oembed import OEmbedProvider
|
||||
from synapse.media.preview_html import decode_body, parse_html_to_open_graph
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.json import json_encoder
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -158,9 +158,9 @@ from synapse.types import (
|
||||
create_requester,
|
||||
)
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
from synapse.util.caches.descriptors import CachedFunction, cached as _cached
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -29,7 +29,7 @@ import logging
|
||||
from typing import List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from synapse.replication.tcp.streams._base import StreamRow
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.json import json_decoder, json_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from prometheus_client import Counter, Histogram
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.json import json_decoder, json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from txredisapi import ConnectionHandler
|
||||
|
||||
@@ -55,7 +55,7 @@ from synapse.replication.tcp.commands import (
|
||||
ServerCommand,
|
||||
parse_command_from_line,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -58,8 +58,8 @@ from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
|
||||
from synapse.rest.admin.experimental_features import ExperimentalFeature
|
||||
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
|
||||
from synapse.types.rest.client import SlidingSyncBody
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
from ._base import client_patterns, set_timeline_upper_limit
|
||||
|
||||
|
||||
@@ -38,8 +38,8 @@ from synapse.http.servlet import (
|
||||
from synapse.storage.keys import FetchKeyResultForRemote
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types.rest import RequestBodyModel
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -63,6 +63,22 @@ class PickIdpResource(DirectServeHtmlResource):
|
||||
if not idp:
|
||||
return await self._serve_id_picker(request, client_redirect_url)
|
||||
|
||||
# Validate the `idp` query parameter. We should only be working with known IdPs.
|
||||
# No need waste further effort if we don't know about it.
|
||||
#
|
||||
# Although, we primarily prevent open redirect attacks by URL encoding all of
|
||||
# the parameters we use in the redirect URL below, this validation also helps
|
||||
# prevent Synapse from crafting arbitrary URLs and being used in open redirect
|
||||
# attacks (defense in depth).
|
||||
providers = self._sso_handler.get_identity_providers()
|
||||
auth_provider = providers.get(idp)
|
||||
if not auth_provider:
|
||||
logger.info("Unknown idp %r", idp)
|
||||
self._sso_handler.render_error(
|
||||
request, "unknown_idp", "Unknown identity provider ID"
|
||||
)
|
||||
return
|
||||
|
||||
# Otherwise, redirect to the login SSO redirect endpoint for the given IdP
|
||||
# (which will in turn take us to the the IdP's redirect URI).
|
||||
#
|
||||
|
||||
@@ -28,7 +28,7 @@ from synapse.api.errors import NotFoundError
|
||||
from synapse.http.server import DirectServeJsonResource
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.json import json_encoder
|
||||
from synapse.util.stringutils import parse_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -156,7 +156,7 @@ from synapse.storage.controllers import StorageControllers
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.synapse_rust.rendezvous import RendezvousHandler
|
||||
from synapse.types import DomainSpecificString, ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.macaroons import MacaroonGenerator
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
@@ -1007,7 +1007,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
)
|
||||
|
||||
media_threadpool.start()
|
||||
self.get_reactor().addSystemEventTrigger(
|
||||
self.get_clock().add_system_event_trigger(
|
||||
"during", "shutdown", media_threadpool.stop
|
||||
)
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ from synapse.storage.database import (
|
||||
make_in_list_sql_clause, # noqa: F401
|
||||
)
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.descriptors import CachedFunction
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -45,7 +45,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.util import Clock, json_encoder
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
from . import engines
|
||||
|
||||
|
||||
@@ -2653,8 +2653,7 @@ def make_in_list_sql_clause(
|
||||
|
||||
|
||||
# These overloads ensure that `columns` and `iterable` values have the same length.
|
||||
# Suppress "Single overload definition, multiple required" complaint.
|
||||
@overload # type: ignore[misc]
|
||||
@overload
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, str],
|
||||
@@ -2662,6 +2661,14 @@ def make_tuple_in_list_sql_clause(
|
||||
) -> Tuple[str, list]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, str, str],
|
||||
iterable: Collection[Tuple[Any, Any, Any]],
|
||||
) -> Tuple[str, list]: ...
|
||||
|
||||
|
||||
def make_tuple_in_list_sql_clause(
|
||||
database_engine: BaseDatabaseEngine,
|
||||
columns: Tuple[str, ...],
|
||||
|
||||
@@ -48,9 +48,9 @@ from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||
from synapse.storage.invite_rule import InviteRulesConfig
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -42,8 +42,8 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import DeviceListUpdates, JsonMapping
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
|
||||
|
||||
@@ -62,6 +63,12 @@ PURGE_HISTORY_CACHE_NAME = "ph_cache_fake"
|
||||
# As above, but for invalidating room caches on room deletion
|
||||
DELETE_ROOM_CACHE_NAME = "dr_cache_fake"
|
||||
|
||||
# This cache takes a list of tuples as its first argument, which requires
|
||||
# special handling.
|
||||
GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = (
|
||||
"_get_e2e_cross_signing_signatures_for_device"
|
||||
)
|
||||
|
||||
# How long between cache invalidation table cleanups, once we have caught up
|
||||
# with the backlog.
|
||||
REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h")
|
||||
@@ -270,6 +277,33 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
# room membership.
|
||||
#
|
||||
# self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined]
|
||||
elif (
|
||||
row.cache_func
|
||||
== GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME
|
||||
):
|
||||
# "keys" is a list of strings, where each string is a
|
||||
# JSON-encoded representation of the tuple keys, i.e.
|
||||
# keys: ['["@userid:domain", "DEVICEID"]','["@userid2:domain", "DEVICEID2"]']
|
||||
#
|
||||
# This is a side-effect of not being able to send nested
|
||||
# information over replication.
|
||||
for json_str in row.keys:
|
||||
try:
|
||||
user_id, device_id = json.loads(json_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.error(
|
||||
"Failed to deserialise cache key as valid JSON: %s",
|
||||
json_str,
|
||||
)
|
||||
continue
|
||||
|
||||
# Invalidate each key.
|
||||
#
|
||||
# Note: .invalidate takes a tuple of arguments, hence the need
|
||||
# to nest our tuple in another tuple.
|
||||
self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined]
|
||||
((user_id, device_id),)
|
||||
)
|
||||
else:
|
||||
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -455,7 +455,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorke
|
||||
self._client_ip_looper = self._clock.looping_call(
|
||||
self._update_client_ips_batch, 5 * 1000
|
||||
)
|
||||
self.hs.get_reactor().addSystemEventTrigger(
|
||||
self.hs.get_clock().add_system_event_trigger(
|
||||
"before", "shutdown", self._update_client_ips_batch
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,8 @@ from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction, StoreError
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, RoomID
|
||||
from synapse.util import json_encoder, stringutils as stringutils
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,10 +53,11 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.util import Duration, json_encoder
|
||||
from synapse.util import Duration
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.json import json_encoder
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -64,11 +64,11 @@ from synapse.types import (
|
||||
StrCollection,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
)
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.json import json_decoder, json_encoder
|
||||
from synapse.util.stringutils import shortstr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -41,7 +41,7 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict, JsonSerializable, StreamKeyType
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#
|
||||
#
|
||||
import abc
|
||||
import json
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -60,10 +61,10 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping, MultiWriterStreamToken
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.json import json_decoder, json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.handlers.e2e_keys import SignatureListItem
|
||||
@@ -354,15 +355,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
)
|
||||
|
||||
for batch in batch_iter(signature_query, 50):
|
||||
cross_sigs_result = await self.db_pool.runInteraction(
|
||||
"get_e2e_cross_signing_signatures_for_devices",
|
||||
self._get_e2e_cross_signing_signatures_for_devices_txn,
|
||||
batch,
|
||||
cross_sigs_result = (
|
||||
await self._get_e2e_cross_signing_signatures_for_devices(batch)
|
||||
)
|
||||
|
||||
# add each cross-signing signature to the correct device in the result dict.
|
||||
for user_id, key_id, device_id, signature in cross_sigs_result:
|
||||
for (
|
||||
user_id,
|
||||
device_id,
|
||||
), signature_list in cross_sigs_result.items():
|
||||
target_device_result = result[user_id][device_id]
|
||||
|
||||
# We've only looked up cross-signatures for non-deleted devices with key
|
||||
# data.
|
||||
assert target_device_result is not None
|
||||
@@ -373,7 +376,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
signing_user_signatures = target_device_signatures.setdefault(
|
||||
user_id, {}
|
||||
)
|
||||
signing_user_signatures[key_id] = signature
|
||||
|
||||
for key_id, signature in signature_list:
|
||||
signing_user_signatures[key_id] = signature
|
||||
|
||||
log_kv(result)
|
||||
return result
|
||||
@@ -479,41 +484,83 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
|
||||
return result
|
||||
|
||||
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
|
||||
) -> List[Tuple[str, str, str, str]]:
|
||||
"""Get cross-signing signatures for a given list of devices
|
||||
|
||||
Returns signatures made by the owners of the devices.
|
||||
|
||||
Returns: a list of results; each entry in the list is a tuple of
|
||||
(user_id, key_id, target_device_id, signature).
|
||||
@cached()
|
||||
def _get_e2e_cross_signing_signatures_for_device(
|
||||
self,
|
||||
user_id_and_device_id: Tuple[str, str],
|
||||
) -> Sequence[Tuple[str, str]]:
|
||||
"""
|
||||
signature_query_clauses = []
|
||||
signature_query_params = []
|
||||
The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
|
||||
See @cachedList for why a separate method is needed.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
for user_id, device_id in device_query:
|
||||
signature_query_clauses.append(
|
||||
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
|
||||
@cachedList(
|
||||
cached_method_name="_get_e2e_cross_signing_signatures_for_device",
|
||||
list_name="device_query",
|
||||
)
|
||||
async def _get_e2e_cross_signing_signatures_for_devices(
|
||||
self, device_query: Iterable[Tuple[str, str]]
|
||||
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
|
||||
"""Get cross-signing signatures for a given list of user IDs and devices.
|
||||
|
||||
Args:
|
||||
An iterable containing tuples of (user ID, device ID).
|
||||
|
||||
Returns:
|
||||
A mapping of results. The keys are the original (user_id, device_id)
|
||||
tuple, while the value is the matching list of tuples of
|
||||
(key_id, signature). The value will be an empty list if no
|
||||
signatures exist for the device.
|
||||
|
||||
Given this method is annotated with `@cachedList`, the return dict's
|
||||
keys match the tuples within `device_query`, so that cache entries can
|
||||
be computed from the corresponding values.
|
||||
|
||||
As results are cached, the return type is immutable.
|
||||
"""
|
||||
|
||||
def _get_e2e_cross_signing_signatures_for_devices_txn(
|
||||
txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
|
||||
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
|
||||
where_clause_sql, where_clause_params = make_tuple_in_list_sql_clause(
|
||||
self.database_engine,
|
||||
columns=("target_user_id", "target_device_id", "user_id"),
|
||||
iterable=[
|
||||
(user_id, device_id, user_id) for user_id, device_id in device_query
|
||||
],
|
||||
)
|
||||
signature_query_params.extend([user_id, device_id, user_id])
|
||||
|
||||
signature_sql = """
|
||||
SELECT user_id, key_id, target_device_id, signature
|
||||
FROM e2e_cross_signing_signatures WHERE %s
|
||||
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
|
||||
signature_sql = f"""
|
||||
SELECT user_id, key_id, target_device_id, signature
|
||||
FROM e2e_cross_signing_signatures WHERE {where_clause_sql}
|
||||
"""
|
||||
|
||||
txn.execute(signature_sql, signature_query_params)
|
||||
return cast(
|
||||
List[
|
||||
Tuple[
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
]
|
||||
],
|
||||
txn.fetchall(),
|
||||
txn.execute(signature_sql, where_clause_params)
|
||||
|
||||
devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
|
||||
|
||||
# `@cachedList` requires we return one key for every item in `device_query`.
|
||||
# Pre-populate `devices_and_signatures` with each key so that none are missing.
|
||||
#
|
||||
# If any are missing, they will be cached as `None`, which is not
|
||||
# what callers expected.
|
||||
for user_id, device_id in device_query:
|
||||
devices_and_signatures.setdefault((user_id, device_id), [])
|
||||
|
||||
# Populate the return dictionary with each found key_id and signature.
|
||||
for user_id, key_id, target_device_id, signature in txn.fetchall():
|
||||
signature_tuple = (key_id, signature)
|
||||
devices_and_signatures[(user_id, target_device_id)].append(
|
||||
signature_tuple
|
||||
)
|
||||
|
||||
return devices_and_signatures
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"_get_e2e_cross_signing_signatures_for_devices_txn",
|
||||
_get_e2e_cross_signing_signatures_for_devices_txn,
|
||||
device_query,
|
||||
)
|
||||
|
||||
async def get_e2e_one_time_keys(
|
||||
@@ -1772,26 +1819,71 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
user_id: the user who made the signatures
|
||||
signatures: signatures to add
|
||||
"""
|
||||
await self.db_pool.simple_insert_many(
|
||||
"e2e_cross_signing_signatures",
|
||||
keys=(
|
||||
"user_id",
|
||||
"key_id",
|
||||
"target_user_id",
|
||||
"target_device_id",
|
||||
"signature",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
user_id,
|
||||
item.signing_key_id,
|
||||
item.target_user_id,
|
||||
item.target_device_id,
|
||||
item.signature,
|
||||
)
|
||||
|
||||
def _store_e2e_cross_signing_signatures(
|
||||
txn: LoggingTransaction,
|
||||
signatures: "Iterable[SignatureListItem]",
|
||||
) -> None:
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_signatures",
|
||||
keys=(
|
||||
"user_id",
|
||||
"key_id",
|
||||
"target_user_id",
|
||||
"target_device_id",
|
||||
"signature",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
user_id,
|
||||
item.signing_key_id,
|
||||
item.target_user_id,
|
||||
item.target_device_id,
|
||||
item.signature,
|
||||
)
|
||||
for item in signatures
|
||||
],
|
||||
)
|
||||
|
||||
to_invalidate = [
|
||||
# Each entry is a tuple of arguments to
|
||||
# `_get_e2e_cross_signing_signatures_for_device`, which
|
||||
# itself takes a tuple. Hence the double-tuple.
|
||||
((user_id, item.target_device_id),)
|
||||
for item in signatures
|
||||
],
|
||||
desc="add_e2e_signing_key",
|
||||
]
|
||||
|
||||
if to_invalidate:
|
||||
# Invalidate the local cache of this worker.
|
||||
for cache_key in to_invalidate:
|
||||
txn.call_after(
|
||||
self._get_e2e_cross_signing_signatures_for_device.invalidate,
|
||||
cache_key,
|
||||
)
|
||||
|
||||
# Stream cache invalidate keys over replication.
|
||||
#
|
||||
# We can only send a primitive per function argument across
|
||||
# replication.
|
||||
#
|
||||
# Encode the array of strings as a JSON string, and we'll unpack
|
||||
# it on the other side.
|
||||
to_send = [
|
||||
(json.dumps([user_id, item.target_device_id]),)
|
||||
for item in signatures
|
||||
]
|
||||
|
||||
self._send_invalidation_to_replication_bulk(
|
||||
txn,
|
||||
cache_name=self._get_e2e_cross_signing_signatures_for_device.__name__,
|
||||
key_tuples=to_send,
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"add_e2e_signing_key",
|
||||
_store_e2e_cross_signing_signatures,
|
||||
signatures,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -59,11 +59,11 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -107,8 +107,8 @@ from synapse.storage.database import (
|
||||
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||
from synapse.storage.databases.main.stream import StreamWorkerStore
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -83,9 +83,9 @@ from synapse.types import (
|
||||
)
|
||||
from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.events import get_plain_text_topic_from_event_content
|
||||
from synapse.util.iterutils import batch_iter, sorted_topologically
|
||||
from synapse.util.json import json_encoder
|
||||
from synapse.util.stringutils import non_null_str_or_none
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -58,8 +58,8 @@ from synapse.types import JsonDict, RoomStreamToken, StateMap, StrCollection
|
||||
from synapse.types.handlers import SLIDING_SYNC_DEFAULT_BUMP_EVENT_TYPES
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.types.storage import _BackgroundUpdates
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -38,7 +38,7 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -99,7 +99,7 @@ class LockStore(SQLBaseStore):
|
||||
# lead to a race, as we may drop the lock while we are still processing.
|
||||
# However, a) it should be a small window, b) the lock is best effort
|
||||
# anyway and c) we want to really avoid leaking locks when we restart.
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
hs.get_clock().add_system_event_trigger(
|
||||
"before",
|
||||
"shutdown",
|
||||
self._on_shutdown,
|
||||
|
||||
@@ -56,10 +56,11 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc
|
||||
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder, unwrapFirstError
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import gather_results
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -42,8 +42,8 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -55,10 +55,10 @@ from synapse.types import (
|
||||
PersistedPosition,
|
||||
StrCollection,
|
||||
)
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -65,8 +65,8 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.json import json_encoder
|
||||
from synapse.util.stringutils import MXC_REGEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -30,7 +30,7 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -35,8 +35,8 @@ from synapse.types.handlers.sliding_sync import (
|
||||
RoomStatusMap,
|
||||
RoomSyncConfig,
|
||||
)
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -30,8 +30,8 @@ from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
|
||||
from synapse.types import JsonDict, JsonMapping
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from synapse.storage.database import (
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -27,7 +27,8 @@ from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder, stringutils
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.json import json_encoder
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
|
||||
@@ -116,13 +116,27 @@ StrSequence = Union[Tuple[str, ...], List[str]]
|
||||
|
||||
# Note that this seems to require inheriting *directly* from Interface in order
|
||||
# for mypy-zope to realize it is an interface.
|
||||
class ISynapseReactor(
|
||||
class ISynapseThreadlessReactor(
|
||||
IReactorTCP,
|
||||
IReactorSSL,
|
||||
IReactorUNIX,
|
||||
IReactorPluggableNameResolver,
|
||||
IReactorTime,
|
||||
IReactorCore,
|
||||
Interface,
|
||||
):
|
||||
"""
|
||||
The interfaces necessary for Synapse to function (without threads).
|
||||
|
||||
Helpful because we use `twisted.internet.testing.MemoryReactorClock` in tests which
|
||||
doesn't implement `IReactorThreads`.
|
||||
"""
|
||||
|
||||
|
||||
# Note that this seems to require inheriting *directly* from Interface in order
|
||||
# for mypy-zope to realize it is an interface.
|
||||
class ISynapseReactor(
|
||||
ISynapseThreadlessReactor,
|
||||
IReactorThreads,
|
||||
Interface,
|
||||
):
|
||||
|
||||
@@ -20,12 +20,9 @@
|
||||
#
|
||||
|
||||
import collections.abc
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
Mapping,
|
||||
@@ -36,17 +33,11 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from immutabledict import immutabledict
|
||||
from matrix_common.versionstring import get_distribution_version_string
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from twisted.internet import defer, task
|
||||
from twisted.internet.interfaces import IDelayedCall, IReactorTime
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.internet import defer
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging import context
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
@@ -62,41 +53,6 @@ class Duration:
|
||||
DAY_MS = 24 * HOUR_MS
|
||||
|
||||
|
||||
def _reject_invalid_json(val: Any) -> None:
|
||||
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
|
||||
raise ValueError("Invalid JSON value: '%s'" % val)
|
||||
|
||||
|
||||
def _handle_immutabledict(obj: Any) -> Dict[Any, Any]:
|
||||
"""Helper for json_encoder. Makes immutabledicts serializable by returning
|
||||
the underlying dict
|
||||
"""
|
||||
if type(obj) is immutabledict:
|
||||
# fishing the protected dict out of the object is a bit nasty,
|
||||
# but we don't really want the overhead of copying the dict.
|
||||
try:
|
||||
# Safety: we catch the AttributeError immediately below.
|
||||
return obj._dict
|
||||
except AttributeError:
|
||||
# If all else fails, resort to making a copy of the immutabledict
|
||||
return dict(obj)
|
||||
raise TypeError(
|
||||
"Object of type %s is not JSON serializable" % obj.__class__.__name__
|
||||
)
|
||||
|
||||
|
||||
# A custom JSON encoder which:
|
||||
# * handles immutabledicts
|
||||
# * produces valid JSON (no NaNs etc)
|
||||
# * reduces redundant whitespace
|
||||
json_encoder = json.JSONEncoder(
|
||||
allow_nan=False, separators=(",", ":"), default=_handle_immutabledict
|
||||
)
|
||||
|
||||
# Create a custom decoder to reject Python extensions to JSON.
|
||||
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
||||
|
||||
|
||||
def unwrapFirstError(failure: Failure) -> Failure:
|
||||
# Deprecated: you probably just want to catch defer.FirstError and reraise
|
||||
# the subFailure's value, which will do a better job of preserving stacktraces.
|
||||
@@ -105,129 +61,6 @@ def unwrapFirstError(failure: Failure) -> Failure:
|
||||
return failure.value.subFailure
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Clock:
|
||||
"""
|
||||
A Clock wraps a Twisted reactor and provides utilities on top of it.
|
||||
|
||||
Args:
|
||||
reactor: The Twisted reactor to use.
|
||||
"""
|
||||
|
||||
_reactor: IReactorTime = attr.ib()
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
d: defer.Deferred[float] = defer.Deferred()
|
||||
with context.PreserveLoggingContext():
|
||||
self._reactor.callLater(seconds, d.callback, seconds)
|
||||
await d
|
||||
|
||||
def time(self) -> float:
|
||||
"""Returns the current system time in seconds since epoch."""
|
||||
return self._reactor.seconds()
|
||||
|
||||
def time_msec(self) -> int:
|
||||
"""Returns the current system time in milliseconds since epoch."""
|
||||
return int(self.time() * 1000)
|
||||
|
||||
def looping_call(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
msec: float,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> LoopingCall:
|
||||
"""Call a function repeatedly.
|
||||
|
||||
Waits `msec` initially before calling `f` for the first time.
|
||||
|
||||
If the function given to `looping_call` returns an awaitable/deferred, the next
|
||||
call isn't scheduled until after the returned awaitable has finished. We get
|
||||
this functionality thanks to this function being a thin wrapper around
|
||||
`twisted.internet.task.LoopingCall`.
|
||||
|
||||
Note that the function will be called with no logcontext, so if it is anything
|
||||
other than trivial, you probably want to wrap it in run_as_background_process.
|
||||
|
||||
Args:
|
||||
f: The function to call repeatedly.
|
||||
msec: How long to wait between calls in milliseconds.
|
||||
*args: Positional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
return self._looping_call_common(f, msec, False, *args, **kwargs)
|
||||
|
||||
def looping_call_now(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
msec: float,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> LoopingCall:
|
||||
"""Call a function immediately, and then repeatedly thereafter.
|
||||
|
||||
As with `looping_call`: subsequent calls are not scheduled until after the
|
||||
the Awaitable returned by a previous call has finished.
|
||||
|
||||
Also as with `looping_call`: the function is called with no logcontext and
|
||||
you probably want to wrap it in `run_as_background_process`.
|
||||
|
||||
Args:
|
||||
f: The function to call repeatedly.
|
||||
msec: How long to wait between calls in milliseconds.
|
||||
*args: Positional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
return self._looping_call_common(f, msec, True, *args, **kwargs)
|
||||
|
||||
def _looping_call_common(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
msec: float,
|
||||
now: bool,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> LoopingCall:
|
||||
"""Common functionality for `looping_call` and `looping_call_now`"""
|
||||
call = task.LoopingCall(f, *args, **kwargs)
|
||||
call.clock = self._reactor
|
||||
d = call.start(msec / 1000.0, now=now)
|
||||
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
|
||||
return call
|
||||
|
||||
def call_later(
|
||||
self, delay: float, callback: Callable, *args: Any, **kwargs: Any
|
||||
) -> IDelayedCall:
|
||||
"""Call something later
|
||||
|
||||
Note that the function will be called with no logcontext, so if it is anything
|
||||
other than trivial, you probably want to wrap it in run_as_background_process.
|
||||
|
||||
Args:
|
||||
delay: How long to wait in seconds.
|
||||
callback: Function to call
|
||||
*args: Postional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
with context.PreserveLoggingContext():
|
||||
callback(*args, **kwargs)
|
||||
|
||||
with context.PreserveLoggingContext():
|
||||
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||
|
||||
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
|
||||
try:
|
||||
timer.cancel()
|
||||
except Exception:
|
||||
if not ignore_errs:
|
||||
raise
|
||||
|
||||
|
||||
def log_failure(
|
||||
failure: Failure, msg: str, consumeErrors: bool = True
|
||||
) -> Optional[Failure]:
|
||||
|
||||
@@ -65,7 +65,8 @@ from synapse.logging.context import (
|
||||
run_coroutine_in_background,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.types import ISynapseThreadlessReactor
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -566,7 +567,7 @@ class Linearizer:
|
||||
if not clock:
|
||||
from twisted.internet import reactor
|
||||
|
||||
clock = Clock(cast(IReactorTime, reactor))
|
||||
clock = Clock(cast(ISynapseThreadlessReactor, reactor))
|
||||
self._clock = clock
|
||||
self.max_count = max_count
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ from twisted.internet import defer
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -579,9 +579,12 @@ def cachedList(
|
||||
Used to do batch lookups for an already created cache. One of the arguments
|
||||
is specified as a list that is iterated through to lookup keys in the
|
||||
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
|
||||
the cache gets passed to the original function, which is expected to results
|
||||
the cache gets passed to the original function, which is expected to result
|
||||
in a map of key to value for each passed value. The new results are stored in the
|
||||
original cache. Note that any missing values are cached as None.
|
||||
original cache.
|
||||
|
||||
Note that any values in the input that end up being missing from both the
|
||||
cache and the returned dictionary will be cached as `None`.
|
||||
|
||||
Args:
|
||||
cached_method_name: The name of the single-item lookup method.
|
||||
|
||||
@@ -29,8 +29,8 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.config import cache as cache_config
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import Clock
|
||||
from synapse.util.caches import EvictionReason, register_cache
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -46,20 +46,21 @@ from typing import (
|
||||
)
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
|
||||
from synapse.config import cache as cache_config
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
run_as_background_process,
|
||||
)
|
||||
from synapse.metrics.jemalloc import get_jemalloc_stats
|
||||
from synapse.util import Clock, caches
|
||||
from synapse.types import ISynapseThreadlessReactor
|
||||
from synapse.util import caches
|
||||
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
|
||||
from synapse.util.caches.treecache import (
|
||||
TreeCache,
|
||||
iterate_tree_cache_entry,
|
||||
iterate_tree_cache_items,
|
||||
)
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.linked_list import ListNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -496,7 +497,7 @@ class LruCache(Generic[KT, VT]):
|
||||
# Default `clock` to something sensible. Note that we rename it to
|
||||
# `real_clock` so that mypy doesn't think its still `Optional`.
|
||||
if clock is None:
|
||||
real_clock = Clock(cast(IReactorTime, reactor))
|
||||
real_clock = Clock(cast(ISynapseThreadlessReactor, reactor))
|
||||
else:
|
||||
real_clock = clock
|
||||
|
||||
|
||||
@@ -41,9 +41,9 @@ from synapse.logging.opentracing import (
|
||||
start_active_span,
|
||||
start_active_span_follows_from,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
|
||||
from synapse.util.caches import EvictionReason, register_cache
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
264
synapse/util/clock.py
Normal file
264
synapse/util/clock.py
Normal file
@@ -0,0 +1,264 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl_3.0.html>.
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
|
||||
import attr
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from twisted.internet import defer, task
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
from twisted.internet.task import LoopingCall
|
||||
|
||||
from synapse.logging import context
|
||||
from synapse.types import ISynapseThreadlessReactor
|
||||
from synapse.util import log_failure
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Clock:
|
||||
"""
|
||||
A Clock wraps a Twisted reactor and provides utilities on top of it.
|
||||
|
||||
Args:
|
||||
reactor: The Twisted reactor to use.
|
||||
"""
|
||||
|
||||
_reactor: ISynapseThreadlessReactor = attr.ib()
|
||||
|
||||
async def sleep(self, seconds: float) -> None:
|
||||
d: defer.Deferred[float] = defer.Deferred()
|
||||
with context.PreserveLoggingContext():
|
||||
self._reactor.callLater(seconds, d.callback, seconds)
|
||||
await d
|
||||
|
||||
def time(self) -> float:
|
||||
"""Returns the current system time in seconds since epoch."""
|
||||
return self._reactor.seconds()
|
||||
|
||||
def time_msec(self) -> int:
|
||||
"""Returns the current system time in milliseconds since epoch."""
|
||||
return int(self.time() * 1000)
|
||||
|
||||
def looping_call(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
msec: float,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> LoopingCall:
|
||||
"""Call a function repeatedly.
|
||||
|
||||
Waits `msec` initially before calling `f` for the first time.
|
||||
|
||||
If the function given to `looping_call` returns an awaitable/deferred, the next
|
||||
call isn't scheduled until after the returned awaitable has finished. We get
|
||||
this functionality thanks to this function being a thin wrapper around
|
||||
`twisted.internet.task.LoopingCall`.
|
||||
|
||||
Note that the function will be called with no logcontext, so if it is anything
|
||||
other than trivial, you probably want to wrap it in run_as_background_process.
|
||||
|
||||
Args:
|
||||
f: The function to call repeatedly.
|
||||
msec: How long to wait between calls in milliseconds.
|
||||
*args: Positional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
return self._looping_call_common(f, msec, False, *args, **kwargs)
|
||||
|
||||
def looping_call_now(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
msec: float,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> LoopingCall:
|
||||
"""Call a function immediately, and then repeatedly thereafter.
|
||||
|
||||
As with `looping_call`: subsequent calls are not scheduled until after the
|
||||
the Awaitable returned by a previous call has finished.
|
||||
|
||||
Also as with `looping_call`: the function is called with no logcontext and
|
||||
you probably want to wrap it in `run_as_background_process`.
|
||||
|
||||
Args:
|
||||
f: The function to call repeatedly.
|
||||
msec: How long to wait between calls in milliseconds.
|
||||
*args: Positional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
return self._looping_call_common(f, msec, True, *args, **kwargs)
|
||||
|
||||
def _looping_call_common(
|
||||
self,
|
||||
f: Callable[P, object],
|
||||
msec: float,
|
||||
now: bool,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> LoopingCall:
|
||||
"""Common functionality for `looping_call` and `looping_call_now`"""
|
||||
call = task.LoopingCall(f, *args, **kwargs)
|
||||
call.clock = self._reactor
|
||||
d = call.start(msec / 1000.0, now=now)
|
||||
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
|
||||
return call
|
||||
|
||||
def call_later(
|
||||
self, delay: float, callback: Callable, *args: Any, **kwargs: Any
|
||||
) -> IDelayedCall:
|
||||
"""Call something later
|
||||
|
||||
Note that the function will be called with no logcontext, so if it is anything
|
||||
other than trivial, you probably want to wrap it in run_as_background_process.
|
||||
|
||||
Args:
|
||||
delay: How long to wait in seconds.
|
||||
callback: Function to call
|
||||
*args: Postional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
with context.PreserveLoggingContext():
|
||||
callback(*args, **kwargs)
|
||||
|
||||
with context.PreserveLoggingContext():
|
||||
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
|
||||
|
||||
def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
|
||||
try:
|
||||
timer.cancel()
|
||||
except Exception:
|
||||
if not ignore_errs:
|
||||
raise
|
||||
|
||||
def call_when_running(
|
||||
self,
|
||||
callback: Callable[P, object],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Call a function when the reactor is running.
|
||||
|
||||
If the reactor has not started, the callable will be scheduled to run when it
|
||||
does start. Otherwise, the callable will be invoked immediately.
|
||||
|
||||
Args:
|
||||
callback: Function to call
|
||||
*args: Postional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
# Since this callback can be invoked immediately if the reactor is already
|
||||
# running, we can't always assume that we're running in the sentinel
|
||||
# logcontext (i.e. we can't assert that we're in the sentinel context like
|
||||
# we can in other methods).
|
||||
#
|
||||
# We will only be running in the sentinel logcontext if the reactor was not
|
||||
# running when `call_when_running` was invoked and later starts up.
|
||||
#
|
||||
# assert context.current_context() is context.SENTINEL_CONTEXT
|
||||
|
||||
# Because this is a callback from the reactor, we will be using the
|
||||
# `sentinel` log context at this point. We want the function to log with
|
||||
# some logcontext as we want to know which server the logs came from.
|
||||
#
|
||||
# We use `PreserveLoggingContext` to prevent our new `call_when_running`
|
||||
# logcontext from finishing as soon as we exit this function, in case `f`
|
||||
# returns an awaitable/deferred which would continue running and may try to
|
||||
# restore the `loop_call` context when it's done (because it's trying to
|
||||
# adhere to the Synapse logcontext rules.)
|
||||
#
|
||||
# This also ensures that we return to the `sentinel` context when we exit
|
||||
# this function and yield control back to the reactor to avoid leaking the
|
||||
# current logcontext to the reactor (which would then get picked up and
|
||||
# associated with the next thing the reactor does)
|
||||
with context.PreserveLoggingContext(
|
||||
context.LoggingContext("call_when_running")
|
||||
):
|
||||
# We use `run_in_background` to reset the logcontext after `f` (or the
|
||||
# awaitable returned by `f`) completes to avoid leaking the current
|
||||
# logcontext to the reactor
|
||||
context.run_in_background(callback, *args, **kwargs)
|
||||
|
||||
# We can ignore the lint here since this class is the one location
|
||||
# callWhenRunning should be called.
|
||||
self._reactor.callWhenRunning(wrapped_callback, *args, **kwargs) # type: ignore[prefer-synapse-clock-call-when-running]
|
||||
|
||||
def add_system_event_trigger(
|
||||
self,
|
||||
phase: str,
|
||||
event_type: str,
|
||||
callback: Callable[P, object],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Add a function to be called when a system event occurs.
|
||||
|
||||
Equivalent to `reactor.addSystemEventTrigger` (see the that docstring for more
|
||||
details), but ensures that the callback is run in a logging context.
|
||||
|
||||
Args:
|
||||
phase: a time to call the event -- either the string 'before', 'after', or
|
||||
'during', describing when to call it relative to the event's execution.
|
||||
eventType: this is a string describing the type of event.
|
||||
callback: Function to call
|
||||
*args: Postional arguments to pass to function.
|
||||
**kwargs: Key arguments to pass to function.
|
||||
"""
|
||||
|
||||
def wrapped_callback(*args: Any, **kwargs: Any) -> None:
|
||||
assert context.current_context() is context.SENTINEL_CONTEXT, (
|
||||
"Expected `add_system_event_trigger` callback from the reactor to start with the sentinel logcontext "
|
||||
f"but saw {context.current_context()}. In other words, another task shouldn't have "
|
||||
"leaked their logcontext to us."
|
||||
)
|
||||
|
||||
# Because this is a callback from the reactor, we will be using the
|
||||
# `sentinel` log context at this point. We want the function to log with
|
||||
# some logcontext as we want to know which server the logs came from.
|
||||
#
|
||||
# We use `PreserveLoggingContext` to prevent our new `system_event`
|
||||
# logcontext from finishing as soon as we exit this function, in case `f`
|
||||
# returns an awaitable/deferred which would continue running and may try to
|
||||
# restore the `loop_call` context when it's done (because it's trying to
|
||||
# adhere to the Synapse logcontext rules.)
|
||||
#
|
||||
# This also ensures that we return to the `sentinel` context when we exit
|
||||
# this function and yield control back to the reactor to avoid leaking the
|
||||
# current logcontext to the reactor (which would then get picked up and
|
||||
# associated with the next thing the reactor does)
|
||||
with context.PreserveLoggingContext(context.LoggingContext("system_event")):
|
||||
# We use `run_in_background` to reset the logcontext after `f` (or the
|
||||
# awaitable returned by `f`) completes to avoid leaking the current
|
||||
# logcontext to the reactor
|
||||
context.run_in_background(callback, *args, **kwargs)
|
||||
|
||||
# We can ignore the lint here since this class is the one location
|
||||
# `addSystemEventTrigger` should be called.
|
||||
self._reactor.addSystemEventTrigger(
|
||||
phase, event_type, wrapped_callback, *args, **kwargs
|
||||
) # type: ignore[prefer-synapse-clock-add-system-event-trigger]
|
||||
57
synapse/util/json.py
Normal file
57
synapse/util/json.py
Normal file
@@ -0,0 +1,57 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl_3.0.html>.
|
||||
#
|
||||
#
|
||||
|
||||
import json
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
)
|
||||
|
||||
from immutabledict import immutabledict
|
||||
|
||||
|
||||
def _reject_invalid_json(val: Any) -> None:
|
||||
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
|
||||
raise ValueError("Invalid JSON value: '%s'" % val)
|
||||
|
||||
|
||||
def _handle_immutabledict(obj: Any) -> Dict[Any, Any]:
|
||||
"""Helper for json_encoder. Makes immutabledicts serializable by returning
|
||||
the underlying dict
|
||||
"""
|
||||
if type(obj) is immutabledict:
|
||||
# fishing the protected dict out of the object is a bit nasty,
|
||||
# but we don't really want the overhead of copying the dict.
|
||||
try:
|
||||
# Safety: we catch the AttributeError immediately below.
|
||||
return obj._dict
|
||||
except AttributeError:
|
||||
# If all else fails, resort to making a copy of the immutabledict
|
||||
return dict(obj)
|
||||
raise TypeError(
|
||||
"Object of type %s is not JSON serializable" % obj.__class__.__name__
|
||||
)
|
||||
|
||||
|
||||
# A custom JSON encoder which:
|
||||
# * handles immutabledicts
|
||||
# * produces valid JSON (no NaNs etc)
|
||||
# * reduces redundant whitespace
|
||||
json_encoder = json.JSONEncoder(
|
||||
allow_nan=False, separators=(",", ":"), default=_handle_immutabledict
|
||||
)
|
||||
|
||||
# Create a custom decoder to reject Python extensions to JSON.
|
||||
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
||||
@@ -28,7 +28,8 @@ import attr
|
||||
import pymacaroons
|
||||
from pymacaroons.exceptions import MacaroonVerificationFailedException
|
||||
|
||||
from synapse.util import Clock, stringutils
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
MacaroonType = Literal["access", "delete_pusher", "session"]
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ from synapse.logging.context import (
|
||||
current_context,
|
||||
)
|
||||
from synapse.metrics import SERVER_NAME_LABEL, InFlightGauge
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ from synapse.logging.context import (
|
||||
)
|
||||
from synapse.logging.opentracing import start_active_span
|
||||
from synapse.metrics import SERVER_NAME_LABEL, Histogram, LaterGauge
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from contextlib import _GeneratorContextManager
|
||||
|
||||
@@ -27,7 +27,7 @@ from synapse.api.errors import CodeMessageException
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage import DataStore
|
||||
from synapse.types import StrCollection
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.notifier import Notifier
|
||||
|
||||
@@ -55,7 +55,7 @@ from synapse.types import (
|
||||
get_domain_from_id,
|
||||
)
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
filtered_event_logger = logging.getLogger("synapse.visibility.filtered_event_debug")
|
||||
|
||||
@@ -62,7 +62,10 @@ def make_test(
|
||||
return res
|
||||
|
||||
d.addBoth(on_done)
|
||||
reactor.callWhenRunning(lambda: d.callback(True))
|
||||
# type-ignore: This is outside of Synapse (just a utility benchmark script)
|
||||
# so we don't need to worry about which server the logs are coming from
|
||||
# (`Clock.call_when_running` manages the logcontext for us).
|
||||
reactor.callWhenRunning(lambda: d.callback(True)) # type: ignore[prefer-synapse-clock-call-when-running]
|
||||
reactor.run()
|
||||
|
||||
# mypy thinks this is an object for some reason.
|
||||
|
||||
@@ -37,7 +37,7 @@ from synapse.config.logger import _setup_stdlib_logging
|
||||
from synapse.logging import RemoteHandler
|
||||
from synapse.synapse_rust import reset_logging_config
|
||||
from synapse.types import ISynapseReactor
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
|
||||
class LineCounter(LineOnlyReceiver):
|
||||
|
||||
@@ -39,7 +39,7 @@ from synapse.appservice import ApplicationService
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
@@ -33,7 +33,7 @@ from synapse.api.filtering import Filter
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -17,7 +17,7 @@ from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.api.urls import LoginSSORedirectURIBuilder
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
@@ -53,3 +53,29 @@ class LoginSSORedirectURIBuilderTestCase(HomeserverTestCase):
|
||||
),
|
||||
"https://test/_matrix/client/v3/login/sso/redirect/oidc-github?redirectUrl=https%3A%2F%2Fx%3F%3Cab+c%3E%26q%22%2B%253D%252B%22%3D%22f%C3%B6%2526%3Do%22",
|
||||
)
|
||||
|
||||
def test_idp_id_with_slash_is_escaped(self) -> None:
|
||||
"""
|
||||
Test to make sure that we properly URL encode the IdP ID.
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="foo/bar",
|
||||
client_redirect_url="http://example.com/redirect",
|
||||
),
|
||||
"https://test/_matrix/client/v3/login/sso/redirect/foo%2Fbar?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
|
||||
)
|
||||
|
||||
def test_url_as_idp_id_is_escaped(self) -> None:
|
||||
"""
|
||||
Test to make sure that we properly URL encode the IdP ID.
|
||||
|
||||
The IdP ID shouldn't be a URL.
|
||||
"""
|
||||
self.assertEqual(
|
||||
self.login_sso_redirect_url_builder.build_login_sso_redirect_uri(
|
||||
idp_id="http://should-not-be-url.com/",
|
||||
client_redirect_url="http://example.com/redirect",
|
||||
),
|
||||
"https://test/_matrix/client/v3/login/sso/redirect/http%3A%2F%2Fshould-not-be-url.com%2F?redirectUrl=http%3A%2F%2Fexample.com%2Fredirect",
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from synapse.app.homeserver import SynapseHomeServer
|
||||
from synapse.config.server import parse_listener_def
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests.server import make_request
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user