mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
18 Commits
anoa/codex
...
madlittlem
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b3d923a47 | ||
|
|
1d731ec93d | ||
|
|
c076082191 | ||
|
|
37ea1ae686 | ||
|
|
84d64251dc | ||
|
|
2bed3fb566 | ||
|
|
2c60b67a95 | ||
|
|
6358afff8d | ||
|
|
f7b547e2d8 | ||
|
|
8f7bd946de | ||
|
|
4f80fa4b0a | ||
|
|
b2592667a4 | ||
|
|
769d30a247 | ||
|
|
7ecfe8b1a8 | ||
|
|
e1036ffa48 | ||
|
|
8c98cf7e55 | ||
|
|
ec64c3e88d | ||
|
|
ada3a3b2b3 |
2
.github/workflows/docker.yml
vendored
2
.github/workflows/docker.yml
vendored
@@ -120,7 +120,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
|
||||
|
||||
- name: Install Cosign
|
||||
uses: sigstore/cosign-installer@d58896d6a1865668819e1d91763c7751a165e159 # v3.9.2
|
||||
uses: sigstore/cosign-installer@d7543c93d881b35a8faa02e8e3605f69b7a1ce62 # v3.10.0
|
||||
|
||||
- name: Calculate docker image tag
|
||||
uses: docker/metadata-action@c1e51972afc2121e065aed6d45c65596fe445f3f # v5.8.0
|
||||
|
||||
23
Cargo.lock
generated
23
Cargo.lock
generated
@@ -1250,18 +1250,28 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.219"
|
||||
version = "1.0.224"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
|
||||
checksum = "6aaeb1e94f53b16384af593c71e20b095e958dab1d26939c1b70645c5cfbcc0b"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_core"
|
||||
version = "1.0.224"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "32f39390fa6346e24defbcdd3d9544ba8a19985d0af74df8501fbfe9a64341ab"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.219"
|
||||
version = "1.0.224"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
|
||||
checksum = "87ff78ab5e8561c9a675bfc1785cb07ae721f0ee53329a595cefd8c04c2ac4e0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -1270,14 +1280,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.143"
|
||||
version = "1.0.145"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a"
|
||||
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
1
changelog.d/18641.bugfix
Normal file
1
changelog.d/18641.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Ensure all PDUs sent via `/send` pass canonical JSON checks.
|
||||
1
changelog.d/18695.feature
Normal file
1
changelog.d/18695.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add experimental support for [MSC4308: Thread Subscriptions extension to Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4308) when [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/pull/4306) and [MSC4186: Simplified Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4186) are enabled.
|
||||
1
changelog.d/18848.feature
Normal file
1
changelog.d/18848.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add `get_media_upload_limits_for_user` and `on_media_upload_limit_exceeded` module API callbacks for media repository.
|
||||
1
changelog.d/18856.doc
Normal file
1
changelog.d/18856.doc
Normal file
@@ -0,0 +1 @@
|
||||
Clarify Python dependency constraints in our deprecation policy.
|
||||
1
changelog.d/18870.misc
Normal file
1
changelog.d/18870.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove `sentinel` logcontext usage where we log in `setup`, `start` and exit.
|
||||
1
changelog.d/18906.misc
Normal file
1
changelog.d/18906.misc
Normal file
@@ -0,0 +1 @@
|
||||
Better explain how we manage the logcontext in `run_in_background(...)` and `run_as_background_process(...)`.
|
||||
1
changelog.d/18933.misc
Normal file
1
changelog.d/18933.misc
Normal file
@@ -0,0 +1 @@
|
||||
Split loading config from homeserver `setup`.
|
||||
@@ -1,13 +1,11 @@
|
||||
Deprecation Policy for Platform Dependencies
|
||||
============================================
|
||||
# Deprecation Policy
|
||||
|
||||
Synapse has a number of platform dependencies, including Python, Rust,
|
||||
PostgreSQL and SQLite. This document outlines the policy towards which versions
|
||||
we support, and when we drop support for versions in the future.
|
||||
Synapse has a number of **platform dependencies** (Python, Rust, PostgreSQL, and SQLite)
|
||||
and **application dependencies** (Python and Rust packages). This document outlines the
|
||||
policy towards which versions we support, and when we drop support for versions in the
|
||||
future.
|
||||
|
||||
|
||||
Policy
|
||||
------
|
||||
## Platform Dependencies
|
||||
|
||||
Synapse follows the upstream support life cycles for Python and PostgreSQL,
|
||||
i.e. when a version reaches End of Life Synapse will withdraw support for that
|
||||
@@ -26,8 +24,8 @@ The oldest supported version of SQLite is the version
|
||||
[provided](https://packages.debian.org/bullseye/libsqlite3-0) by
|
||||
[Debian oldstable](https://wiki.debian.org/DebianOldStable).
|
||||
|
||||
Context
|
||||
-------
|
||||
|
||||
### Context
|
||||
|
||||
It is important for system admins to have a clear understanding of the platform
|
||||
requirements of Synapse and its deprecation policies so that they can
|
||||
@@ -50,4 +48,42 @@ the ecosystem.
|
||||
On a similar note, SQLite does not generally have a concept of "supported
|
||||
release"; bugfixes are published for the latest minor release only. We chose to
|
||||
track Debian's oldstable as this is relatively conservative, predictably updated
|
||||
and is consistent with the `.deb` packages released by Matrix.org.
|
||||
and is consistent with the `.deb` packages released by Matrix.org.
|
||||
|
||||
|
||||
## Application dependencies
|
||||
|
||||
For application-level Python dependencies, we often specify loose version constraints
|
||||
(ex. `>=X.Y.Z`) to be forwards compatible with any new versions. Upper bounds (`<A.B.C`)
|
||||
are only added when necessary to prevent known incompatibilities.
|
||||
|
||||
When selecting a minimum version, while we are mindful of the impact on downstream
|
||||
package maintainers, our primary focus is on the maintainability and progress of Synapse
|
||||
itself.
|
||||
|
||||
For developers, a Python dependency version can be considered a "no-brainer" upgrade once it is
|
||||
available in both the latest [Debian Stable](https://packages.debian.org/stable/) and
|
||||
[Ubuntu LTS](https://launchpad.net/ubuntu) repositories. No need to burden yourself with
|
||||
extra scrutiny or consideration at this point.
|
||||
|
||||
We aggressively update Rust dependencies. Since these are statically linked and managed
|
||||
entirely by `cargo` during build, they *can* pose no ongoing maintenance burden on others.
|
||||
This allows us to freely upgrade to leverage the latest ecosystem advancements assuming
|
||||
they don't have their own system-level dependencies.
|
||||
|
||||
|
||||
### Context
|
||||
|
||||
Because Python dependencies can easily be managed in a virtual environment, we are less
|
||||
concerned about the criteria for selecting minimum versions. The only thing of concern
|
||||
is making sure we're not making it unnecessarily difficult for downstream package
|
||||
maintainers. Generally, this just means avoiding the bleeding edge for a few months.
|
||||
|
||||
The situation for Rust dependencies is fundamentally different. For packagers, the
|
||||
concerns around Python dependency versions do not apply. The `cargo` tool handles
|
||||
downloading and building all libraries to satisfy dependencies, and these libraries are
|
||||
statically linked into the final binary. This means that from a packager's perspective,
|
||||
the Rust dependency versions are an internal build detail, not a runtime dependency to
|
||||
be managed on the target system. Consequently, we have even greater flexibility to
|
||||
upgrade Rust dependencies as needed for the project. Some distros (e.g. Fedora) do
|
||||
package Rust libraries, but this appears to be the outlier rather than the norm.
|
||||
|
||||
@@ -64,3 +64,68 @@ If multiple modules implement this callback, they will be considered in order. I
|
||||
returns `True`, Synapse falls through to the next one. The value of the first callback that
|
||||
returns `False` will be used. If this happens, Synapse will not call any of the subsequent
|
||||
implementations of this callback.
|
||||
|
||||
### `get_media_upload_limits_for_user`
|
||||
|
||||
_First introduced in Synapse v1.139.0_
|
||||
|
||||
```python
|
||||
async def get_media_upload_limits_for_user(user_id: str, size: int) -> Optional[List[synapse.module_api.MediaUploadLimit]]
|
||||
```
|
||||
|
||||
**<span style="color:red">
|
||||
Caution: This callback is currently experimental. The method signature or behaviour
|
||||
may change without notice.
|
||||
</span>**
|
||||
|
||||
Called when processing a request to store content in the media repository. This can be used to dynamically override
|
||||
the [media upload limits configuration](../usage/configuration/config_documentation.html#media_upload_limits).
|
||||
|
||||
The arguments passed to this callback are:
|
||||
|
||||
* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request.
|
||||
|
||||
If the callback returns a list then it will be used as the limits instead of those in the configuration (if any).
|
||||
|
||||
If an empty list is returned then no limits are applied (**warning:** users will be able
|
||||
to upload as much data as they desire).
|
||||
|
||||
If multiple modules implement this callback, they will be considered in order. If a
|
||||
callback returns `None`, Synapse falls through to the next one. The value of the first
|
||||
callback that does not return `None` will be used. If this happens, Synapse will not call
|
||||
any of the subsequent implementations of this callback.
|
||||
|
||||
If there are no registered modules, or if all modules return `None`, then
|
||||
the default
|
||||
[media upload limits configuration](../usage/configuration/config_documentation.html#media_upload_limits)
|
||||
will be used.
|
||||
|
||||
### `on_media_upload_limit_exceeded`
|
||||
|
||||
_First introduced in Synapse v1.139.0_
|
||||
|
||||
```python
|
||||
async def on_media_upload_limit_exceeded(user_id: str, limit: synapse.module_api.MediaUploadLimit, sent_bytes: int, attempted_bytes: int) -> None
|
||||
```
|
||||
|
||||
**<span style="color:red">
|
||||
Caution: This callback is currently experimental. The method signature or behaviour
|
||||
may change without notice.
|
||||
</span>**
|
||||
|
||||
Called when a user attempts to upload media that would exceed a
|
||||
[configured media upload limit](../usage/configuration/config_documentation.html#media_upload_limits).
|
||||
|
||||
This callback will only be called on workers which handle
|
||||
[POST /_matrix/media/v3/upload](https://spec.matrix.org/v1.15/client-server-api/#post_matrixmediav3upload)
|
||||
requests.
|
||||
|
||||
This could be used to inform the user that they have reached a media upload limit through
|
||||
some external method.
|
||||
|
||||
The arguments passed to this callback are:
|
||||
|
||||
* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request.
|
||||
* `limit`: The `synapse.module_api.MediaUploadLimit` representing the limit that was reached.
|
||||
* `sent_bytes`: The number of bytes already sent during the period of the limit.
|
||||
* `attempted_bytes`: The number of bytes that the user attempted to send.
|
||||
|
||||
@@ -2168,9 +2168,12 @@ max_upload_size: 60M
|
||||
### `media_upload_limits`
|
||||
|
||||
*(array)* A list of media upload limits defining how much data a given user can upload in a given time period.
|
||||
These limits are applied in addition to the `max_upload_size` limit above (which applies to individual uploads).
|
||||
|
||||
An empty list means no limits are applied.
|
||||
|
||||
These settings can be overridden using the `get_media_upload_limits_for_user` module API [callback](../../modules/media_repository_callbacks.md#get_media_upload_limits_for_user).
|
||||
|
||||
Defaults to `[]`.
|
||||
|
||||
Example configuration:
|
||||
|
||||
24
poetry.lock
generated
24
poetry.lock
generated
@@ -34,15 +34,15 @@ tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" a
|
||||
|
||||
[[package]]
|
||||
name = "authlib"
|
||||
version = "1.6.1"
|
||||
version = "1.6.3"
|
||||
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
markers = "extra == \"all\" or extra == \"jwt\" or extra == \"oidc\""
|
||||
files = [
|
||||
{file = "authlib-1.6.1-py2.py3-none-any.whl", hash = "sha256:e9d2031c34c6309373ab845afc24168fe9e93dc52d252631f52642f21f5ed06e"},
|
||||
{file = "authlib-1.6.1.tar.gz", hash = "sha256:4dffdbb1460ba6ec8c17981a4c67af7d8af131231b5a36a88a1e8c80c111cdfd"},
|
||||
{file = "authlib-1.6.3-py2.py3-none-any.whl", hash = "sha256:7ea0f082edd95a03b7b72edac65ec7f8f68d703017d7e37573aee4fc603f2a48"},
|
||||
{file = "authlib-1.6.3.tar.gz", hash = "sha256:9f7a982cc395de719e4c2215c5707e7ea690ecf84f1ab126f28c053f4219e610"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1774,14 +1774,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.11.7"
|
||||
version = "2.11.9"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main", "dev"]
|
||||
files = [
|
||||
{file = "pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b"},
|
||||
{file = "pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db"},
|
||||
{file = "pydantic-2.11.9-py3-none-any.whl", hash = "sha256:c42dd626f5cfc1c6950ce6205ea58c93efa406da65f479dcb4029d5934857da2"},
|
||||
{file = "pydantic-2.11.9.tar.gz", hash = "sha256:6b8ffda597a14812a7975c90b82a8a2e777d9257aba3453f973acd3c032a18e2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2971,14 +2971,14 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "types-psycopg2"
|
||||
version = "2.9.21.20250809"
|
||||
version = "2.9.21.20250915"
|
||||
description = "Typing stubs for psycopg2"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "types_psycopg2-2.9.21.20250809-py3-none-any.whl", hash = "sha256:59b7b0ed56dcae9efae62b8373497274fc1a0484bdc5135cdacbe5a8f44e1d7b"},
|
||||
{file = "types_psycopg2-2.9.21.20250809.tar.gz", hash = "sha256:b7c2cbdcf7c0bd16240f59ba694347329b0463e43398de69784ea4dee45f3c6d"},
|
||||
{file = "types_psycopg2-2.9.21.20250915-py3-none-any.whl", hash = "sha256:eefe5ccdc693fc086146e84c9ba437bb278efe1ef330b299a0cb71169dc6c55f"},
|
||||
{file = "types_psycopg2-2.9.21.20250915.tar.gz", hash = "sha256:bfeb8f54c32490e7b5edc46215ab4163693192bc90407b4a023822de9239f5c8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3026,14 +3026,14 @@ urllib3 = ">=2"
|
||||
|
||||
[[package]]
|
||||
name = "types-setuptools"
|
||||
version = "80.9.0.20250809"
|
||||
version = "80.9.0.20250822"
|
||||
description = "Typing stubs for setuptools"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "types_setuptools-80.9.0.20250809-py3-none-any.whl", hash = "sha256:7c6539b4c7ac7b4ab4db2be66d8a58fb1e28affa3ee3834be48acafd94f5976a"},
|
||||
{file = "types_setuptools-80.9.0.20250809.tar.gz", hash = "sha256:e986ba37ffde364073d76189e1d79d9928fb6f5278c7d07589cde353d0218864"},
|
||||
{file = "types_setuptools-80.9.0.20250822-py3-none-any.whl", hash = "sha256:53bf881cb9d7e46ed12c76ef76c0aaf28cfe6211d3fab12e0b83620b1a8642c3"},
|
||||
{file = "types_setuptools-80.9.0.20250822.tar.gz", hash = "sha256:070ea7716968ec67a84c7f7768d9952ff24d28b65b6594797a464f1b3066f965"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -2415,8 +2415,15 @@ properties:
|
||||
A list of media upload limits defining how much data a given user can
|
||||
upload in a given time period.
|
||||
|
||||
These limits are applied in addition to the `max_upload_size` limit above
|
||||
(which applies to individual uploads).
|
||||
|
||||
|
||||
An empty list means no limits are applied.
|
||||
|
||||
|
||||
These settings can be overridden using the `get_media_upload_limits_for_user`
|
||||
module API [callback](../../modules/media_repository_callbacks.md#get_media_upload_limits_for_user).
|
||||
default: []
|
||||
items:
|
||||
time_period:
|
||||
|
||||
@@ -72,7 +72,7 @@ from synapse.events.auto_accept_invites import InviteAutoAccepter
|
||||
from synapse.events.presence_router import load_legacy_presence_router
|
||||
from synapse.handlers.auth import load_legacy_password_auth_providers
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||
from synapse.logging.opentracing import init_tracer
|
||||
from synapse.metrics import install_gc_manager, register_threadpool
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
@@ -183,25 +183,23 @@ def start_reactor(
|
||||
if gc_thresholds:
|
||||
gc.set_threshold(*gc_thresholds)
|
||||
install_gc_manager()
|
||||
run_command()
|
||||
|
||||
# make sure that we run the reactor with the sentinel log context,
|
||||
# otherwise other PreserveLoggingContext instances will get confused
|
||||
# and complain when they see the logcontext arbitrarily swapping
|
||||
# between the sentinel and `run` logcontexts.
|
||||
#
|
||||
# We also need to drop the logcontext before forking if we're daemonizing,
|
||||
# otherwise the cputime metrics get confused about the per-thread resource usage
|
||||
# appearing to go backwards.
|
||||
with PreserveLoggingContext():
|
||||
if daemonize:
|
||||
assert pid_file is not None
|
||||
# Reset the logging context when we start the reactor (whenever we yield control
|
||||
# to the reactor, the `sentinel` logging context needs to be set so we don't
|
||||
# leak the current logging context and erroneously apply it to the next task the
|
||||
# reactor event loop picks up)
|
||||
with PreserveLoggingContext():
|
||||
run_command()
|
||||
|
||||
if print_pidfile:
|
||||
print(pid_file)
|
||||
if daemonize:
|
||||
assert pid_file is not None
|
||||
|
||||
daemonize_process(pid_file, logger)
|
||||
run()
|
||||
if print_pidfile:
|
||||
print(pid_file)
|
||||
|
||||
daemonize_process(pid_file, logger)
|
||||
|
||||
run()
|
||||
|
||||
|
||||
def quit_with_error(error_string: str) -> NoReturn:
|
||||
@@ -601,10 +599,12 @@ async def start(hs: "HomeServer") -> None:
|
||||
hs.get_datastores().main.db_pool.start_profiling()
|
||||
hs.get_pusherpool().start()
|
||||
|
||||
def log_shutdown() -> None:
|
||||
with LoggingContext("log_shutdown"):
|
||||
logger.info("Shutting down...")
|
||||
|
||||
# Log when we start the shut down process.
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
"before", "shutdown", logger.info, "Shutting down..."
|
||||
)
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", log_shutdown)
|
||||
|
||||
setup_sentry(hs)
|
||||
setup_sdnotify(hs)
|
||||
|
||||
@@ -24,7 +24,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import List, Mapping, Optional, Sequence
|
||||
from typing import List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from twisted.internet import defer, task
|
||||
|
||||
@@ -256,7 +256,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
|
||||
return self.base_directory
|
||||
|
||||
|
||||
def start(config_options: List[str]) -> None:
|
||||
def load_config(argv_options: List[str]) -> Tuple[HomeServerConfig, argparse.Namespace]:
|
||||
parser = argparse.ArgumentParser(description="Synapse Admin Command")
|
||||
HomeServerConfig.add_arguments_to_parser(parser)
|
||||
|
||||
@@ -282,11 +282,15 @@ def start(config_options: List[str]) -> None:
|
||||
export_data_parser.set_defaults(func=export_data_command)
|
||||
|
||||
try:
|
||||
config, args = HomeServerConfig.load_config_with_parser(parser, config_options)
|
||||
config, args = HomeServerConfig.load_config_with_parser(parser, argv_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
return config, args
|
||||
|
||||
|
||||
def start(config: HomeServerConfig, args: argparse.Namespace) -> None:
|
||||
if config.worker.worker_app is not None:
|
||||
assert config.worker.worker_app == "synapse.app.admin_cmd"
|
||||
|
||||
@@ -325,7 +329,7 @@ def start(config_options: List[str]) -> None:
|
||||
# command.
|
||||
|
||||
async def run() -> None:
|
||||
with LoggingContext("command"):
|
||||
with LoggingContext(name="command"):
|
||||
await _base.start(ss)
|
||||
await args.func(ss, args)
|
||||
|
||||
@@ -337,5 +341,6 @@ def start(config_options: List[str]) -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config, args = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config, args)
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -20,13 +20,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -310,13 +310,26 @@ class GenericWorkerServer(HomeServer):
|
||||
self.get_replication_command_handler().start_replication(self)
|
||||
|
||||
|
||||
def start(config_options: List[str]) -> None:
|
||||
def load_config(argv_options: List[str]) -> HomeServerConfig:
|
||||
"""
|
||||
Parse the commandline and config files (does not generate config)
|
||||
|
||||
Args:
|
||||
argv_options: The options passed to Synapse. Usually `sys.argv[1:]`.
|
||||
|
||||
Returns:
|
||||
Config object.
|
||||
"""
|
||||
try:
|
||||
config = HomeServerConfig.load_config("Synapse worker", config_options)
|
||||
config = HomeServerConfig.load_config("Synapse worker", argv_options)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + str(e) + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def start(config: HomeServerConfig) -> None:
|
||||
# For backwards compatibility let any of the old app names.
|
||||
assert config.worker.worker_app in (
|
||||
"synapse.app.appservice",
|
||||
@@ -365,8 +378,9 @@ def start(config_options: List[str]) -> None:
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -308,17 +308,21 @@ class SynapseHomeServer(HomeServer):
|
||||
logger.warning("Unrecognized listener type: %s", listener.type)
|
||||
|
||||
|
||||
def setup(config_options: List[str]) -> SynapseHomeServer:
|
||||
def load_or_generate_config(argv_options: List[str]) -> HomeServerConfig:
|
||||
"""
|
||||
Parse the commandline and config files
|
||||
|
||||
Supports generation of config files, so is used for the main homeserver app.
|
||||
|
||||
Args:
|
||||
config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`.
|
||||
argv_options: The options passed to Synapse. Usually `sys.argv[1:]`.
|
||||
|
||||
Returns:
|
||||
A homeserver instance.
|
||||
"""
|
||||
try:
|
||||
config = HomeServerConfig.load_or_generate_config(
|
||||
"Synapse Homeserver", config_options
|
||||
"Synapse Homeserver", argv_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n")
|
||||
@@ -332,6 +336,20 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
|
||||
# generating config files and shouldn't try to continue.
|
||||
sys.exit(0)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def setup(config: HomeServerConfig) -> SynapseHomeServer:
|
||||
"""
|
||||
Create and setup a Synapse homeserver instance given a configuration.
|
||||
|
||||
Args:
|
||||
config: The configuration for the homeserver.
|
||||
|
||||
Returns:
|
||||
A homeserver instance.
|
||||
"""
|
||||
|
||||
if config.worker.worker_app:
|
||||
raise ConfigError(
|
||||
"You have specified `worker_app` in the config but are attempting to start a non-worker "
|
||||
@@ -377,15 +395,17 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
|
||||
handle_startup_exception(e)
|
||||
|
||||
async def start() -> None:
|
||||
# Load the OIDC provider metadatas, if OIDC is enabled.
|
||||
if hs.config.oidc.oidc_enabled:
|
||||
oidc = hs.get_oidc_handler()
|
||||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
await oidc.load_metadata()
|
||||
# Re-establish log context now that we're back from the reactor
|
||||
with LoggingContext("start"):
|
||||
# Load the OIDC provider metadatas, if OIDC is enabled.
|
||||
if hs.config.oidc.oidc_enabled:
|
||||
oidc = hs.get_oidc_handler()
|
||||
# Loading the provider metadata also ensures the provider config is valid.
|
||||
await oidc.load_metadata()
|
||||
|
||||
await _base.start(hs)
|
||||
await _base.start(hs)
|
||||
|
||||
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
|
||||
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
|
||||
|
||||
register_start(start)
|
||||
|
||||
@@ -405,10 +425,12 @@ def run(hs: HomeServer) -> None:
|
||||
|
||||
|
||||
def main() -> None:
|
||||
homeserver_config = load_or_generate_config(sys.argv[1:])
|
||||
|
||||
with LoggingContext("main"):
|
||||
# check base requirements
|
||||
check_requirements()
|
||||
hs = setup(sys.argv[1:])
|
||||
hs = setup(homeserver_config)
|
||||
|
||||
# redirect stdio to the logs, if configured.
|
||||
if not hs.config.logging.no_redirect_stdio:
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,13 +21,14 @@
|
||||
|
||||
import sys
|
||||
|
||||
from synapse.app.generic_worker import start
|
||||
from synapse.app.generic_worker import load_config, start
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
|
||||
|
||||
def main() -> None:
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
homeserver_config = load_config(sys.argv[1:])
|
||||
with LoggingContext(name="main"):
|
||||
start(homeserver_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -646,12 +646,16 @@ class RootConfig:
|
||||
|
||||
@classmethod
|
||||
def load_or_generate_config(
|
||||
cls: Type[TRootConfig], description: str, argv: List[str]
|
||||
cls: Type[TRootConfig], description: str, argv_options: List[str]
|
||||
) -> Optional[TRootConfig]:
|
||||
"""Parse the commandline and config files
|
||||
|
||||
Supports generation of config files, so is used for the main homeserver app.
|
||||
|
||||
Args:
|
||||
description: TODO
|
||||
argv_options: The options passed to Synapse. Usually `sys.argv[1:]`.
|
||||
|
||||
Returns:
|
||||
Config object, or None if --generate-config or --generate-keys was set
|
||||
"""
|
||||
@@ -747,7 +751,7 @@ class RootConfig:
|
||||
)
|
||||
|
||||
cls.invoke_all_static("add_arguments", parser)
|
||||
config_args = parser.parse_args(argv)
|
||||
config_args = parser.parse_args(argv_options)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
|
||||
@@ -590,5 +590,5 @@ class ExperimentalConfig(Config):
|
||||
self.msc4293_enabled: bool = experimental.get("msc4293_enabled", False)
|
||||
|
||||
# MSC4306: Thread Subscriptions
|
||||
# (and MSC4308: sliding sync extension for thread subscriptions)
|
||||
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
|
||||
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
|
||||
|
||||
@@ -120,11 +120,19 @@ def parse_thumbnail_requirements(
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||
class MediaUploadLimit:
|
||||
"""A limit on the amount of data a user can upload in a given time
|
||||
period."""
|
||||
"""
|
||||
Represents a limit on the amount of data a user can upload in a given time
|
||||
period.
|
||||
|
||||
These can be configured through the `media_upload_limits` [config option](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#media_upload_limits)
|
||||
or via the `get_media_upload_limits_for_user` module API [callback](https://element-hq.github.io/synapse/latest/modules/media_repository_callbacks.html#get_media_upload_limits_for_user).
|
||||
"""
|
||||
|
||||
max_bytes: int
|
||||
"""The maximum number of bytes that can be uploaded in the given time period."""
|
||||
|
||||
time_period_ms: int
|
||||
"""The time period in milliseconds."""
|
||||
|
||||
|
||||
class ContentRepositoryConfig(Config):
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.api.constants import EduTypes
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.persistence import TransactionActions
|
||||
from synapse.federation.units import Edu, Transaction
|
||||
from synapse.federation.units import Edu, Transaction, serialize_and_filter_pdus
|
||||
from synapse.logging.opentracing import (
|
||||
extract_text_map,
|
||||
set_tag,
|
||||
@@ -119,7 +119,7 @@ class TransactionManager:
|
||||
transaction_id=txn_id,
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
pdus=[p.get_pdu_json() for p in pdus],
|
||||
pdus=serialize_and_filter_pdus(pdus),
|
||||
edus=[edu.get_dict() for edu in edus],
|
||||
)
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet):
|
||||
if not self.allow_access:
|
||||
raise FederationDeniedError(origin)
|
||||
|
||||
limit = parse_integer_from_args(query, "limit", 0)
|
||||
limit: Optional[int] = parse_integer_from_args(query, "limit", 0)
|
||||
since_token = parse_string_from_args(query, "since", None)
|
||||
include_all_networks = parse_boolean_from_args(
|
||||
query, "include_all_networks", default=False
|
||||
|
||||
@@ -211,7 +211,7 @@ class SlidingSyncHandler:
|
||||
|
||||
Args:
|
||||
sync_config: Sync configuration
|
||||
to_token: The point in the stream to sync up to.
|
||||
to_token: The latest point in the stream to sync up to.
|
||||
from_token: The point in the stream to sync from. Token of the end of the
|
||||
previous batch. May be `None` if this is the initial sync request.
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from typing_extensions import assert_never
|
||||
from typing_extensions import TypeAlias, assert_never
|
||||
|
||||
from synapse.api.constants import AccountDataTypes, EduTypes
|
||||
from synapse.handlers.receipts import ReceiptEventSource
|
||||
@@ -40,6 +40,7 @@ from synapse.types import (
|
||||
SlidingSyncStreamToken,
|
||||
StrCollection,
|
||||
StreamToken,
|
||||
ThreadSubscriptionsToken,
|
||||
)
|
||||
from synapse.types.handlers.sliding_sync import (
|
||||
HaveSentRoomFlag,
|
||||
@@ -54,6 +55,13 @@ from synapse.util.async_helpers import (
|
||||
gather_optional_coroutines,
|
||||
)
|
||||
|
||||
_ThreadSubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
|
||||
)
|
||||
_ThreadUnsubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -68,6 +76,7 @@ class SlidingSyncExtensionHandler:
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.push_rules_handler = hs.get_push_rules_handler()
|
||||
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
|
||||
|
||||
@trace
|
||||
async def get_extensions_response(
|
||||
@@ -93,7 +102,7 @@ class SlidingSyncExtensionHandler:
|
||||
actual_room_ids: The actual room IDs in the the Sliding Sync response.
|
||||
actual_room_response_map: A map of room ID to room results in the the
|
||||
Sliding Sync response.
|
||||
to_token: The point in the stream to sync up to.
|
||||
to_token: The latest point in the stream to sync up to.
|
||||
from_token: The point in the stream to sync from.
|
||||
"""
|
||||
|
||||
@@ -156,18 +165,32 @@ class SlidingSyncExtensionHandler:
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
thread_subs_coro = None
|
||||
if (
|
||||
sync_config.extensions.thread_subscriptions is not None
|
||||
and self._enable_thread_subscriptions
|
||||
):
|
||||
thread_subs_coro = self.get_thread_subscriptions_extension_response(
|
||||
sync_config=sync_config,
|
||||
thread_subscriptions_request=sync_config.extensions.thread_subscriptions,
|
||||
to_token=to_token,
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
(
|
||||
to_device_response,
|
||||
e2ee_response,
|
||||
account_data_response,
|
||||
receipts_response,
|
||||
typing_response,
|
||||
thread_subs_response,
|
||||
) = await gather_optional_coroutines(
|
||||
to_device_coro,
|
||||
e2ee_coro,
|
||||
account_data_coro,
|
||||
receipts_coro,
|
||||
typing_coro,
|
||||
thread_subs_coro,
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions(
|
||||
@@ -176,6 +199,7 @@ class SlidingSyncExtensionHandler:
|
||||
account_data=account_data_response,
|
||||
receipts=receipts_response,
|
||||
typing=typing_response,
|
||||
thread_subscriptions=thread_subs_response,
|
||||
)
|
||||
|
||||
def find_relevant_room_ids_for_extension(
|
||||
@@ -877,3 +901,72 @@ class SlidingSyncExtensionHandler:
|
||||
return SlidingSyncResult.Extensions.TypingExtension(
|
||||
room_id_to_typing_map=room_id_to_typing_map,
|
||||
)
|
||||
|
||||
async def get_thread_subscriptions_extension_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
thread_subscriptions_request: SlidingSyncConfig.Extensions.ThreadSubscriptionsExtension,
|
||||
to_token: StreamToken,
|
||||
from_token: Optional[SlidingSyncStreamToken],
|
||||
) -> Optional[SlidingSyncResult.Extensions.ThreadSubscriptionsExtension]:
|
||||
"""Handle Thread Subscriptions extension (MSC4308)
|
||||
|
||||
Args:
|
||||
sync_config: Sync configuration
|
||||
thread_subscriptions_request: The thread_subscriptions extension from the request
|
||||
to_token: The point in the stream to sync up to.
|
||||
from_token: The point in the stream to sync from.
|
||||
|
||||
Returns:
|
||||
the response (None if empty or thread subscriptions are disabled)
|
||||
"""
|
||||
if not thread_subscriptions_request.enabled:
|
||||
return None
|
||||
|
||||
limit = thread_subscriptions_request.limit
|
||||
|
||||
if from_token:
|
||||
from_stream_id = from_token.stream_token.thread_subscriptions_key
|
||||
else:
|
||||
from_stream_id = StreamToken.START.thread_subscriptions_key
|
||||
|
||||
to_stream_id = to_token.thread_subscriptions_key
|
||||
|
||||
updates = await self.store.get_latest_updated_thread_subscriptions_for_user(
|
||||
user_id=sync_config.user.to_string(),
|
||||
from_id=from_stream_id,
|
||||
to_id=to_stream_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if len(updates) == 0:
|
||||
return None
|
||||
|
||||
subscribed_threads: Dict[str, Dict[str, _ThreadSubscription]] = {}
|
||||
unsubscribed_threads: Dict[str, Dict[str, _ThreadUnsubscription]] = {}
|
||||
for stream_id, room_id, thread_root_id, subscribed, automatic in updates:
|
||||
if subscribed:
|
||||
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
|
||||
_ThreadSubscription(
|
||||
automatic=automatic,
|
||||
bump_stamp=stream_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
|
||||
_ThreadUnsubscription(bump_stamp=stream_id)
|
||||
)
|
||||
|
||||
prev_batch = None
|
||||
if len(updates) == limit:
|
||||
# Tell the client about a potential gap where there may be more
|
||||
# thread subscriptions for it to backpaginate.
|
||||
# We subtract one because the 'later in the stream' bound is inclusive,
|
||||
# and we already saw the element at index 0.
|
||||
prev_batch = ThreadSubscriptionsToken(updates[0][0] - 1)
|
||||
|
||||
return SlidingSyncResult.Extensions.ThreadSubscriptionsExtension(
|
||||
subscribed=subscribed_threads,
|
||||
unsubscribed=unsubscribed_threads,
|
||||
prev_batch=prev_batch,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from synapse.storage.databases.main.thread_subscriptions import (
|
||||
AutomaticSubscriptionConflicted,
|
||||
ThreadSubscription,
|
||||
)
|
||||
from synapse.types import EventOrderings, UserID
|
||||
from synapse.types import EventOrderings, StreamKeyType, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -22,6 +22,7 @@ class ThreadSubscriptionsHandler:
|
||||
self.store = hs.get_datastores().main
|
||||
self.event_handler = hs.get_event_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self._notifier = hs.get_notifier()
|
||||
|
||||
async def get_thread_subscription_settings(
|
||||
self,
|
||||
@@ -132,6 +133,15 @@ class ThreadSubscriptionsHandler:
|
||||
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
|
||||
)
|
||||
|
||||
if outcome is not None:
|
||||
# wake up user streams (e.g. sliding sync) on the same worker
|
||||
self._notifier.on_new_event(
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS,
|
||||
# outcome is a stream_id
|
||||
outcome,
|
||||
users=[user_id.to_string()],
|
||||
)
|
||||
|
||||
return outcome
|
||||
|
||||
async def unsubscribe_user_from_thread(
|
||||
@@ -162,8 +172,19 @@ class ThreadSubscriptionsHandler:
|
||||
logger.info("rejecting thread subscriptions change (thread not accessible)")
|
||||
raise NotFoundError("No such thread root")
|
||||
|
||||
return await self.store.unsubscribe_user_from_thread(
|
||||
outcome = await self.store.unsubscribe_user_from_thread(
|
||||
user_id.to_string(),
|
||||
event.room_id,
|
||||
thread_root_event_id,
|
||||
)
|
||||
|
||||
if outcome is not None:
|
||||
# wake up user streams (e.g. sliding sync) on the same worker
|
||||
self._notifier.on_new_event(
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS,
|
||||
# outcome is a stream_id
|
||||
outcome,
|
||||
users=[user_id.to_string()],
|
||||
)
|
||||
|
||||
return outcome
|
||||
|
||||
@@ -130,6 +130,16 @@ def parse_integer(
|
||||
return parse_integer_from_args(args, name, default, required, negative)
|
||||
|
||||
|
||||
@overload
|
||||
def parse_integer_from_args(
|
||||
args: Mapping[bytes, Sequence[bytes]],
|
||||
name: str,
|
||||
default: int,
|
||||
required: Literal[False] = False,
|
||||
negative: bool = False,
|
||||
) -> int: ...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_integer_from_args(
|
||||
args: Mapping[bytes, Sequence[bytes]],
|
||||
|
||||
@@ -802,8 +802,9 @@ def run_in_background(
|
||||
deferred returned by the function completes.
|
||||
|
||||
To explain how the log contexts work here:
|
||||
- When this function is called, the current context is stored ("original"), we kick
|
||||
off the background task, and we restore that original context before returning
|
||||
- When `run_in_background` is called, the current context is stored ("original"),
|
||||
we kick off the background task in the current context, and we restore that
|
||||
original context before returning
|
||||
- When the background task finishes, we don't want to leak our context into the
|
||||
reactor which would erroneously get attached to the next operation picked up by
|
||||
the event loop. We add a callback to the deferred which will clear the logging
|
||||
@@ -828,6 +829,7 @@ def run_in_background(
|
||||
"""
|
||||
calling_context = current_context()
|
||||
try:
|
||||
# (kick off the task in the current context)
|
||||
res = f(*args, **kwargs)
|
||||
except Exception:
|
||||
# the assumption here is that the caller doesn't want to be disturbed
|
||||
|
||||
@@ -179,11 +179,13 @@ class MediaRepository:
|
||||
|
||||
# We get the media upload limits and sort them in descending order of
|
||||
# time period, so that we can apply some optimizations.
|
||||
self.media_upload_limits = hs.config.media.media_upload_limits
|
||||
self.media_upload_limits.sort(
|
||||
self.default_media_upload_limits = hs.config.media.media_upload_limits
|
||||
self.default_media_upload_limits.sort(
|
||||
key=lambda limit: limit.time_period_ms, reverse=True
|
||||
)
|
||||
|
||||
self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository
|
||||
|
||||
def _start_update_recently_accessed(self) -> Deferred:
|
||||
return run_as_background_process(
|
||||
"update_recently_accessed_media",
|
||||
@@ -340,16 +342,27 @@ class MediaRepository:
|
||||
|
||||
# Check that the user has not exceeded any of the media upload limits.
|
||||
|
||||
# Use limits from module API if provided
|
||||
media_upload_limits = (
|
||||
await self.media_repository_callbacks.get_media_upload_limits_for_user(
|
||||
auth_user.to_string()
|
||||
)
|
||||
)
|
||||
|
||||
# Otherwise use the default limits from config
|
||||
if media_upload_limits is None:
|
||||
# Note: the media upload limits are sorted so larger time periods are
|
||||
# first.
|
||||
media_upload_limits = self.default_media_upload_limits
|
||||
|
||||
# This is the total size of media uploaded by the user in the last
|
||||
# `time_period_ms` milliseconds, or None if we haven't checked yet.
|
||||
uploaded_media_size: Optional[int] = None
|
||||
|
||||
# Note: the media upload limits are sorted so larger time periods are
|
||||
# first.
|
||||
for limit in self.media_upload_limits:
|
||||
for limit in media_upload_limits:
|
||||
# We only need to check the amount of media uploaded by the user in
|
||||
# this latest (smaller) time period if the amount of media uploaded
|
||||
# in a previous (larger) time period is above the limit.
|
||||
# in a previous (larger) time period is below the limit.
|
||||
#
|
||||
# This optimization means that in the common case where the user
|
||||
# hasn't uploaded much media, we only need to query the database
|
||||
@@ -363,6 +376,12 @@ class MediaRepository:
|
||||
)
|
||||
|
||||
if uploaded_media_size + content_length > limit.max_bytes:
|
||||
await self.media_repository_callbacks.on_media_upload_limit_exceeded(
|
||||
user_id=auth_user.to_string(),
|
||||
limit=limit,
|
||||
sent_bytes=uploaded_media_size,
|
||||
attempted_bytes=content_length,
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED
|
||||
)
|
||||
|
||||
@@ -286,9 +286,11 @@ def run_as_background_process(
|
||||
).dec()
|
||||
|
||||
# To explain how the log contexts work here:
|
||||
# - When this function is called, the current context is stored (using
|
||||
# `PreserveLoggingContext`), we kick off the background task, and we restore the
|
||||
# original context before returning (also part of `PreserveLoggingContext`).
|
||||
# - When `run_as_background_process` is called, the current context is stored
|
||||
# (using `PreserveLoggingContext`), we kick off the background task, and we
|
||||
# restore the original context before returning (also part of
|
||||
# `PreserveLoggingContext`).
|
||||
# - The background task runs in its own new logcontext named after `desc`
|
||||
# - When the background task finishes, we don't want to leak our background context
|
||||
# into the reactor which would erroneously get attached to the next operation
|
||||
# picked up by the event loop. We use `PreserveLoggingContext` to set the
|
||||
|
||||
@@ -50,6 +50,7 @@ from synapse.api.constants import ProfileFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.repository import MediaUploadLimit
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.presence_router import (
|
||||
GET_INTERESTED_USERS_CALLBACK,
|
||||
@@ -94,7 +95,9 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
|
||||
)
|
||||
from synapse.module_api.callbacks.media_repository_callbacks import (
|
||||
GET_MEDIA_CONFIG_FOR_USER_CALLBACK,
|
||||
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK,
|
||||
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK,
|
||||
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK,
|
||||
)
|
||||
from synapse.module_api.callbacks.ratelimit_callbacks import (
|
||||
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK,
|
||||
@@ -205,6 +208,7 @@ __all__ = [
|
||||
"RoomAlias",
|
||||
"UserProfile",
|
||||
"RatelimitOverride",
|
||||
"MediaUploadLimit",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -462,6 +466,12 @@ class ModuleApi:
|
||||
is_user_allowed_to_upload_media_of_size: Optional[
|
||||
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
|
||||
] = None,
|
||||
get_media_upload_limits_for_user: Optional[
|
||||
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
|
||||
] = None,
|
||||
on_media_upload_limit_exceeded: Optional[
|
||||
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Registers callbacks for media repository capabilities.
|
||||
Added in Synapse v1.132.0.
|
||||
@@ -469,6 +479,8 @@ class ModuleApi:
|
||||
return self._callbacks.media_repository.register_callbacks(
|
||||
get_media_config_for_user=get_media_config_for_user,
|
||||
is_user_allowed_to_upload_media_of_size=is_user_allowed_to_upload_media_of_size,
|
||||
get_media_upload_limits_for_user=get_media_upload_limits_for_user,
|
||||
on_media_upload_limit_exceeded=on_media_upload_limit_exceeded,
|
||||
)
|
||||
|
||||
def register_third_party_rules_callbacks(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
|
||||
|
||||
from synapse.config.repository import MediaUploadLimit
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -28,6 +29,14 @@ GET_MEDIA_CONFIG_FOR_USER_CALLBACK = Callable[[str], Awaitable[Optional[JsonDict
|
||||
|
||||
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK = Callable[[str, int], Awaitable[bool]]
|
||||
|
||||
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK = Callable[
|
||||
[str], Awaitable[Optional[List[MediaUploadLimit]]]
|
||||
]
|
||||
|
||||
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK = Callable[
|
||||
[str, MediaUploadLimit, int, int], Awaitable[None]
|
||||
]
|
||||
|
||||
|
||||
class MediaRepositoryModuleApiCallbacks:
|
||||
def __init__(self, hs: "HomeServer") -> None:
|
||||
@@ -39,6 +48,12 @@ class MediaRepositoryModuleApiCallbacks:
|
||||
self._is_user_allowed_to_upload_media_of_size_callbacks: List[
|
||||
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
|
||||
] = []
|
||||
self._get_media_upload_limits_for_user_callbacks: List[
|
||||
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
|
||||
] = []
|
||||
self._on_media_upload_limit_exceeded_callbacks: List[
|
||||
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK
|
||||
] = []
|
||||
|
||||
def register_callbacks(
|
||||
self,
|
||||
@@ -46,6 +61,12 @@ class MediaRepositoryModuleApiCallbacks:
|
||||
is_user_allowed_to_upload_media_of_size: Optional[
|
||||
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
|
||||
] = None,
|
||||
get_media_upload_limits_for_user: Optional[
|
||||
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
|
||||
] = None,
|
||||
on_media_upload_limit_exceeded: Optional[
|
||||
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK
|
||||
] = None,
|
||||
) -> None:
|
||||
"""Register callbacks from module for each hook."""
|
||||
if get_media_config_for_user is not None:
|
||||
@@ -56,6 +77,16 @@ class MediaRepositoryModuleApiCallbacks:
|
||||
is_user_allowed_to_upload_media_of_size
|
||||
)
|
||||
|
||||
if get_media_upload_limits_for_user is not None:
|
||||
self._get_media_upload_limits_for_user_callbacks.append(
|
||||
get_media_upload_limits_for_user
|
||||
)
|
||||
|
||||
if on_media_upload_limit_exceeded is not None:
|
||||
self._on_media_upload_limit_exceeded_callbacks.append(
|
||||
on_media_upload_limit_exceeded
|
||||
)
|
||||
|
||||
async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]:
|
||||
for callback in self._get_media_config_for_user_callbacks:
|
||||
with Measure(
|
||||
@@ -83,3 +114,47 @@ class MediaRepositoryModuleApiCallbacks:
|
||||
return res
|
||||
|
||||
return True
|
||||
|
||||
async def get_media_upload_limits_for_user(
|
||||
self, user_id: str
|
||||
) -> Optional[List[MediaUploadLimit]]:
|
||||
"""
|
||||
Get the first non-None list of MediaUploadLimits for the user from the registered callbacks.
|
||||
If a list is returned it will be sorted in descending order of duration.
|
||||
"""
|
||||
for callback in self._get_media_upload_limits_for_user_callbacks:
|
||||
with Measure(
|
||||
self.clock,
|
||||
name=f"{callback.__module__}.{callback.__qualname__}",
|
||||
server_name=self.server_name,
|
||||
):
|
||||
res: Optional[List[MediaUploadLimit]] = await delay_cancellation(
|
||||
callback(user_id)
|
||||
)
|
||||
if res is not None: # to allow [] to be returned meaning no limit
|
||||
# We sort them in descending order of time period
|
||||
res.sort(key=lambda limit: limit.time_period_ms, reverse=True)
|
||||
return res
|
||||
|
||||
return None
|
||||
|
||||
async def on_media_upload_limit_exceeded(
|
||||
self,
|
||||
user_id: str,
|
||||
limit: MediaUploadLimit,
|
||||
sent_bytes: int,
|
||||
attempted_bytes: int,
|
||||
) -> None:
|
||||
for callback in self._on_media_upload_limit_exceeded_callbacks:
|
||||
with Measure(
|
||||
self.clock,
|
||||
name=f"{callback.__module__}.{callback.__qualname__}",
|
||||
server_name=self.server_name,
|
||||
):
|
||||
# Use a copy of the data in case the module modifies it
|
||||
limit_copy = MediaUploadLimit(
|
||||
max_bytes=limit.max_bytes, time_period_ms=limit.time_period_ms
|
||||
)
|
||||
await delay_cancellation(
|
||||
callback(user_id, limit_copy, sent_bytes, attempted_bytes)
|
||||
)
|
||||
|
||||
@@ -532,6 +532,7 @@ class Notifier:
|
||||
StreamKeyType.TO_DEVICE,
|
||||
StreamKeyType.TYPING,
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS,
|
||||
],
|
||||
new_token: int,
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
|
||||
@@ -44,6 +44,7 @@ from synapse.replication.tcp.streams import (
|
||||
UnPartialStatedEventStream,
|
||||
UnPartialStatedRoomStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamEventRow,
|
||||
@@ -255,6 +256,12 @@ class ReplicationDataHandler:
|
||||
self._state_storage_controller.notify_event_un_partial_stated(
|
||||
row.event_id
|
||||
)
|
||||
elif stream_name == ThreadSubscriptionsStream.NAME:
|
||||
self.notifier.on_new_event(
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS,
|
||||
token,
|
||||
users=[row.user_id for row in rows],
|
||||
)
|
||||
|
||||
await self._presence_handler.process_replication_rows(
|
||||
stream_name, instance_name, token, rows
|
||||
|
||||
@@ -23,6 +23,8 @@ import logging
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.api.filtering import FilterCollection
|
||||
@@ -632,12 +634,21 @@ class SyncRestServlet(RestServlet):
|
||||
|
||||
class SlidingSyncRestServlet(RestServlet):
|
||||
"""
|
||||
API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a
|
||||
API endpoint for MSC4186 Simplified Sliding Sync `/sync`, which was historically derived
|
||||
from MSC3575 (Sliding Sync; now abandoned). Allows for clients to request a
|
||||
subset (sliding window) of rooms, state, and timeline events (just what they need)
|
||||
in order to bootstrap quickly and subscribe to only what the client cares about.
|
||||
Because the client can specify what it cares about, we can respond quickly and skip
|
||||
all of the work we would normally have to do with a sync v2 response.
|
||||
|
||||
Extensions of various features are defined in:
|
||||
- to-device messaging (MSC3885)
|
||||
- end-to-end encryption (MSC3884)
|
||||
- typing notifications (MSC3961)
|
||||
- receipts (MSC3960)
|
||||
- account data (MSC3959)
|
||||
- thread subscriptions (MSC4308)
|
||||
|
||||
Request query parameters:
|
||||
timeout: How long to wait for new events in milliseconds.
|
||||
pos: Stream position token when asking for incremental deltas.
|
||||
@@ -1074,9 +1085,48 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
"rooms": extensions.typing.room_id_to_typing_map,
|
||||
}
|
||||
|
||||
# excludes both None and falsy `thread_subscriptions`
|
||||
if extensions.thread_subscriptions:
|
||||
serialized_extensions["io.element.msc4308.thread_subscriptions"] = (
|
||||
_serialise_thread_subscriptions(extensions.thread_subscriptions)
|
||||
)
|
||||
|
||||
return serialized_extensions
|
||||
|
||||
|
||||
def _serialise_thread_subscriptions(
|
||||
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,
|
||||
) -> JsonDict:
|
||||
out: JsonDict = {}
|
||||
|
||||
if thread_subscriptions.subscribed:
|
||||
out["subscribed"] = {
|
||||
room_id: {
|
||||
thread_root_id: attr.asdict(
|
||||
change, filter=lambda _attr, v: v is not None
|
||||
)
|
||||
for thread_root_id, change in room_threads.items()
|
||||
}
|
||||
for room_id, room_threads in thread_subscriptions.subscribed.items()
|
||||
}
|
||||
|
||||
if thread_subscriptions.unsubscribed:
|
||||
out["unsubscribed"] = {
|
||||
room_id: {
|
||||
thread_root_id: attr.asdict(
|
||||
change, filter=lambda _attr, v: v is not None
|
||||
)
|
||||
for thread_root_id, change in room_threads.items()
|
||||
}
|
||||
for room_id, room_threads in thread_subscriptions.unsubscribed.items()
|
||||
}
|
||||
|
||||
if thread_subscriptions.prev_batch:
|
||||
out["prev_batch"] = thread_subscriptions.prev_batch.to_string()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
SyncRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
@@ -1,21 +1,39 @@
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_and_validate_json_object_from_request,
|
||||
parse_integer,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict, RoomID
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
RoomID,
|
||||
SlidingSyncStreamToken,
|
||||
ThreadSubscriptionsToken,
|
||||
)
|
||||
from synapse.types.handlers.sliding_sync import SlidingSyncResult
|
||||
from synapse.types.rest import RequestBodyModel
|
||||
from synapse.util.pydantic_models import AnyEventId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
_ThreadSubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
|
||||
)
|
||||
_ThreadUnsubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
|
||||
)
|
||||
|
||||
|
||||
class ThreadSubscriptionsRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
@@ -100,6 +118,130 @@ class ThreadSubscriptionsRestServlet(RestServlet):
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
|
||||
class ThreadSubscriptionsPaginationRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/io.element.msc4308/thread_subscriptions$",
|
||||
unstable=True,
|
||||
releases=(),
|
||||
)
|
||||
CATEGORY = "Thread Subscriptions requests (unstable)"
|
||||
|
||||
# Maximum number of thread subscriptions to return in one request.
|
||||
MAX_LIMIT = 512
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine = hs.is_mine
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
limit = min(
|
||||
parse_integer(request, "limit", default=100, negative=False),
|
||||
ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT,
|
||||
)
|
||||
from_end_opt = parse_string(request, "from", required=False)
|
||||
to_start_opt = parse_string(request, "to", required=False)
|
||||
_direction = parse_string(request, "dir", required=True, allowed_values=("b",))
|
||||
|
||||
if limit <= 0:
|
||||
# condition needed because `negative=False` still allows 0
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"limit must be greater than 0",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if from_end_opt is not None:
|
||||
try:
|
||||
# because of backwards pagination, the `from` token is actually the
|
||||
# bound closest to the end of the stream
|
||||
end_stream_id = ThreadSubscriptionsToken.from_string(
|
||||
from_end_opt
|
||||
).stream_id
|
||||
except ValueError:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"`from` is not a valid token",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
else:
|
||||
end_stream_id = self.store.get_max_thread_subscriptions_stream_id()
|
||||
|
||||
if to_start_opt is not None:
|
||||
# because of backwards pagination, the `to` token is actually the
|
||||
# bound closest to the start of the stream
|
||||
try:
|
||||
start_stream_id = ThreadSubscriptionsToken.from_string(
|
||||
to_start_opt
|
||||
).stream_id
|
||||
except ValueError:
|
||||
# we also accept sliding sync `pos` tokens on this parameter
|
||||
try:
|
||||
sliding_sync_pos = await SlidingSyncStreamToken.from_string(
|
||||
self.store, to_start_opt
|
||||
)
|
||||
start_stream_id = (
|
||||
sliding_sync_pos.stream_token.thread_subscriptions_key
|
||||
)
|
||||
except ValueError:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"`to` is not a valid token",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
else:
|
||||
# the start of time is ID 1; the lower bound is exclusive though
|
||||
start_stream_id = 0
|
||||
|
||||
subscriptions = (
|
||||
await self.store.get_latest_updated_thread_subscriptions_for_user(
|
||||
requester.user.to_string(),
|
||||
from_id=start_stream_id,
|
||||
to_id=end_stream_id,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
|
||||
subscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
|
||||
unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
|
||||
for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions:
|
||||
if subscribed:
|
||||
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
|
||||
attr.asdict(
|
||||
_ThreadSubscription(
|
||||
automatic=automatic,
|
||||
bump_stamp=stream_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
|
||||
attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id))
|
||||
)
|
||||
|
||||
result: JsonDict = {}
|
||||
if subscribed_threads:
|
||||
result["subscribed"] = subscribed_threads
|
||||
if unsubscribed_threads:
|
||||
result["unsubscribed"] = unsubscribed_threads
|
||||
|
||||
if len(subscriptions) == limit:
|
||||
# We hit the limit, so there might be more entries to return.
|
||||
# Generate a new token that has moved backwards, ready for the next
|
||||
# request.
|
||||
min_returned_stream_id, _, _, _, _ = subscriptions[0]
|
||||
result["end"] = ThreadSubscriptionsToken(
|
||||
# We subtract one because the 'later in the stream' bound is inclusive,
|
||||
# and we already saw the element at index 0.
|
||||
stream_id=min_returned_stream_id - 1
|
||||
).to_string()
|
||||
|
||||
return HTTPStatus.OK, result
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc4306_enabled:
|
||||
ThreadSubscriptionsRestServlet(hs).register(http_server)
|
||||
ThreadSubscriptionsPaginationRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -53,7 +53,7 @@ from synapse.storage.databases.main.stream import (
|
||||
generate_pagination_where_clause,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
|
||||
from synapse.types import JsonDict, StreamKeyType, StreamToken
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -316,17 +316,8 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
StreamKeyType.ROOM, next_key
|
||||
)
|
||||
else:
|
||||
next_token = StreamToken(
|
||||
room_key=next_key,
|
||||
presence_key=0,
|
||||
typing_key=0,
|
||||
receipt_key=MultiWriterStreamToken(stream=0),
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=MultiWriterStreamToken(stream=0),
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=0,
|
||||
next_token = StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM, next_key
|
||||
)
|
||||
|
||||
return events[:limit], next_token
|
||||
|
||||
@@ -492,7 +492,7 @@ class PerConnectionStateDB:
|
||||
"""An equivalent to `PerConnectionState` that holds data in a format stored
|
||||
in the DB.
|
||||
|
||||
The principle difference is that the tokens for the different streams are
|
||||
The principal difference is that the tokens for the different streams are
|
||||
serialized to strings.
|
||||
|
||||
When persisting this *only* contains updates to the state.
|
||||
|
||||
@@ -505,6 +505,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
"""
|
||||
return self._thread_subscriptions_id_gen.get_current_token()
|
||||
|
||||
def get_thread_subscriptions_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._thread_subscriptions_id_gen
|
||||
|
||||
async def get_updated_thread_subscriptions(
|
||||
self, *, from_id: int, to_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str, str]]:
|
||||
@@ -538,34 +541,52 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
get_updated_thread_subscriptions_txn,
|
||||
)
|
||||
|
||||
async def get_updated_thread_subscriptions_for_user(
|
||||
async def get_latest_updated_thread_subscriptions_for_user(
|
||||
self, user_id: str, *, from_id: int, to_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
"""Get updates to thread subscriptions for a specific user.
|
||||
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
|
||||
"""Get the latest updates to thread subscriptions for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user
|
||||
from_id: The starting stream ID (exclusive)
|
||||
to_id: The ending stream ID (inclusive)
|
||||
limit: The maximum number of rows to return
|
||||
If there are too many rows to return, rows from the start (closer to `from_id`)
|
||||
will be omitted.
|
||||
|
||||
Returns:
|
||||
A list of (stream_id, room_id, thread_root_event_id) tuples.
|
||||
A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples.
|
||||
The row with lowest `stream_id` is the first row.
|
||||
"""
|
||||
|
||||
def get_updated_thread_subscriptions_for_user_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
|
||||
sql = """
|
||||
SELECT stream_id, room_id, event_id
|
||||
FROM thread_subscriptions
|
||||
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
|
||||
WITH the_updates AS (
|
||||
SELECT stream_id, room_id, event_id, subscribed, automatic
|
||||
FROM thread_subscriptions
|
||||
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id DESC
|
||||
LIMIT ?
|
||||
)
|
||||
SELECT stream_id, room_id, event_id, subscribed, automatic
|
||||
FROM the_updates
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, from_id, to_id, limit))
|
||||
return [(row[0], row[1], row[2]) for row in txn]
|
||||
return [
|
||||
(
|
||||
stream_id,
|
||||
room_id,
|
||||
event_id,
|
||||
# SQLite integer to boolean conversions
|
||||
bool(subscribed),
|
||||
bool(automatic) if subscribed else None,
|
||||
)
|
||||
for (stream_id, room_id, event_id, subscribed, automatic) in txn
|
||||
]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_updated_thread_subscriptions_for_user",
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Work around https://github.com/element-hq/synapse/issues/18712 by advancing the
|
||||
-- stream sequence.
|
||||
-- This makes last_value of the sequence point to a position that will not get later
|
||||
-- returned by nextval.
|
||||
-- (For blank thread subscription streams, this means last_value = 2, nextval() = 3 after this line.)
|
||||
SELECT nextval('thread_subscriptions_sequence');
|
||||
@@ -187,8 +187,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
||||
Warning: Streams using this generator start at ID 2, because ID 1 is always assumed
|
||||
to have been 'seen as persisted'.
|
||||
Unclear if this extant behaviour is desirable for some reason.
|
||||
When creating a new sequence for a new stream,
|
||||
it will be necessary to use `START WITH 2`.
|
||||
When creating a new sequence for a new stream, it will be necessary to advance it
|
||||
so that position 1 is consumed.
|
||||
DO NOT USE `START WITH 2` FOR THIS PURPOSE:
|
||||
see https://github.com/element-hq/synapse/issues/18712
|
||||
Instead, use `SELECT nextval('sequence_name');` immediately after the
|
||||
`CREATE SEQUENCE` statement.
|
||||
|
||||
Args:
|
||||
db_conn
|
||||
|
||||
@@ -33,7 +33,6 @@ from synapse.logging.opentracing import trace
|
||||
from synapse.streams import EventSource
|
||||
from synapse.types import (
|
||||
AbstractMultiWriterStreamToken,
|
||||
MultiWriterStreamToken,
|
||||
StreamKeyType,
|
||||
StreamToken,
|
||||
)
|
||||
@@ -84,6 +83,7 @@ class EventSources:
|
||||
un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token(
|
||||
self._instance_name
|
||||
)
|
||||
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=self.sources.room.get_current_key(),
|
||||
@@ -97,6 +97,7 @@ class EventSources:
|
||||
# Groups key is unused.
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
|
||||
thread_subscriptions_key=thread_subscriptions_key,
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -123,6 +124,7 @@ class EventSources:
|
||||
StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
|
||||
StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
|
||||
}
|
||||
|
||||
for _, key in StreamKeyType.__members__.items():
|
||||
@@ -195,16 +197,7 @@ class EventSources:
|
||||
Returns:
|
||||
The current token for pagination.
|
||||
"""
|
||||
token = StreamToken(
|
||||
room_key=await self.sources.room.get_current_key_for_room(room_id),
|
||||
presence_key=0,
|
||||
typing_key=0,
|
||||
receipt_key=MultiWriterStreamToken(stream=0),
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=MultiWriterStreamToken(stream=0),
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=0,
|
||||
return StreamToken.START.copy_and_replace(
|
||||
StreamKeyType.ROOM,
|
||||
await self.sources.room.get_current_key_for_room(room_id),
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -996,6 +996,7 @@ class StreamKeyType(Enum):
|
||||
TO_DEVICE = "to_device_key"
|
||||
DEVICE_LIST = "device_list_key"
|
||||
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
|
||||
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@@ -1003,7 +1004,7 @@ class StreamToken:
|
||||
"""A collection of keys joined together by underscores in the following
|
||||
order and which represent the position in their respective streams.
|
||||
|
||||
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379`
|
||||
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242`
|
||||
1. `room_key`: `s2633508` which is a `RoomStreamToken`
|
||||
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
|
||||
- See the docstring for `RoomStreamToken` for more details.
|
||||
@@ -1016,6 +1017,7 @@ class StreamToken:
|
||||
8. `device_list_key`: `265584`
|
||||
9. `groups_key`: `1` (note that this key is now unused)
|
||||
10. `un_partial_stated_rooms_key`: `379`
|
||||
11. `thread_subscriptions_key`: 4242
|
||||
|
||||
You can see how many of these keys correspond to the various
|
||||
fields in a "/sync" response:
|
||||
@@ -1074,6 +1076,7 @@ class StreamToken:
|
||||
# Note that the groups key is no longer used and may have bogus values.
|
||||
groups_key: int
|
||||
un_partial_stated_rooms_key: int
|
||||
thread_subscriptions_key: int
|
||||
|
||||
_SEPARATOR = "_"
|
||||
START: ClassVar["StreamToken"]
|
||||
@@ -1101,6 +1104,7 @@ class StreamToken:
|
||||
device_list_key,
|
||||
groups_key,
|
||||
un_partial_stated_rooms_key,
|
||||
thread_subscriptions_key,
|
||||
) = keys
|
||||
|
||||
return cls(
|
||||
@@ -1116,6 +1120,7 @@ class StreamToken:
|
||||
),
|
||||
groups_key=int(groups_key),
|
||||
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
|
||||
thread_subscriptions_key=int(thread_subscriptions_key),
|
||||
)
|
||||
except CancelledError:
|
||||
raise
|
||||
@@ -1138,6 +1143,7 @@ class StreamToken:
|
||||
# if additional tokens are added.
|
||||
str(self.groups_key),
|
||||
str(self.un_partial_stated_rooms_key),
|
||||
str(self.thread_subscriptions_key),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1202,6 +1208,7 @@ class StreamToken:
|
||||
StreamKeyType.TO_DEVICE,
|
||||
StreamKeyType.TYPING,
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS,
|
||||
],
|
||||
) -> int: ...
|
||||
|
||||
@@ -1257,7 +1264,8 @@ class StreamToken:
|
||||
f"typing: {self.typing_key}, receipt: {self.receipt_key}, "
|
||||
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
|
||||
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
|
||||
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key})"
|
||||
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
|
||||
f"thread_subscriptions: {self.thread_subscriptions_key})"
|
||||
)
|
||||
|
||||
|
||||
@@ -1272,6 +1280,7 @@ StreamToken.START = StreamToken(
|
||||
device_list_key=MultiWriterStreamToken(stream=0),
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=0,
|
||||
thread_subscriptions_key=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -1318,6 +1327,27 @@ class SlidingSyncStreamToken:
|
||||
return f"{self.connection_position}/{stream_token_str}"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadSubscriptionsToken:
|
||||
"""
|
||||
Token for a position in the thread subscriptions stream.
|
||||
|
||||
Format: `ts<stream_id>`
|
||||
"""
|
||||
|
||||
stream_id: int
|
||||
|
||||
@staticmethod
|
||||
def from_string(s: str) -> "ThreadSubscriptionsToken":
|
||||
if not s.startswith("ts"):
|
||||
raise ValueError("thread subscription token must start with `ts`")
|
||||
|
||||
return ThreadSubscriptionsToken(stream_id=int(s[2:]))
|
||||
|
||||
def to_string(self) -> str:
|
||||
return f"ts{self.stream_id}"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class PersistedPosition:
|
||||
"""Position of a newly persisted row with instance that persisted it."""
|
||||
|
||||
@@ -50,6 +50,7 @@ from synapse.types import (
|
||||
SlidingSyncStreamToken,
|
||||
StrCollection,
|
||||
StreamToken,
|
||||
ThreadSubscriptionsToken,
|
||||
UserID,
|
||||
)
|
||||
from synapse.types.rest.client import SlidingSyncBody
|
||||
@@ -357,11 +358,50 @@ class SlidingSyncResult:
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.room_id_to_typing_map)
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadSubscriptionsExtension:
|
||||
"""The Thread Subscriptions extension (MSC4308)
|
||||
|
||||
Attributes:
|
||||
subscribed: map (room_id -> thread_root_id -> info) of new or changed subscriptions
|
||||
unsubscribed: map (room_id -> thread_root_id -> info) of new unsubscriptions
|
||||
prev_batch: if present, there is a gap and the client can use this token to backpaginate
|
||||
"""
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadSubscription:
|
||||
# always present when `subscribed`
|
||||
automatic: Optional[bool]
|
||||
|
||||
# the same as our stream_id; useful for clients to resolve
|
||||
# race conditions locally
|
||||
bump_stamp: int
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUnsubscription:
|
||||
# the same as our stream_id; useful for clients to resolve
|
||||
# race conditions locally
|
||||
bump_stamp: int
|
||||
|
||||
# room_id -> event_id (of thread root) -> the subscription change
|
||||
subscribed: Optional[Mapping[str, Mapping[str, ThreadSubscription]]]
|
||||
# room_id -> event_id (of thread root) -> the unsubscription
|
||||
unsubscribed: Optional[Mapping[str, Mapping[str, ThreadUnsubscription]]]
|
||||
prev_batch: Optional[ThreadSubscriptionsToken]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return (
|
||||
bool(self.subscribed)
|
||||
or bool(self.unsubscribed)
|
||||
or bool(self.prev_batch)
|
||||
)
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
receipts: Optional[ReceiptsExtension] = None
|
||||
typing: Optional[TypingExtension] = None
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(
|
||||
@@ -370,6 +410,7 @@ class SlidingSyncResult:
|
||||
or self.account_data
|
||||
or self.receipts
|
||||
or self.typing
|
||||
or self.thread_subscriptions
|
||||
)
|
||||
|
||||
next_pos: SlidingSyncStreamToken
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from synapse._pydantic_compat import (
|
||||
Extra,
|
||||
Field,
|
||||
StrictBool,
|
||||
StrictInt,
|
||||
StrictStr,
|
||||
@@ -364,11 +365,25 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
# Process all room subscriptions defined in the Room Subscription API. (This is the default.)
|
||||
rooms: Optional[List[StrictStr]] = ["*"]
|
||||
|
||||
class ThreadSubscriptionsExtension(RequestBodyModel):
|
||||
"""The Thread Subscriptions extension (MSC4308)
|
||||
|
||||
Attributes:
|
||||
enabled
|
||||
limit: maximum number of subscription changes to return (default 100)
|
||||
"""
|
||||
|
||||
enabled: Optional[StrictBool] = False
|
||||
limit: StrictInt = 100
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
receipts: Optional[ReceiptsExtension] = None
|
||||
typing: Optional[TypingExtension] = None
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
|
||||
alias="io.element.msc4308.thread_subscriptions"
|
||||
)
|
||||
|
||||
conn_id: Optional[StrictStr]
|
||||
|
||||
|
||||
@@ -347,6 +347,7 @@ T2 = TypeVar("T2")
|
||||
T3 = TypeVar("T3")
|
||||
T4 = TypeVar("T4")
|
||||
T5 = TypeVar("T5")
|
||||
T6 = TypeVar("T6")
|
||||
|
||||
|
||||
@overload
|
||||
@@ -461,6 +462,23 @@ async def gather_optional_coroutines(
|
||||
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
Optional[Coroutine[Any, Any, T3]],
|
||||
Optional[Coroutine[Any, Any, T4]],
|
||||
Optional[Coroutine[Any, Any, T5]],
|
||||
Optional[Coroutine[Any, Any, T6]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[
|
||||
Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5], Optional[T6]
|
||||
]: ...
|
||||
|
||||
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
|
||||
) -> Tuple[Optional[T1], ...]:
|
||||
|
||||
@@ -29,6 +29,11 @@ import sys
|
||||
from types import FrameType, TracebackType
|
||||
from typing import NoReturn, Optional, Type
|
||||
|
||||
from synapse.logging.context import (
|
||||
LoggingContext,
|
||||
PreserveLoggingContext,
|
||||
)
|
||||
|
||||
|
||||
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
|
||||
"""daemonize the current process
|
||||
@@ -64,8 +69,14 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
||||
pid_fh.write(old_pid)
|
||||
sys.exit(1)
|
||||
|
||||
# Fork, creating a new process for the child.
|
||||
process_id = os.fork()
|
||||
# Stop the existing context *before* we fork the process. Otherwise the cputime
|
||||
# metrics get confused about the per-thread resource usage appearing to go backwards
|
||||
# because we're comparing the resource usage from the original process to the forked
|
||||
# process. `PreserveLoggingContext` already takes care of restarting the original
|
||||
# context *after* the block.
|
||||
with PreserveLoggingContext():
|
||||
# Fork, creating a new process for the child.
|
||||
process_id = os.fork()
|
||||
|
||||
if process_id != 0:
|
||||
# parent process: exit.
|
||||
@@ -140,9 +151,10 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
|
||||
|
||||
# Cleanup pid file at exit.
|
||||
def exit() -> None:
|
||||
logger.warning("Stopping daemon.")
|
||||
os.remove(pid_file)
|
||||
sys.exit(0)
|
||||
with LoggingContext("atexit"):
|
||||
logger.warning("Stopping daemon.")
|
||||
os.remove(pid_file)
|
||||
sys.exit(0)
|
||||
|
||||
atexit.register(exit)
|
||||
|
||||
|
||||
@@ -37,4 +37,7 @@ class HomeserverAppStartTestCase(ConfigFileTestCase):
|
||||
self.add_lines_to_config([" main:", " host: 127.0.0.1", " port: 1234"])
|
||||
# Ensure that starting master process with worker config raises an exception
|
||||
with self.assertRaises(ConfigError):
|
||||
synapse.app.homeserver.setup(["-c", self.config_file])
|
||||
homeserver_config = synapse.app.homeserver.load_or_generate_config(
|
||||
["-c", self.config_file]
|
||||
)
|
||||
synapse.app.homeserver.setup(homeserver_config)
|
||||
|
||||
@@ -112,4 +112,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase):
|
||||
|
||||
# Test that allowing open registration without verification raises an error
|
||||
with self.assertRaises(ConfigError):
|
||||
synapse.app.homeserver.setup(["-c", self.config_file])
|
||||
homeserver_config = synapse.app.homeserver.load_or_generate_config(
|
||||
["-c", self.config_file]
|
||||
)
|
||||
synapse.app.homeserver.setup(homeserver_config)
|
||||
|
||||
@@ -2244,7 +2244,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_topo_token_is_accepted(self) -> None:
|
||||
"""Test Topo Token is accepted."""
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
|
||||
@@ -2258,7 +2258,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||
"""Test that stream token is accepted for forward pagination."""
|
||||
token = "s0_0_0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
|
||||
|
||||
@@ -0,0 +1,497 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import List, Optional, Tuple, cast
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.rest.client import login, room, sync, thread_subscriptions
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# The name of the extension. Currently unstable-prefixed.
|
||||
EXT_NAME = "io.element.msc4308.thread_subscriptions"
|
||||
|
||||
|
||||
class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase):
|
||||
"""
|
||||
Test the thread subscriptions extension in the Sliding Sync API.
|
||||
"""
|
||||
|
||||
maxDiff = None
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
sync.register_servlets,
|
||||
thread_subscriptions.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = {"msc4306_enabled": True}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.storage_controllers = hs.get_storage_controllers()
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
def test_no_data_initial_sync(self) -> None:
|
||||
"""
|
||||
Test enabling thread subscriptions extension during initial sync with no data.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Sync
|
||||
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Assert
|
||||
self.assertNotIn(EXT_NAME, response_body["extensions"])
|
||||
|
||||
def test_no_data_incremental_sync(self) -> None:
|
||||
"""
|
||||
Test enabling thread subscriptions extension during incremental sync with no data.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
initial_sync_body: JsonDict = {
|
||||
"lists": {},
|
||||
}
|
||||
|
||||
# Initial sync
|
||||
response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok)
|
||||
|
||||
# Incremental sync with extension enabled
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
|
||||
|
||||
# Assert
|
||||
self.assertNotIn(
|
||||
EXT_NAME,
|
||||
response_body["extensions"],
|
||||
response_body,
|
||||
)
|
||||
|
||||
def test_thread_subscription_initial_sync(self) -> None:
|
||||
"""
|
||||
Test thread subscriptions appear in initial sync response.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root_id = thread_root_resp["event_id"]
|
||||
|
||||
# get the baseline stream_id of the thread_subscriptions stream
|
||||
# before we write any data.
|
||||
# Required because the initial value differs between SQLite and Postgres.
|
||||
base = self.store.get_max_thread_subscriptions_stream_id()
|
||||
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Sync
|
||||
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(
|
||||
response_body["extensions"][EXT_NAME],
|
||||
{
|
||||
"subscribed": {
|
||||
room_id: {
|
||||
thread_root_id: {
|
||||
"automatic": False,
|
||||
"bump_stamp": base + 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def test_thread_subscription_incremental_sync(self) -> None:
|
||||
"""
|
||||
Test new thread subscriptions appear in incremental sync response.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root_id = thread_root_resp["event_id"]
|
||||
|
||||
# get the baseline stream_id of the thread_subscriptions stream
|
||||
# before we write any data.
|
||||
# Required because the initial value differs between SQLite and Postgres.
|
||||
base = self.store.get_max_thread_subscriptions_stream_id()
|
||||
|
||||
# Initial sync
|
||||
_, sync_pos = self.do_sync(sync_body, tok=user1_tok)
|
||||
logger.info("Synced to: %r, now subscribing to thread", sync_pos)
|
||||
|
||||
# Subscribe
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
|
||||
|
||||
# Incremental sync
|
||||
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
|
||||
logger.info("Synced to: %r", sync_pos)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(
|
||||
response_body["extensions"][EXT_NAME],
|
||||
{
|
||||
"subscribed": {
|
||||
room_id: {
|
||||
thread_root_id: {
|
||||
"automatic": False,
|
||||
"bump_stamp": base + 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def test_unsubscribe_from_thread(self) -> None:
|
||||
"""
|
||||
Test unsubscribing from a thread.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
|
||||
thread_root_id = thread_root_resp["event_id"]
|
||||
|
||||
# get the baseline stream_id of the thread_subscriptions stream
|
||||
# before we write any data.
|
||||
# Required because the initial value differs between SQLite and Postgres.
|
||||
base = self.store.get_max_thread_subscriptions_stream_id()
|
||||
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Assert: Subscription present
|
||||
self.assertIn(EXT_NAME, response_body["extensions"])
|
||||
self.assertEqual(
|
||||
response_body["extensions"][EXT_NAME],
|
||||
{
|
||||
"subscribed": {
|
||||
room_id: {
|
||||
thread_root_id: {"automatic": False, "bump_stamp": base + 1}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Unsubscribe
|
||||
self._unsubscribe_from_thread(user1_id, room_id, thread_root_id)
|
||||
|
||||
# Incremental sync
|
||||
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
|
||||
|
||||
# Assert: Unsubscription present
|
||||
self.assertEqual(
|
||||
response_body["extensions"][EXT_NAME],
|
||||
{"unsubscribed": {room_id: {thread_root_id: {"bump_stamp": base + 2}}}},
|
||||
)
|
||||
|
||||
def test_multiple_thread_subscriptions(self) -> None:
|
||||
"""
|
||||
Test handling of multiple thread subscriptions.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create thread roots
|
||||
thread_root_resp1 = self.helper.send(
|
||||
room_id, body="Thread root 1", tok=user1_tok
|
||||
)
|
||||
thread_root_id1 = thread_root_resp1["event_id"]
|
||||
thread_root_resp2 = self.helper.send(
|
||||
room_id, body="Thread root 2", tok=user1_tok
|
||||
)
|
||||
thread_root_id2 = thread_root_resp2["event_id"]
|
||||
thread_root_resp3 = self.helper.send(
|
||||
room_id, body="Thread root 3", tok=user1_tok
|
||||
)
|
||||
thread_root_id3 = thread_root_resp3["event_id"]
|
||||
|
||||
# get the baseline stream_id of the thread_subscriptions stream
|
||||
# before we write any data.
|
||||
# Required because the initial value differs between SQLite and Postgres.
|
||||
base = self.store.get_max_thread_subscriptions_stream_id()
|
||||
|
||||
# Subscribe to threads
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_id1)
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_id2)
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_id3)
|
||||
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {
|
||||
EXT_NAME: {
|
||||
"enabled": True,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Sync
|
||||
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Assert
|
||||
self.assertEqual(
|
||||
response_body["extensions"][EXT_NAME],
|
||||
{
|
||||
"subscribed": {
|
||||
room_id: {
|
||||
thread_root_id1: {
|
||||
"automatic": False,
|
||||
"bump_stamp": base + 1,
|
||||
},
|
||||
thread_root_id2: {
|
||||
"automatic": False,
|
||||
"bump_stamp": base + 2,
|
||||
},
|
||||
thread_root_id3: {
|
||||
"automatic": False,
|
||||
"bump_stamp": base + 3,
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def test_limit_parameter(self) -> None:
|
||||
"""
|
||||
Test limit parameter in thread subscriptions extension.
|
||||
"""
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# Create 5 thread roots and subscribe to each
|
||||
thread_root_ids = []
|
||||
for i in range(5):
|
||||
thread_root_resp = self.helper.send(
|
||||
room_id, body=f"Thread root {i}", tok=user1_tok
|
||||
)
|
||||
thread_root_ids.append(thread_root_resp["event_id"])
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
|
||||
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {EXT_NAME: {"enabled": True, "limit": 3}},
|
||||
}
|
||||
|
||||
# Sync
|
||||
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
# Assert
|
||||
thread_subscriptions = response_body["extensions"][EXT_NAME]
|
||||
self.assertEqual(
|
||||
len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions
|
||||
)
|
||||
|
||||
def test_limit_and_companion_backpagination(self) -> None:
|
||||
"""
|
||||
Create 1 thread subscription, do a sync, create 4 more,
|
||||
then sync with a limit of 2 and fill in the gap
|
||||
using the companion /thread_subscriptions endpoint.
|
||||
"""
|
||||
|
||||
thread_root_ids: List[str] = []
|
||||
|
||||
def make_subscription() -> None:
|
||||
thread_root_resp = self.helper.send(
|
||||
room_id, body="Some thread root", tok=user1_tok
|
||||
)
|
||||
thread_root_ids.append(thread_root_resp["event_id"])
|
||||
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
|
||||
|
||||
user1_id = self.register_user("user1", "pass")
|
||||
user1_tok = self.login(user1_id, "pass")
|
||||
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
|
||||
|
||||
# get the baseline stream_id of the thread_subscriptions stream
|
||||
# before we write any data.
|
||||
# Required because the initial value differs between SQLite and Postgres.
|
||||
base = self.store.get_max_thread_subscriptions_stream_id()
|
||||
|
||||
# Make our first subscription
|
||||
make_subscription()
|
||||
|
||||
# Sync for the first time
|
||||
sync_body = {
|
||||
"lists": {},
|
||||
"extensions": {EXT_NAME: {"enabled": True, "limit": 2}},
|
||||
}
|
||||
|
||||
sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok)
|
||||
|
||||
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
|
||||
self.assertEqual(
|
||||
thread_subscriptions["subscribed"],
|
||||
{
|
||||
room_id: {
|
||||
thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Get our pos for the next sync
|
||||
first_sync_pos = sync_resp["pos"]
|
||||
|
||||
# Create 5 more thread subscriptions and subscribe to each
|
||||
for _ in range(5):
|
||||
make_subscription()
|
||||
|
||||
# Now sync again. Our limit is 2,
|
||||
# so we should get the latest 2 subscriptions,
|
||||
# with a gap of 3 more subscriptions in the middle
|
||||
sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos)
|
||||
|
||||
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
|
||||
self.assertEqual(
|
||||
thread_subscriptions["subscribed"],
|
||||
{
|
||||
room_id: {
|
||||
thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5},
|
||||
thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6},
|
||||
}
|
||||
},
|
||||
)
|
||||
# 1st backpagination: expecting a page with 2 subscriptions
|
||||
page, end_tok = self._do_backpaginate(
|
||||
from_tok=thread_subscriptions["prev_batch"],
|
||||
to_tok=first_sync_pos,
|
||||
limit=2,
|
||||
access_token=user1_tok,
|
||||
)
|
||||
self.assertIsNotNone(end_tok, "backpagination should continue")
|
||||
self.assertEqual(
|
||||
page["subscribed"],
|
||||
{
|
||||
room_id: {
|
||||
thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3},
|
||||
thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# 2nd backpagination: expecting a page with only 1 subscription
|
||||
# and no other token for further backpagination
|
||||
assert end_tok is not None
|
||||
page, end_tok = self._do_backpaginate(
|
||||
from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok
|
||||
)
|
||||
self.assertIsNone(end_tok, "backpagination should have finished")
|
||||
self.assertEqual(
|
||||
page["subscribed"],
|
||||
{
|
||||
room_id: {
|
||||
thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def _do_backpaginate(
|
||||
self, *, from_tok: str, to_tok: str, limit: int, access_token: str
|
||||
) -> Tuple[JsonDict, Optional[str]]:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/io.element.msc4308/thread_subscriptions"
|
||||
f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b",
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
body = channel.json_body
|
||||
return body, cast(Optional[str], body.get("end"))
|
||||
|
||||
def _subscribe_to_thread(
|
||||
self, user_id: str, room_id: str, thread_root_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to subscribe a user to a thread.
|
||||
"""
|
||||
self.get_success(
|
||||
self.store.subscribe_user_to_thread(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
thread_root_event_id=thread_root_id,
|
||||
automatic_event_orderings=None,
|
||||
)
|
||||
)
|
||||
|
||||
def _unsubscribe_from_thread(
|
||||
self, user_id: str, room_id: str, thread_root_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to unsubscribe a user from a thread.
|
||||
"""
|
||||
self.get_success(
|
||||
self.store.unsubscribe_user_from_thread(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
thread_root_event_id=thread_root_id,
|
||||
)
|
||||
)
|
||||
@@ -46,6 +46,7 @@ from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config._base import Config
|
||||
from synapse.config.oembed import OEmbedEndpointConfig
|
||||
from synapse.http.client import MultipartResponse
|
||||
from synapse.http.types import QueryParams
|
||||
@@ -53,6 +54,7 @@ from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.media._base import FileInfo, ThumbnailInfo
|
||||
from synapse.media.thumbnailer import ThumbnailProvider
|
||||
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
|
||||
from synapse.module_api import MediaUploadLimit
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, media
|
||||
from synapse.server import HomeServer
|
||||
@@ -2967,3 +2969,192 @@ class MediaUploadLimits(unittest.HomeserverTestCase):
|
||||
# This will succeed as the weekly limit has reset
|
||||
channel = self.upload_media(900)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
|
||||
class MediaUploadLimitsModuleOverrides(unittest.HomeserverTestCase):
|
||||
"""
|
||||
This test case simulates a homeserver with media upload limits being overridden by the module API.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
media.register_servlets,
|
||||
login.register_servlets,
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
self.storage_path = self.mktemp()
|
||||
self.media_store_path = self.mktemp()
|
||||
os.mkdir(self.storage_path)
|
||||
os.mkdir(self.media_store_path)
|
||||
config["media_store_path"] = self.media_store_path
|
||||
|
||||
provider_config = {
|
||||
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||
"store_local": True,
|
||||
"store_synchronous": False,
|
||||
"store_remote": True,
|
||||
"config": {"directory": self.storage_path},
|
||||
}
|
||||
|
||||
config["media_storage_providers"] = [provider_config]
|
||||
|
||||
# default limits to use
|
||||
config["media_upload_limits"] = [
|
||||
{"time_period": "1d", "max_size": "1K"},
|
||||
{"time_period": "1w", "max_size": "3K"},
|
||||
]
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
async def _get_media_upload_limits_for_user(
|
||||
self,
|
||||
user_id: str,
|
||||
) -> Optional[List[MediaUploadLimit]]:
|
||||
# user1 has custom limits
|
||||
if user_id == self.user1:
|
||||
# n.b. we return these in increasing duration order and Synapse will need to sort them correctly
|
||||
return [
|
||||
MediaUploadLimit(
|
||||
time_period_ms=Config.parse_duration("1d"), max_bytes=5000
|
||||
),
|
||||
MediaUploadLimit(
|
||||
time_period_ms=Config.parse_duration("1w"), max_bytes=15000
|
||||
),
|
||||
]
|
||||
# user2 has no limits
|
||||
if user_id == self.user2:
|
||||
return []
|
||||
# otherwise use default
|
||||
return None
|
||||
|
||||
async def _on_media_upload_limit_exceeded(
|
||||
self,
|
||||
user_id: str,
|
||||
limit: MediaUploadLimit,
|
||||
sent_bytes: int,
|
||||
attempted_bytes: int,
|
||||
) -> None:
|
||||
self.last_media_upload_limit_exceeded: Optional[dict[str, object]] = {
|
||||
"user_id": user_id,
|
||||
"limit": limit,
|
||||
"sent_bytes": sent_bytes,
|
||||
"attempted_bytes": attempted_bytes,
|
||||
}
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.repo = hs.get_media_repository()
|
||||
self.client = hs.get_federation_http_client()
|
||||
self.store = hs.get_datastores().main
|
||||
self.user1 = self.register_user("user1", "pass")
|
||||
self.tok1 = self.login("user1", "pass")
|
||||
self.user2 = self.register_user("user2", "pass")
|
||||
self.tok2 = self.login("user2", "pass")
|
||||
self.user3 = self.register_user("user3", "pass")
|
||||
self.tok3 = self.login("user3", "pass")
|
||||
self.last_media_upload_limit_exceeded = None
|
||||
self.hs.get_module_api().register_media_repository_callbacks(
|
||||
get_media_upload_limits_for_user=self._get_media_upload_limits_for_user,
|
||||
on_media_upload_limit_exceeded=self._on_media_upload_limit_exceeded,
|
||||
)
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
resources = super().create_resource_dict()
|
||||
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
|
||||
return resources
|
||||
|
||||
def upload_media(self, size: int, tok: str) -> FakeChannel:
|
||||
"""Helper to upload media of a given size with a given token."""
|
||||
return self.make_request(
|
||||
"POST",
|
||||
"/_matrix/media/v3/upload",
|
||||
content=b"0" * size,
|
||||
access_token=tok,
|
||||
shorthand=False,
|
||||
content_type=b"text/plain",
|
||||
custom_headers=[("Content-Length", str(size))],
|
||||
)
|
||||
|
||||
def test_upload_under_limit(self) -> None:
|
||||
"""Test that uploading media under the limit works."""
|
||||
|
||||
# User 1 uploads 100 bytes
|
||||
channel = self.upload_media(100, self.tok1)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# User 2 (unlimited) uploads 100 bytes
|
||||
channel = self.upload_media(100, self.tok2)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# User 3 (default) uploads 100 bytes
|
||||
channel = self.upload_media(100, self.tok3)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded, None)
|
||||
|
||||
def test_uses_custom_limit(self) -> None:
|
||||
"""Test that uploading media over the module provided daily limit fails."""
|
||||
|
||||
# User 1 uploads 3000 bytes
|
||||
channel = self.upload_media(3000, self.tok1)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# User 1 attempts to upload 4000 bytes taking it over the limit
|
||||
channel = self.upload_media(4000, self.tok1)
|
||||
self.assertEqual(channel.code, 400)
|
||||
assert self.last_media_upload_limit_exceeded is not None
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user1)
|
||||
self.assertEqual(
|
||||
self.last_media_upload_limit_exceeded["limit"],
|
||||
MediaUploadLimit(
|
||||
max_bytes=5000, time_period_ms=Config.parse_duration("1d")
|
||||
),
|
||||
)
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 3000)
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["attempted_bytes"], 4000)
|
||||
|
||||
# User 1 attempts to upload 20000 bytes which is over the weekly limit
|
||||
# This tests that the limits have been sorted as expected
|
||||
channel = self.upload_media(20000, self.tok1)
|
||||
self.assertEqual(channel.code, 400)
|
||||
assert self.last_media_upload_limit_exceeded is not None
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user1)
|
||||
self.assertEqual(
|
||||
self.last_media_upload_limit_exceeded["limit"],
|
||||
MediaUploadLimit(
|
||||
max_bytes=15000, time_period_ms=Config.parse_duration("1w")
|
||||
),
|
||||
)
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 3000)
|
||||
self.assertEqual(
|
||||
self.last_media_upload_limit_exceeded["attempted_bytes"], 20000
|
||||
)
|
||||
|
||||
def test_uses_unlimited(self) -> None:
|
||||
"""Test that unlimited user is not limited when module returns []."""
|
||||
# User 2 uploads 10000 bytes which is over the default limit
|
||||
channel = self.upload_media(10000, self.tok2)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded, None)
|
||||
|
||||
def test_uses_defaults(self) -> None:
|
||||
"""Test that the default limits are applied when module returned None."""
|
||||
# User 3 uploads 500 bytes
|
||||
channel = self.upload_media(500, self.tok3)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# User 3 uploads 800 bytes which is over the limit
|
||||
channel = self.upload_media(800, self.tok3)
|
||||
self.assertEqual(channel.code, 400)
|
||||
assert self.last_media_upload_limit_exceeded is not None
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user3)
|
||||
self.assertEqual(
|
||||
self.last_media_upload_limit_exceeded["limit"],
|
||||
MediaUploadLimit(
|
||||
max_bytes=1024, time_period_ms=Config.parse_duration("1d")
|
||||
),
|
||||
)
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 500)
|
||||
self.assertEqual(self.last_media_upload_limit_exceeded["attempted_bytes"], 800)
|
||||
|
||||
@@ -2245,7 +2245,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_topo_token_is_accepted(self) -> None:
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
)
|
||||
@@ -2256,7 +2256,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||
token = "s0_0_0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
)
|
||||
|
||||
@@ -189,19 +189,19 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
|
||||
|
||||
subscriptions = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions_for_user(
|
||||
self.store.get_latest_updated_thread_subscriptions_for_user(
|
||||
self.user_id,
|
||||
from_id=0,
|
||||
to_id=50,
|
||||
limit=50,
|
||||
)
|
||||
)
|
||||
min_id = min(id for (id, _, _) in subscriptions)
|
||||
min_id = min(id for (id, _, _, _, _) in subscriptions)
|
||||
self.assertEqual(
|
||||
subscriptions,
|
||||
[
|
||||
(min_id, self.room_id, self.thread_root_id),
|
||||
(min_id + 1, self.room_id, self.other_thread_root_id),
|
||||
(min_id, self.room_id, self.thread_root_id, True, True),
|
||||
(min_id + 1, self.room_id, self.other_thread_root_id, True, False),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -212,7 +212,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Check user has no subscriptions
|
||||
subscriptions = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions_for_user(
|
||||
self.store.get_latest_updated_thread_subscriptions_for_user(
|
||||
self.user_id,
|
||||
from_id=0,
|
||||
to_id=50,
|
||||
@@ -280,20 +280,22 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Get updates for main user
|
||||
updates = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions_for_user(
|
||||
self.store.get_latest_updated_thread_subscriptions_for_user(
|
||||
self.user_id, from_id=0, to_id=stream_id2, limit=10
|
||||
)
|
||||
)
|
||||
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
|
||||
self.assertEqual(
|
||||
updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)]
|
||||
)
|
||||
|
||||
# Get updates for other user
|
||||
updates = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions_for_user(
|
||||
self.store.get_latest_updated_thread_subscriptions_for_user(
|
||||
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
|
||||
updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)]
|
||||
)
|
||||
|
||||
def test_should_skip_autosubscription_after_unsubscription(self) -> None:
|
||||
|
||||
Reference in New Issue
Block a user