Compare commits

..

18 Commits

Author SHA1 Message Date
Eric Eastwood
4b3d923a47 Add changelog 2025-09-17 13:52:19 -05:00
Eric Eastwood
1d731ec93d Remove server_name changes
From https://github.com/element-hq/synapse/pull/18868
2025-09-17 13:46:34 -05:00
Eric Eastwood
c076082191 Fill out other entrypoints
Changes from 3a5bab7a13

```
git checkout 3a5bab7a13 -- synapse/app/admin_cmd.py synapse/app/appservice.py synapse/app/client_reader.py synapse/app/event_creator.py synapse/app/federation_reader.py synapse/app/federation_sender.py synapse/app/frontend_proxy.py synapse/app/generic_worker.py synapse/app/media_repository.py synapse/app/pusher.py synapse/app/synchrotron.py synapse/app/user_dir.py
```
2025-09-17 13:42:04 -05:00
Eric Eastwood
37ea1ae686 Split loading config vs setting up the homeserver
This allows us to get access to `server_name` before
which we may want to use in the `with LoggingContext("main"):`
call early on.

This also allows us more flexibility to parse config however
we want and setup a Synapse homeserver. Like what we do
in Synapse Pro for Small Hosts.
2025-09-17 13:35:04 -05:00
Eric Eastwood
84d64251dc Remove sentinel logcontext where we log in setup, start and exit (#18870)
Remove `sentinel` logcontext where we log in `setup`, `start`, and exit.

Instead of having one giant PR that removes all places we use `sentinel`
logcontext, I've decided to tackle this more piece-meal. This PR covers
the parts if you just startup Synapse and exit it with no requests or
activity going on in between.

Part of https://github.com/element-hq/synapse/issues/18905 (Remove
`sentinel` logcontext where we log in Synapse)

Prerequisite for https://github.com/element-hq/synapse/pull/18868.
Logging with the `sentinel` logcontext means we won't know which server
the log came from.



### Why


9cc4001778/docs/log_contexts.md (L71-L81)

(docs updated in https://github.com/element-hq/synapse/pull/18900)


### Testing strategy

1. Run Synapse normally and with `daemonize: true`: `poetry run
synapse_homeserver --config-path homeserver.yaml`
 1. Execute some requests
 1. Shutdown the server
 1. Look for any bad log entries in your homeserver logs:
    - `Expected logging context sentinel but found main`
    - `Expected logging context main was lost`
    - `Expected previous context`
    - `utime went backwards!`/`stime went backwards!`
- `Called stop on logcontext POST-0 without recording a start rusage`
 1. Look for any logs coming from the `sentinel` context


With these changes, you should only see the following logs (not from
Synapse) using the `sentinel` context if you start up Synapse and exit:

`homeserver.log`
```
2025-09-10 14:45:39,924 - asyncio - 64 - DEBUG - sentinel - Using selector: EpollSelector

2025-09-10 14:45:40,562 - twisted - 281 - INFO - sentinel - Received SIGINT, shutting down.

2025-09-10 14:45:40,562 - twisted - 281 - INFO - sentinel - (TCP Port 9322 Closed)
2025-09-10 14:45:40,563 - twisted - 281 - INFO - sentinel - (TCP Port 8008 Closed)
2025-09-10 14:45:40,563 - twisted - 281 - INFO - sentinel - (TCP Port 9093 Closed)
2025-09-10 14:45:40,564 - twisted - 281 - INFO - sentinel - Main loop terminated.
```
2025-09-16 17:15:08 -05:00
dependabot[bot]
2bed3fb566 Bump serde from 1.0.219 to 1.0.223 (#18920)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-15 20:05:23 +01:00
dependabot[bot]
2c60b67a95 Bump types-setuptools from 80.9.0.20250809 to 80.9.0.20250822 (#18924)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-15 17:37:43 +01:00
dependabot[bot]
6358afff8d Bump pydantic from 2.11.7 to 2.11.9 (#18922)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-15 17:37:24 +01:00
dependabot[bot]
f7b547e2d8 Bump authlib from 1.6.1 to 1.6.3 (#18921)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-15 17:35:11 +01:00
dependabot[bot]
8f7bd946de Bump serde_json from 1.0.143 to 1.0.145 (#18919)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-15 17:31:12 +01:00
dependabot[bot]
4f80fa4b0a Bump types-psycopg2 from 2.9.21.20250809 to 2.9.21.20250915 (#18918)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-15 17:29:49 +01:00
dependabot[bot]
b2592667a4 Bump sigstore/cosign-installer from 3.9.2 to 3.10.0 (#18917)
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-09-15 17:26:04 +01:00
Eric Eastwood
769d30a247 Clarify Python dependency constraints (#18856)
Clarify Python dependency constraints

Spawning from
https://github.com/element-hq/synapse/pull/18852#issuecomment-3212003675
as I don't actually know the the exact rule of thumb. It's unclear to me
what we care about exactly. Our [deprecation
policy](https://element-hq.github.io/synapse/latest/deprecation_policy.html)
mentions Debian oldstable support at-least for the version of SQLite.
But then we only refer to Debian stable for the Twisted dependency.
2025-09-15 09:45:41 -05:00
Eric Eastwood
7ecfe8b1a8 Better explain which context the task is run in when using run_in_background(...) or run_as_background_process(...) (#18906)
Follow-up to https://github.com/element-hq/synapse/pull/18900
2025-09-12 09:29:35 -05:00
Hugh Nimmo-Smith
e1036ffa48 Add get_media_upload_limits_for_user and on_media_upload_limit_exceeded callbacks to module API (#18848)
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2025-09-12 12:26:19 +01:00
Andrew Morgan
8c98cf7e55 Remove usage of deprecated pkg_resources interface (#18910) 2025-09-12 10:57:04 +01:00
Kegan Dougal
ec64c3e88d Ensure we /send PDUs which pass canonical JSON checks (#18641)
### Pull Request Checklist

Fixes https://github.com/element-hq/synapse/issues/18554

Looks like this was missed when it was
[implemented](2277df2a1e).

<!-- Please read
https://element-hq.github.io/synapse/latest/development/contributing_guide.html
before submitting your pull request -->

* [x] Pull request is based on the develop branch
* [x] Pull request includes a [changelog
file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog).
The entry should:
- Be a short description of your change which makes sense to users.
"Fixed a bug that prevented receiving messages from other servers."
instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
  - Use markdown where necessary, mostly for `code blocks`.
  - End with either a period (.) or an exclamation mark (!).
  - Start with a capital letter.
- Feel free to credit yourself, by adding a sentence "Contributed by
@github_username." or "Contributed by [Your Name]." to the end of the
entry.
* [x] [Code
style](https://element-hq.github.io/synapse/latest/code_style.html) is
correct (run the
[linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))

---------

Co-authored-by: reivilibre <oliverw@element.io>
2025-09-12 08:54:20 +00:00
reivilibre
ada3a3b2b3 Add experimental support for MSC4308: Thread Subscriptions extension to Sliding Sync when MSC4306 and MSC4186 are enabled. (#18695)
Closes: #18436

Implements:
https://github.com/matrix-org/matrix-spec-proposals/pull/4308

Follows: #18674

Adds an extension to Sliding Sync and a companion
endpoint needed for backpaginating missed thread subscription changes,
as described in MSC4308

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2025-09-11 14:45:04 +01:00
64 changed files with 1654 additions and 188 deletions

View File

@@ -120,7 +120,7 @@ jobs:
uses: docker/setup-buildx-action@e468171a9de216ec08956ac3ada2f0791b6bd435 # v3.11.1
- name: Install Cosign
uses: sigstore/cosign-installer@d58896d6a1865668819e1d91763c7751a165e159 # v3.9.2
uses: sigstore/cosign-installer@d7543c93d881b35a8faa02e8e3605f69b7a1ce62 # v3.10.0
- name: Calculate docker image tag
uses: docker/metadata-action@c1e51972afc2121e065aed6d45c65596fe445f3f # v5.8.0

23
Cargo.lock generated
View File

@@ -1250,18 +1250,28 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.219"
version = "1.0.224"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
checksum = "6aaeb1e94f53b16384af593c71e20b095e958dab1d26939c1b70645c5cfbcc0b"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.224"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32f39390fa6346e24defbcdd3d9544ba8a19985d0af74df8501fbfe9a64341ab"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.219"
version = "1.0.224"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
checksum = "87ff78ab5e8561c9a675bfc1785cb07ae721f0ee53329a595cefd8c04c2ac4e0"
dependencies = [
"proc-macro2",
"quote",
@@ -1270,14 +1280,15 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.143"
version = "1.0.145"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a"
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
]
[[package]]

1
changelog.d/18641.bugfix Normal file
View File

@@ -0,0 +1 @@
Ensure all PDUs sent via `/send` pass canonical JSON checks.

View File

@@ -0,0 +1 @@
Add experimental support for [MSC4308: Thread Subscriptions extension to Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4308) when [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/pull/4306) and [MSC4186: Simplified Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4186) are enabled.

View File

@@ -0,0 +1 @@
Add `get_media_upload_limits_for_user` and `on_media_upload_limit_exceeded` module API callbacks for media repository.

1
changelog.d/18856.doc Normal file
View File

@@ -0,0 +1 @@
Clarify Python dependency constraints in our deprecation policy.

1
changelog.d/18870.misc Normal file
View File

@@ -0,0 +1 @@
Remove `sentinel` logcontext usage where we log in `setup`, `start` and exit.

1
changelog.d/18906.misc Normal file
View File

@@ -0,0 +1 @@
Better explain how we manage the logcontext in `run_in_background(...)` and `run_as_background_process(...)`.

1
changelog.d/18933.misc Normal file
View File

@@ -0,0 +1 @@
Split loading config from homeserver `setup`.

View File

@@ -1,13 +1,11 @@
Deprecation Policy for Platform Dependencies
============================================
# Deprecation Policy
Synapse has a number of platform dependencies, including Python, Rust,
PostgreSQL and SQLite. This document outlines the policy towards which versions
we support, and when we drop support for versions in the future.
Synapse has a number of **platform dependencies** (Python, Rust, PostgreSQL, and SQLite)
and **application dependencies** (Python and Rust packages). This document outlines the
policy towards which versions we support, and when we drop support for versions in the
future.
Policy
------
## Platform Dependencies
Synapse follows the upstream support life cycles for Python and PostgreSQL,
i.e. when a version reaches End of Life Synapse will withdraw support for that
@@ -26,8 +24,8 @@ The oldest supported version of SQLite is the version
[provided](https://packages.debian.org/bullseye/libsqlite3-0) by
[Debian oldstable](https://wiki.debian.org/DebianOldStable).
Context
-------
### Context
It is important for system admins to have a clear understanding of the platform
requirements of Synapse and its deprecation policies so that they can
@@ -50,4 +48,42 @@ the ecosystem.
On a similar note, SQLite does not generally have a concept of "supported
release"; bugfixes are published for the latest minor release only. We chose to
track Debian's oldstable as this is relatively conservative, predictably updated
and is consistent with the `.deb` packages released by Matrix.org.
and is consistent with the `.deb` packages released by Matrix.org.
## Application dependencies
For application-level Python dependencies, we often specify loose version constraints
(ex. `>=X.Y.Z`) to be forwards compatible with any new versions. Upper bounds (`<A.B.C`)
are only added when necessary to prevent known incompatibilities.
When selecting a minimum version, while we are mindful of the impact on downstream
package maintainers, our primary focus is on the maintainability and progress of Synapse
itself.
For developers, a Python dependency version can be considered a "no-brainer" upgrade once it is
available in both the latest [Debian Stable](https://packages.debian.org/stable/) and
[Ubuntu LTS](https://launchpad.net/ubuntu) repositories. No need to burden yourself with
extra scrutiny or consideration at this point.
We aggressively update Rust dependencies. Since these are statically linked and managed
entirely by `cargo` during build, they *can* pose no ongoing maintenance burden on others.
This allows us to freely upgrade to leverage the latest ecosystem advancements assuming
they don't have their own system-level dependencies.
### Context
Because Python dependencies can easily be managed in a virtual environment, we are less
concerned about the criteria for selecting minimum versions. The only thing of concern
is making sure we're not making it unnecessarily difficult for downstream package
maintainers. Generally, this just means avoiding the bleeding edge for a few months.
The situation for Rust dependencies is fundamentally different. For packagers, the
concerns around Python dependency versions do not apply. The `cargo` tool handles
downloading and building all libraries to satisfy dependencies, and these libraries are
statically linked into the final binary. This means that from a packager's perspective,
the Rust dependency versions are an internal build detail, not a runtime dependency to
be managed on the target system. Consequently, we have even greater flexibility to
upgrade Rust dependencies as needed for the project. Some distros (e.g. Fedora) do
package Rust libraries, but this appears to be the outlier rather than the norm.

View File

@@ -64,3 +64,68 @@ If multiple modules implement this callback, they will be considered in order. I
returns `True`, Synapse falls through to the next one. The value of the first callback that
returns `False` will be used. If this happens, Synapse will not call any of the subsequent
implementations of this callback.
### `get_media_upload_limits_for_user`
_First introduced in Synapse v1.139.0_
```python
async def get_media_upload_limits_for_user(user_id: str, size: int) -> Optional[List[synapse.module_api.MediaUploadLimit]]
```
**<span style="color:red">
Caution: This callback is currently experimental. The method signature or behaviour
may change without notice.
</span>**
Called when processing a request to store content in the media repository. This can be used to dynamically override
the [media upload limits configuration](../usage/configuration/config_documentation.html#media_upload_limits).
The arguments passed to this callback are:
* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request.
If the callback returns a list then it will be used as the limits instead of those in the configuration (if any).
If an empty list is returned then no limits are applied (**warning:** users will be able
to upload as much data as they desire).
If multiple modules implement this callback, they will be considered in order. If a
callback returns `None`, Synapse falls through to the next one. The value of the first
callback that does not return `None` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback.
If there are no registered modules, or if all modules return `None`, then
the default
[media upload limits configuration](../usage/configuration/config_documentation.html#media_upload_limits)
will be used.
### `on_media_upload_limit_exceeded`
_First introduced in Synapse v1.139.0_
```python
async def on_media_upload_limit_exceeded(user_id: str, limit: synapse.module_api.MediaUploadLimit, sent_bytes: int, attempted_bytes: int) -> None
```
**<span style="color:red">
Caution: This callback is currently experimental. The method signature or behaviour
may change without notice.
</span>**
Called when a user attempts to upload media that would exceed a
[configured media upload limit](../usage/configuration/config_documentation.html#media_upload_limits).
This callback will only be called on workers which handle
[POST /_matrix/media/v3/upload](https://spec.matrix.org/v1.15/client-server-api/#post_matrixmediav3upload)
requests.
This could be used to inform the user that they have reached a media upload limit through
some external method.
The arguments passed to this callback are:
* `user_id`: The Matrix user ID of the user (e.g. `@alice:example.com`) making the request.
* `limit`: The `synapse.module_api.MediaUploadLimit` representing the limit that was reached.
* `sent_bytes`: The number of bytes already sent during the period of the limit.
* `attempted_bytes`: The number of bytes that the user attempted to send.

View File

@@ -2168,9 +2168,12 @@ max_upload_size: 60M
### `media_upload_limits`
*(array)* A list of media upload limits defining how much data a given user can upload in a given time period.
These limits are applied in addition to the `max_upload_size` limit above (which applies to individual uploads).
An empty list means no limits are applied.
These settings can be overridden using the `get_media_upload_limits_for_user` module API [callback](../../modules/media_repository_callbacks.md#get_media_upload_limits_for_user).
Defaults to `[]`.
Example configuration:

24
poetry.lock generated
View File

@@ -34,15 +34,15 @@ tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" a
[[package]]
name = "authlib"
version = "1.6.1"
version = "1.6.3"
description = "The ultimate Python library in building OAuth and OpenID Connect servers and clients."
optional = true
python-versions = ">=3.9"
groups = ["main"]
markers = "extra == \"all\" or extra == \"jwt\" or extra == \"oidc\""
files = [
{file = "authlib-1.6.1-py2.py3-none-any.whl", hash = "sha256:e9d2031c34c6309373ab845afc24168fe9e93dc52d252631f52642f21f5ed06e"},
{file = "authlib-1.6.1.tar.gz", hash = "sha256:4dffdbb1460ba6ec8c17981a4c67af7d8af131231b5a36a88a1e8c80c111cdfd"},
{file = "authlib-1.6.3-py2.py3-none-any.whl", hash = "sha256:7ea0f082edd95a03b7b72edac65ec7f8f68d703017d7e37573aee4fc603f2a48"},
{file = "authlib-1.6.3.tar.gz", hash = "sha256:9f7a982cc395de719e4c2215c5707e7ea690ecf84f1ab126f28c053f4219e610"},
]
[package.dependencies]
@@ -1774,14 +1774,14 @@ files = [
[[package]]
name = "pydantic"
version = "2.11.7"
version = "2.11.9"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.9"
groups = ["main", "dev"]
files = [
{file = "pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b"},
{file = "pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db"},
{file = "pydantic-2.11.9-py3-none-any.whl", hash = "sha256:c42dd626f5cfc1c6950ce6205ea58c93efa406da65f479dcb4029d5934857da2"},
{file = "pydantic-2.11.9.tar.gz", hash = "sha256:6b8ffda597a14812a7975c90b82a8a2e777d9257aba3453f973acd3c032a18e2"},
]
[package.dependencies]
@@ -2971,14 +2971,14 @@ files = [
[[package]]
name = "types-psycopg2"
version = "2.9.21.20250809"
version = "2.9.21.20250915"
description = "Typing stubs for psycopg2"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "types_psycopg2-2.9.21.20250809-py3-none-any.whl", hash = "sha256:59b7b0ed56dcae9efae62b8373497274fc1a0484bdc5135cdacbe5a8f44e1d7b"},
{file = "types_psycopg2-2.9.21.20250809.tar.gz", hash = "sha256:b7c2cbdcf7c0bd16240f59ba694347329b0463e43398de69784ea4dee45f3c6d"},
{file = "types_psycopg2-2.9.21.20250915-py3-none-any.whl", hash = "sha256:eefe5ccdc693fc086146e84c9ba437bb278efe1ef330b299a0cb71169dc6c55f"},
{file = "types_psycopg2-2.9.21.20250915.tar.gz", hash = "sha256:bfeb8f54c32490e7b5edc46215ab4163693192bc90407b4a023822de9239f5c8"},
]
[[package]]
@@ -3026,14 +3026,14 @@ urllib3 = ">=2"
[[package]]
name = "types-setuptools"
version = "80.9.0.20250809"
version = "80.9.0.20250822"
description = "Typing stubs for setuptools"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "types_setuptools-80.9.0.20250809-py3-none-any.whl", hash = "sha256:7c6539b4c7ac7b4ab4db2be66d8a58fb1e28affa3ee3834be48acafd94f5976a"},
{file = "types_setuptools-80.9.0.20250809.tar.gz", hash = "sha256:e986ba37ffde364073d76189e1d79d9928fb6f5278c7d07589cde353d0218864"},
{file = "types_setuptools-80.9.0.20250822-py3-none-any.whl", hash = "sha256:53bf881cb9d7e46ed12c76ef76c0aaf28cfe6211d3fab12e0b83620b1a8642c3"},
{file = "types_setuptools-80.9.0.20250822.tar.gz", hash = "sha256:070ea7716968ec67a84c7f7768d9952ff24d28b65b6594797a464f1b3066f965"},
]
[[package]]

View File

@@ -2415,8 +2415,15 @@ properties:
A list of media upload limits defining how much data a given user can
upload in a given time period.
These limits are applied in addition to the `max_upload_size` limit above
(which applies to individual uploads).
An empty list means no limits are applied.
These settings can be overridden using the `get_media_upload_limits_for_user`
module API [callback](../../modules/media_repository_callbacks.md#get_media_upload_limits_for_user).
default: []
items:
time_period:

View File

@@ -72,7 +72,7 @@ from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.events.presence_router import load_legacy_presence_router
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.http.site import SynapseSite
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import init_tracer
from synapse.metrics import install_gc_manager, register_threadpool
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -183,25 +183,23 @@ def start_reactor(
if gc_thresholds:
gc.set_threshold(*gc_thresholds)
install_gc_manager()
run_command()
# make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
# and complain when they see the logcontext arbitrarily swapping
# between the sentinel and `run` logcontexts.
#
# We also need to drop the logcontext before forking if we're daemonizing,
# otherwise the cputime metrics get confused about the per-thread resource usage
# appearing to go backwards.
with PreserveLoggingContext():
if daemonize:
assert pid_file is not None
# Reset the logging context when we start the reactor (whenever we yield control
# to the reactor, the `sentinel` logging context needs to be set so we don't
# leak the current logging context and erroneously apply it to the next task the
# reactor event loop picks up)
with PreserveLoggingContext():
run_command()
if print_pidfile:
print(pid_file)
if daemonize:
assert pid_file is not None
daemonize_process(pid_file, logger)
run()
if print_pidfile:
print(pid_file)
daemonize_process(pid_file, logger)
run()
def quit_with_error(error_string: str) -> NoReturn:
@@ -601,10 +599,12 @@ async def start(hs: "HomeServer") -> None:
hs.get_datastores().main.db_pool.start_profiling()
hs.get_pusherpool().start()
def log_shutdown() -> None:
with LoggingContext("log_shutdown"):
logger.info("Shutting down...")
# Log when we start the shut down process.
hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", logger.info, "Shutting down..."
)
hs.get_reactor().addSystemEventTrigger("before", "shutdown", log_shutdown)
setup_sentry(hs)
setup_sdnotify(hs)

View File

@@ -24,7 +24,7 @@ import logging
import os
import sys
import tempfile
from typing import List, Mapping, Optional, Sequence
from typing import List, Mapping, Optional, Sequence, Tuple
from twisted.internet import defer, task
@@ -256,7 +256,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
return self.base_directory
def start(config_options: List[str]) -> None:
def load_config(argv_options: List[str]) -> Tuple[HomeServerConfig, argparse.Namespace]:
parser = argparse.ArgumentParser(description="Synapse Admin Command")
HomeServerConfig.add_arguments_to_parser(parser)
@@ -282,11 +282,15 @@ def start(config_options: List[str]) -> None:
export_data_parser.set_defaults(func=export_data_command)
try:
config, args = HomeServerConfig.load_config_with_parser(parser, config_options)
config, args = HomeServerConfig.load_config_with_parser(parser, argv_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
return config, args
def start(config: HomeServerConfig, args: argparse.Namespace) -> None:
if config.worker.worker_app is not None:
assert config.worker.worker_app == "synapse.app.admin_cmd"
@@ -325,7 +329,7 @@ def start(config_options: List[str]) -> None:
# command.
async def run() -> None:
with LoggingContext("command"):
with LoggingContext(name="command"):
await _base.start(ss)
await args.func(ss, args)
@@ -337,5 +341,6 @@ def start(config_options: List[str]) -> None:
if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config, args = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config, args)

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -20,13 +20,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -310,13 +310,26 @@ class GenericWorkerServer(HomeServer):
self.get_replication_command_handler().start_replication(self)
def start(config_options: List[str]) -> None:
def load_config(argv_options: List[str]) -> HomeServerConfig:
"""
Parse the commandline and config files (does not generate config)
Args:
argv_options: The options passed to Synapse. Usually `sys.argv[1:]`.
Returns:
Config object.
"""
try:
config = HomeServerConfig.load_config("Synapse worker", config_options)
config = HomeServerConfig.load_config("Synapse worker", argv_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
return config
def start(config: HomeServerConfig) -> None:
# For backwards compatibility let any of the old app names.
assert config.worker.worker_app in (
"synapse.app.appservice",
@@ -365,8 +378,9 @@ def start(config_options: List[str]) -> None:
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -308,17 +308,21 @@ class SynapseHomeServer(HomeServer):
logger.warning("Unrecognized listener type: %s", listener.type)
def setup(config_options: List[str]) -> SynapseHomeServer:
def load_or_generate_config(argv_options: List[str]) -> HomeServerConfig:
"""
Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app.
Args:
config_options_options: The options passed to Synapse. Usually `sys.argv[1:]`.
argv_options: The options passed to Synapse. Usually `sys.argv[1:]`.
Returns:
A homeserver instance.
"""
try:
config = HomeServerConfig.load_or_generate_config(
"Synapse Homeserver", config_options
"Synapse Homeserver", argv_options
)
except ConfigError as e:
sys.stderr.write("\n")
@@ -332,6 +336,20 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
# generating config files and shouldn't try to continue.
sys.exit(0)
return config
def setup(config: HomeServerConfig) -> SynapseHomeServer:
"""
Create and setup a Synapse homeserver instance given a configuration.
Args:
config: The configuration for the homeserver.
Returns:
A homeserver instance.
"""
if config.worker.worker_app:
raise ConfigError(
"You have specified `worker_app` in the config but are attempting to start a non-worker "
@@ -377,15 +395,17 @@ def setup(config_options: List[str]) -> SynapseHomeServer:
handle_startup_exception(e)
async def start() -> None:
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
# Re-establish log context now that we're back from the reactor
with LoggingContext("start"):
# Load the OIDC provider metadatas, if OIDC is enabled.
if hs.config.oidc.oidc_enabled:
oidc = hs.get_oidc_handler()
# Loading the provider metadata also ensures the provider config is valid.
await oidc.load_metadata()
await _base.start(hs)
await _base.start(hs)
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
hs.get_datastores().main.db_pool.updates.start_doing_background_updates()
register_start(start)
@@ -405,10 +425,12 @@ def run(hs: HomeServer) -> None:
def main() -> None:
homeserver_config = load_or_generate_config(sys.argv[1:])
with LoggingContext("main"):
# check base requirements
check_requirements()
hs = setup(sys.argv[1:])
hs = setup(homeserver_config)
# redirect stdio to the logs, if configured.
if not hs.config.logging.no_redirect_stdio:

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -21,13 +21,14 @@
import sys
from synapse.app.generic_worker import start
from synapse.app.generic_worker import load_config, start
from synapse.util.logcontext import LoggingContext
def main() -> None:
with LoggingContext("main"):
start(sys.argv[1:])
homeserver_config = load_config(sys.argv[1:])
with LoggingContext(name="main"):
start(homeserver_config)
if __name__ == "__main__":

View File

@@ -646,12 +646,16 @@ class RootConfig:
@classmethod
def load_or_generate_config(
cls: Type[TRootConfig], description: str, argv: List[str]
cls: Type[TRootConfig], description: str, argv_options: List[str]
) -> Optional[TRootConfig]:
"""Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app.
Args:
description: TODO
argv_options: The options passed to Synapse. Usually `sys.argv[1:]`.
Returns:
Config object, or None if --generate-config or --generate-keys was set
"""
@@ -747,7 +751,7 @@ class RootConfig:
)
cls.invoke_all_static("add_arguments", parser)
config_args = parser.parse_args(argv)
config_args = parser.parse_args(argv_options)
config_files = find_config_files(search_paths=config_args.config_path)

View File

@@ -590,5 +590,5 @@ class ExperimentalConfig(Config):
self.msc4293_enabled: bool = experimental.get("msc4293_enabled", False)
# MSC4306: Thread Subscriptions
# (and MSC4308: sliding sync extension for thread subscriptions)
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)

View File

@@ -120,11 +120,19 @@ def parse_thumbnail_requirements(
@attr.s(auto_attribs=True, slots=True, frozen=True)
class MediaUploadLimit:
"""A limit on the amount of data a user can upload in a given time
period."""
"""
Represents a limit on the amount of data a user can upload in a given time
period.
These can be configured through the `media_upload_limits` [config option](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#media_upload_limits)
or via the `get_media_upload_limits_for_user` module API [callback](https://element-hq.github.io/synapse/latest/modules/media_repository_callbacks.html#get_media_upload_limits_for_user).
"""
max_bytes: int
"""The maximum number of bytes that can be uploaded in the given time period."""
time_period_ms: int
"""The time period in milliseconds."""
class ContentRepositoryConfig(Config):

View File

@@ -26,7 +26,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.federation.units import Edu, Transaction, serialize_and_filter_pdus
from synapse.logging.opentracing import (
extract_text_map,
set_tag,
@@ -119,7 +119,7 @@ class TransactionManager:
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=[p.get_pdu_json() for p in pdus],
pdus=serialize_and_filter_pdus(pdus),
edus=[edu.get_dict() for edu in edus],
)

View File

@@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access:
raise FederationDeniedError(origin)
limit = parse_integer_from_args(query, "limit", 0)
limit: Optional[int] = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
query, "include_all_networks", default=False

View File

@@ -211,7 +211,7 @@ class SlidingSyncHandler:
Args:
sync_config: Sync configuration
to_token: The point in the stream to sync up to.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from. Token of the end of the
previous batch. May be `None` if this is the initial sync request.
"""

View File

@@ -27,7 +27,7 @@ from typing import (
cast,
)
from typing_extensions import assert_never
from typing_extensions import TypeAlias, assert_never
from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.handlers.receipts import ReceiptEventSource
@@ -40,6 +40,7 @@ from synapse.types import (
SlidingSyncStreamToken,
StrCollection,
StreamToken,
ThreadSubscriptionsToken,
)
from synapse.types.handlers.sliding_sync import (
HaveSentRoomFlag,
@@ -54,6 +55,13 @@ from synapse.util.async_helpers import (
gather_optional_coroutines,
)
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -68,6 +76,7 @@ class SlidingSyncExtensionHandler:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
@trace
async def get_extensions_response(
@@ -93,7 +102,7 @@ class SlidingSyncExtensionHandler:
actual_room_ids: The actual room IDs in the the Sliding Sync response.
actual_room_response_map: A map of room ID to room results in the the
Sliding Sync response.
to_token: The point in the stream to sync up to.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
@@ -156,18 +165,32 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
thread_subs_coro = None
if (
sync_config.extensions.thread_subscriptions is not None
and self._enable_thread_subscriptions
):
thread_subs_coro = self.get_thread_subscriptions_extension_response(
sync_config=sync_config,
thread_subscriptions_request=sync_config.extensions.thread_subscriptions,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
account_data_response,
receipts_response,
typing_response,
thread_subs_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
account_data_coro,
receipts_coro,
typing_coro,
thread_subs_coro,
)
return SlidingSyncResult.Extensions(
@@ -176,6 +199,7 @@ class SlidingSyncExtensionHandler:
account_data=account_data_response,
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
)
def find_relevant_room_ids_for_extension(
@@ -877,3 +901,72 @@ class SlidingSyncExtensionHandler:
return SlidingSyncResult.Extensions.TypingExtension(
room_id_to_typing_map=room_id_to_typing_map,
)
async def get_thread_subscriptions_extension_response(
self,
sync_config: SlidingSyncConfig,
thread_subscriptions_request: SlidingSyncConfig.Extensions.ThreadSubscriptionsExtension,
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> Optional[SlidingSyncResult.Extensions.ThreadSubscriptionsExtension]:
"""Handle Thread Subscriptions extension (MSC4308)
Args:
sync_config: Sync configuration
thread_subscriptions_request: The thread_subscriptions extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
Returns:
the response (None if empty or thread subscriptions are disabled)
"""
if not thread_subscriptions_request.enabled:
return None
limit = thread_subscriptions_request.limit
if from_token:
from_stream_id = from_token.stream_token.thread_subscriptions_key
else:
from_stream_id = StreamToken.START.thread_subscriptions_key
to_stream_id = to_token.thread_subscriptions_key
updates = await self.store.get_latest_updated_thread_subscriptions_for_user(
user_id=sync_config.user.to_string(),
from_id=from_stream_id,
to_id=to_stream_id,
limit=limit,
)
if len(updates) == 0:
return None
subscribed_threads: Dict[str, Dict[str, _ThreadSubscription]] = {}
unsubscribed_threads: Dict[str, Dict[str, _ThreadUnsubscription]] = {}
for stream_id, room_id, thread_root_id, subscribed, automatic in updates:
if subscribed:
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
_ThreadSubscription(
automatic=automatic,
bump_stamp=stream_id,
)
)
else:
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
_ThreadUnsubscription(bump_stamp=stream_id)
)
prev_batch = None
if len(updates) == limit:
# Tell the client about a potential gap where there may be more
# thread subscriptions for it to backpaginate.
# We subtract one because the 'later in the stream' bound is inclusive,
# and we already saw the element at index 0.
prev_batch = ThreadSubscriptionsToken(updates[0][0] - 1)
return SlidingSyncResult.Extensions.ThreadSubscriptionsExtension(
subscribed=subscribed_threads,
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)

View File

@@ -9,7 +9,7 @@ from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscription,
)
from synapse.types import EventOrderings, UserID
from synapse.types import EventOrderings, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -22,6 +22,7 @@ class ThreadSubscriptionsHandler:
self.store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._notifier = hs.get_notifier()
async def get_thread_subscription_settings(
self,
@@ -132,6 +133,15 @@ class ThreadSubscriptionsHandler:
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
)
if outcome is not None:
# wake up user streams (e.g. sliding sync) on the same worker
self._notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
# outcome is a stream_id
outcome,
users=[user_id.to_string()],
)
return outcome
async def unsubscribe_user_from_thread(
@@ -162,8 +172,19 @@ class ThreadSubscriptionsHandler:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.unsubscribe_user_from_thread(
outcome = await self.store.unsubscribe_user_from_thread(
user_id.to_string(),
event.room_id,
thread_root_event_id,
)
if outcome is not None:
# wake up user streams (e.g. sliding sync) on the same worker
self._notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
# outcome is a stream_id
outcome,
users=[user_id.to_string()],
)
return outcome

View File

@@ -130,6 +130,16 @@ def parse_integer(
return parse_integer_from_args(args, name, default, required, negative)
@overload
def parse_integer_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: int,
required: Literal[False] = False,
negative: bool = False,
) -> int: ...
@overload
def parse_integer_from_args(
args: Mapping[bytes, Sequence[bytes]],

View File

@@ -802,8 +802,9 @@ def run_in_background(
deferred returned by the function completes.
To explain how the log contexts work here:
- When this function is called, the current context is stored ("original"), we kick
off the background task, and we restore that original context before returning
- When `run_in_background` is called, the current context is stored ("original"),
we kick off the background task in the current context, and we restore that
original context before returning
- When the background task finishes, we don't want to leak our context into the
reactor which would erroneously get attached to the next operation picked up by
the event loop. We add a callback to the deferred which will clear the logging
@@ -828,6 +829,7 @@ def run_in_background(
"""
calling_context = current_context()
try:
# (kick off the task in the current context)
res = f(*args, **kwargs)
except Exception:
# the assumption here is that the caller doesn't want to be disturbed

View File

@@ -179,11 +179,13 @@ class MediaRepository:
# We get the media upload limits and sort them in descending order of
# time period, so that we can apply some optimizations.
self.media_upload_limits = hs.config.media.media_upload_limits
self.media_upload_limits.sort(
self.default_media_upload_limits = hs.config.media.media_upload_limits
self.default_media_upload_limits.sort(
key=lambda limit: limit.time_period_ms, reverse=True
)
self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository
def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process(
"update_recently_accessed_media",
@@ -340,16 +342,27 @@ class MediaRepository:
# Check that the user has not exceeded any of the media upload limits.
# Use limits from module API if provided
media_upload_limits = (
await self.media_repository_callbacks.get_media_upload_limits_for_user(
auth_user.to_string()
)
)
# Otherwise use the default limits from config
if media_upload_limits is None:
# Note: the media upload limits are sorted so larger time periods are
# first.
media_upload_limits = self.default_media_upload_limits
# This is the total size of media uploaded by the user in the last
# `time_period_ms` milliseconds, or None if we haven't checked yet.
uploaded_media_size: Optional[int] = None
# Note: the media upload limits are sorted so larger time periods are
# first.
for limit in self.media_upload_limits:
for limit in media_upload_limits:
# We only need to check the amount of media uploaded by the user in
# this latest (smaller) time period if the amount of media uploaded
# in a previous (larger) time period is above the limit.
# in a previous (larger) time period is below the limit.
#
# This optimization means that in the common case where the user
# hasn't uploaded much media, we only need to query the database
@@ -363,6 +376,12 @@ class MediaRepository:
)
if uploaded_media_size + content_length > limit.max_bytes:
await self.media_repository_callbacks.on_media_upload_limit_exceeded(
user_id=auth_user.to_string(),
limit=limit,
sent_bytes=uploaded_media_size,
attempted_bytes=content_length,
)
raise SynapseError(
400, "Media upload limit exceeded", Codes.RESOURCE_LIMIT_EXCEEDED
)

View File

@@ -286,9 +286,11 @@ def run_as_background_process(
).dec()
# To explain how the log contexts work here:
# - When this function is called, the current context is stored (using
# `PreserveLoggingContext`), we kick off the background task, and we restore the
# original context before returning (also part of `PreserveLoggingContext`).
# - When `run_as_background_process` is called, the current context is stored
# (using `PreserveLoggingContext`), we kick off the background task, and we
# restore the original context before returning (also part of
# `PreserveLoggingContext`).
# - The background task runs in its own new logcontext named after `desc`
# - When the background task finishes, we don't want to leak our background context
# into the reactor which would erroneously get attached to the next operation
# picked up by the event loop. We use `PreserveLoggingContext` to set the

View File

@@ -50,6 +50,7 @@ from synapse.api.constants import ProfileFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.config import ConfigError
from synapse.config.repository import MediaUploadLimit
from synapse.events import EventBase
from synapse.events.presence_router import (
GET_INTERESTED_USERS_CALLBACK,
@@ -94,7 +95,9 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
)
from synapse.module_api.callbacks.media_repository_callbacks import (
GET_MEDIA_CONFIG_FOR_USER_CALLBACK,
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK,
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK,
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK,
)
from synapse.module_api.callbacks.ratelimit_callbacks import (
GET_RATELIMIT_OVERRIDE_FOR_USER_CALLBACK,
@@ -205,6 +208,7 @@ __all__ = [
"RoomAlias",
"UserProfile",
"RatelimitOverride",
"MediaUploadLimit",
]
logger = logging.getLogger(__name__)
@@ -462,6 +466,12 @@ class ModuleApi:
is_user_allowed_to_upload_media_of_size: Optional[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = None,
get_media_upload_limits_for_user: Optional[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = None,
on_media_upload_limit_exceeded: Optional[
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK
] = None,
) -> None:
"""Registers callbacks for media repository capabilities.
Added in Synapse v1.132.0.
@@ -469,6 +479,8 @@ class ModuleApi:
return self._callbacks.media_repository.register_callbacks(
get_media_config_for_user=get_media_config_for_user,
is_user_allowed_to_upload_media_of_size=is_user_allowed_to_upload_media_of_size,
get_media_upload_limits_for_user=get_media_upload_limits_for_user,
on_media_upload_limit_exceeded=on_media_upload_limit_exceeded,
)
def register_third_party_rules_callbacks(

View File

@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
from synapse.config.repository import MediaUploadLimit
from synapse.types import JsonDict
from synapse.util.async_helpers import delay_cancellation
from synapse.util.metrics import Measure
@@ -28,6 +29,14 @@ GET_MEDIA_CONFIG_FOR_USER_CALLBACK = Callable[[str], Awaitable[Optional[JsonDict
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK = Callable[[str, int], Awaitable[bool]]
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK = Callable[
[str], Awaitable[Optional[List[MediaUploadLimit]]]
]
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK = Callable[
[str, MediaUploadLimit, int, int], Awaitable[None]
]
class MediaRepositoryModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
@@ -39,6 +48,12 @@ class MediaRepositoryModuleApiCallbacks:
self._is_user_allowed_to_upload_media_of_size_callbacks: List[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = []
self._get_media_upload_limits_for_user_callbacks: List[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = []
self._on_media_upload_limit_exceeded_callbacks: List[
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK
] = []
def register_callbacks(
self,
@@ -46,6 +61,12 @@ class MediaRepositoryModuleApiCallbacks:
is_user_allowed_to_upload_media_of_size: Optional[
IS_USER_ALLOWED_TO_UPLOAD_MEDIA_OF_SIZE_CALLBACK
] = None,
get_media_upload_limits_for_user: Optional[
GET_MEDIA_UPLOAD_LIMITS_FOR_USER_CALLBACK
] = None,
on_media_upload_limit_exceeded: Optional[
ON_MEDIA_UPLOAD_LIMIT_EXCEEDED_CALLBACK
] = None,
) -> None:
"""Register callbacks from module for each hook."""
if get_media_config_for_user is not None:
@@ -56,6 +77,16 @@ class MediaRepositoryModuleApiCallbacks:
is_user_allowed_to_upload_media_of_size
)
if get_media_upload_limits_for_user is not None:
self._get_media_upload_limits_for_user_callbacks.append(
get_media_upload_limits_for_user
)
if on_media_upload_limit_exceeded is not None:
self._on_media_upload_limit_exceeded_callbacks.append(
on_media_upload_limit_exceeded
)
async def get_media_config_for_user(self, user_id: str) -> Optional[JsonDict]:
for callback in self._get_media_config_for_user_callbacks:
with Measure(
@@ -83,3 +114,47 @@ class MediaRepositoryModuleApiCallbacks:
return res
return True
async def get_media_upload_limits_for_user(
self, user_id: str
) -> Optional[List[MediaUploadLimit]]:
"""
Get the first non-None list of MediaUploadLimits for the user from the registered callbacks.
If a list is returned it will be sorted in descending order of duration.
"""
for callback in self._get_media_upload_limits_for_user_callbacks:
with Measure(
self.clock,
name=f"{callback.__module__}.{callback.__qualname__}",
server_name=self.server_name,
):
res: Optional[List[MediaUploadLimit]] = await delay_cancellation(
callback(user_id)
)
if res is not None: # to allow [] to be returned meaning no limit
# We sort them in descending order of time period
res.sort(key=lambda limit: limit.time_period_ms, reverse=True)
return res
return None
async def on_media_upload_limit_exceeded(
self,
user_id: str,
limit: MediaUploadLimit,
sent_bytes: int,
attempted_bytes: int,
) -> None:
for callback in self._on_media_upload_limit_exceeded_callbacks:
with Measure(
self.clock,
name=f"{callback.__module__}.{callback.__qualname__}",
server_name=self.server_name,
):
# Use a copy of the data in case the module modifies it
limit_copy = MediaUploadLimit(
max_bytes=limit.max_bytes, time_period_ms=limit.time_period_ms
)
await delay_cancellation(
callback(user_id, limit_copy, sent_bytes, attempted_bytes)
)

View File

@@ -532,6 +532,7 @@ class Notifier:
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
],
new_token: int,
users: Optional[Collection[Union[str, UserID]]] = None,

View File

@@ -44,6 +44,7 @@ from synapse.replication.tcp.streams import (
UnPartialStatedEventStream,
UnPartialStatedRoomStream,
)
from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
@@ -255,6 +256,12 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated(
row.event_id
)
elif stream_name == ThreadSubscriptionsStream.NAME:
self.notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
token,
users=[row.user_id for row in rows],
)
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows

View File

@@ -23,6 +23,8 @@ import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
import attr
from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import FilterCollection
@@ -632,12 +634,21 @@ class SyncRestServlet(RestServlet):
class SlidingSyncRestServlet(RestServlet):
"""
API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a
API endpoint for MSC4186 Simplified Sliding Sync `/sync`, which was historically derived
from MSC3575 (Sliding Sync; now abandoned). Allows for clients to request a
subset (sliding window) of rooms, state, and timeline events (just what they need)
in order to bootstrap quickly and subscribe to only what the client cares about.
Because the client can specify what it cares about, we can respond quickly and skip
all of the work we would normally have to do with a sync v2 response.
Extensions of various features are defined in:
- to-device messaging (MSC3885)
- end-to-end encryption (MSC3884)
- typing notifications (MSC3961)
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
pos: Stream position token when asking for incremental deltas.
@@ -1074,9 +1085,48 @@ class SlidingSyncRestServlet(RestServlet):
"rooms": extensions.typing.room_id_to_typing_map,
}
# excludes both None and falsy `thread_subscriptions`
if extensions.thread_subscriptions:
serialized_extensions["io.element.msc4308.thread_subscriptions"] = (
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
return serialized_extensions
def _serialise_thread_subscriptions(
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,
) -> JsonDict:
out: JsonDict = {}
if thread_subscriptions.subscribed:
out["subscribed"] = {
room_id: {
thread_root_id: attr.asdict(
change, filter=lambda _attr, v: v is not None
)
for thread_root_id, change in room_threads.items()
}
for room_id, room_threads in thread_subscriptions.subscribed.items()
}
if thread_subscriptions.unsubscribed:
out["unsubscribed"] = {
room_id: {
thread_root_id: attr.asdict(
change, filter=lambda _attr, v: v is not None
)
for thread_root_id, change in room_threads.items()
}
for room_id, room_threads in thread_subscriptions.unsubscribed.items()
}
if thread_subscriptions.prev_batch:
out["prev_batch"] = thread_subscriptions.prev_batch.to_string()
return out
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@@ -1,21 +1,39 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import attr
from typing_extensions import TypeAlias
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, RoomID
from synapse.types import (
JsonDict,
RoomID,
SlidingSyncStreamToken,
ThreadSubscriptionsToken,
)
from synapse.types.handlers.sliding_sync import SlidingSyncResult
from synapse.types.rest import RequestBodyModel
from synapse.util.pydantic_models import AnyEventId
if TYPE_CHECKING:
from synapse.server import HomeServer
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
class ThreadSubscriptionsRestServlet(RestServlet):
PATTERNS = client_patterns(
@@ -100,6 +118,130 @@ class ThreadSubscriptionsRestServlet(RestServlet):
return HTTPStatus.OK, {}
class ThreadSubscriptionsPaginationRestServlet(RestServlet):
PATTERNS = client_patterns(
"/io.element.msc4308/thread_subscriptions$",
unstable=True,
releases=(),
)
CATEGORY = "Thread Subscriptions requests (unstable)"
# Maximum number of thread subscriptions to return in one request.
MAX_LIMIT = 512
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastores().main
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
limit = min(
parse_integer(request, "limit", default=100, negative=False),
ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT,
)
from_end_opt = parse_string(request, "from", required=False)
to_start_opt = parse_string(request, "to", required=False)
_direction = parse_string(request, "dir", required=True, allowed_values=("b",))
if limit <= 0:
# condition needed because `negative=False` still allows 0
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"limit must be greater than 0",
errcode=Codes.INVALID_PARAM,
)
if from_end_opt is not None:
try:
# because of backwards pagination, the `from` token is actually the
# bound closest to the end of the stream
end_stream_id = ThreadSubscriptionsToken.from_string(
from_end_opt
).stream_id
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`from` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
end_stream_id = self.store.get_max_thread_subscriptions_stream_id()
if to_start_opt is not None:
# because of backwards pagination, the `to` token is actually the
# bound closest to the start of the stream
try:
start_stream_id = ThreadSubscriptionsToken.from_string(
to_start_opt
).stream_id
except ValueError:
# we also accept sliding sync `pos` tokens on this parameter
try:
sliding_sync_pos = await SlidingSyncStreamToken.from_string(
self.store, to_start_opt
)
start_stream_id = (
sliding_sync_pos.stream_token.thread_subscriptions_key
)
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`to` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
# the start of time is ID 1; the lower bound is exclusive though
start_stream_id = 0
subscriptions = (
await self.store.get_latest_updated_thread_subscriptions_for_user(
requester.user.to_string(),
from_id=start_stream_id,
to_id=end_stream_id,
limit=limit,
)
)
subscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions:
if subscribed:
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(
_ThreadSubscription(
automatic=automatic,
bump_stamp=stream_id,
)
)
)
else:
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id))
)
result: JsonDict = {}
if subscribed_threads:
result["subscribed"] = subscribed_threads
if unsubscribed_threads:
result["unsubscribed"] = unsubscribed_threads
if len(subscriptions) == limit:
# We hit the limit, so there might be more entries to return.
# Generate a new token that has moved backwards, ready for the next
# request.
min_returned_stream_id, _, _, _, _ = subscriptions[0]
result["end"] = ThreadSubscriptionsToken(
# We subtract one because the 'later in the stream' bound is inclusive,
# and we already saw the element at index 0.
stream_id=min_returned_stream_id - 1
).to_string()
return HTTPStatus.OK, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc4306_enabled:
ThreadSubscriptionsRestServlet(hs).register(http_server)
ThreadSubscriptionsPaginationRestServlet(hs).register(http_server)

View File

@@ -53,7 +53,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -316,17 +316,8 @@ class RelationsWorkerStore(SQLBaseStore):
StreamKeyType.ROOM, next_key
)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
next_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, next_key
)
return events[:limit], next_token

View File

@@ -492,7 +492,7 @@ class PerConnectionStateDB:
"""An equivalent to `PerConnectionState` that holds data in a format stored
in the DB.
The principle difference is that the tokens for the different streams are
The principal difference is that the tokens for the different streams are
serialized to strings.
When persisting this *only* contains updates to the state.

View File

@@ -505,6 +505,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"""
return self._thread_subscriptions_id_gen.get_current_token()
def get_thread_subscriptions_stream_id_generator(self) -> MultiWriterIdGenerator:
return self._thread_subscriptions_id_gen
async def get_updated_thread_subscriptions(
self, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
@@ -538,34 +541,52 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
get_updated_thread_subscriptions_txn,
)
async def get_updated_thread_subscriptions_for_user(
async def get_latest_updated_thread_subscriptions_for_user(
self, user_id: str, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
"""Get the latest updates to thread subscriptions for a specific user.
Args:
user_id: The ID of the user
from_id: The starting stream ID (exclusive)
to_id: The ending stream ID (inclusive)
limit: The maximum number of rows to return
If there are too many rows to return, rows from the start (closer to `from_id`)
will be omitted.
Returns:
A list of (stream_id, room_id, thread_root_event_id) tuples.
A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples.
The row with lowest `stream_id` is the first row.
"""
def get_updated_thread_subscriptions_for_user_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
sql = """
SELECT stream_id, room_id, event_id
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
WITH the_updates AS (
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT ?
)
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM the_updates
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (user_id, from_id, to_id, limit))
return [(row[0], row[1], row[2]) for row in txn]
return [
(
stream_id,
room_id,
event_id,
# SQLite integer to boolean conversions
bool(subscribed),
bool(automatic) if subscribed else None,
)
for (stream_id, room_id, event_id, subscribed, automatic) in txn
]
return await self.db_pool.runInteraction(
"get_updated_thread_subscriptions_for_user",

View File

@@ -0,0 +1,19 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Work around https://github.com/element-hq/synapse/issues/18712 by advancing the
-- stream sequence.
-- This makes last_value of the sequence point to a position that will not get later
-- returned by nextval.
-- (For blank thread subscription streams, this means last_value = 2, nextval() = 3 after this line.)
SELECT nextval('thread_subscriptions_sequence');

View File

@@ -187,8 +187,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
Warning: Streams using this generator start at ID 2, because ID 1 is always assumed
to have been 'seen as persisted'.
Unclear if this extant behaviour is desirable for some reason.
When creating a new sequence for a new stream,
it will be necessary to use `START WITH 2`.
When creating a new sequence for a new stream, it will be necessary to advance it
so that position 1 is consumed.
DO NOT USE `START WITH 2` FOR THIS PURPOSE:
see https://github.com/element-hq/synapse/issues/18712
Instead, use `SELECT nextval('sequence_name');` immediately after the
`CREATE SEQUENCE` statement.
Args:
db_conn

View File

@@ -33,7 +33,6 @@ from synapse.logging.opentracing import trace
from synapse.streams import EventSource
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
StreamKeyType,
StreamToken,
)
@@ -84,6 +83,7 @@ class EventSources:
un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token(
self._instance_name
)
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
token = StreamToken(
room_key=self.sources.room.get_current_key(),
@@ -97,6 +97,7 @@ class EventSources:
# Groups key is unused.
groups_key=0,
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
thread_subscriptions_key=thread_subscriptions_key,
)
return token
@@ -123,6 +124,7 @@ class EventSources:
StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
}
for _, key in StreamKeyType.__members__.items():
@@ -195,16 +197,7 @@ class EventSources:
Returns:
The current token for pagination.
"""
token = StreamToken(
room_key=await self.sources.room.get_current_key_for_room(room_id),
presence_key=0,
typing_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
return StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
await self.sources.room.get_current_key_for_room(room_id),
)
return token

View File

@@ -996,6 +996,7 @@ class StreamKeyType(Enum):
TO_DEVICE = "to_device_key"
DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -1003,7 +1004,7 @@ class StreamToken:
"""A collection of keys joined together by underscores in the following
order and which represent the position in their respective streams.
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379`
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242`
1. `room_key`: `s2633508` which is a `RoomStreamToken`
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
- See the docstring for `RoomStreamToken` for more details.
@@ -1016,6 +1017,7 @@ class StreamToken:
8. `device_list_key`: `265584`
9. `groups_key`: `1` (note that this key is now unused)
10. `un_partial_stated_rooms_key`: `379`
11. `thread_subscriptions_key`: 4242
You can see how many of these keys correspond to the various
fields in a "/sync" response:
@@ -1074,6 +1076,7 @@ class StreamToken:
# Note that the groups key is no longer used and may have bogus values.
groups_key: int
un_partial_stated_rooms_key: int
thread_subscriptions_key: int
_SEPARATOR = "_"
START: ClassVar["StreamToken"]
@@ -1101,6 +1104,7 @@ class StreamToken:
device_list_key,
groups_key,
un_partial_stated_rooms_key,
thread_subscriptions_key,
) = keys
return cls(
@@ -1116,6 +1120,7 @@ class StreamToken:
),
groups_key=int(groups_key),
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
thread_subscriptions_key=int(thread_subscriptions_key),
)
except CancelledError:
raise
@@ -1138,6 +1143,7 @@ class StreamToken:
# if additional tokens are added.
str(self.groups_key),
str(self.un_partial_stated_rooms_key),
str(self.thread_subscriptions_key),
]
)
@@ -1202,6 +1208,7 @@ class StreamToken:
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
],
) -> int: ...
@@ -1257,7 +1264,8 @@ class StreamToken:
f"typing: {self.typing_key}, receipt: {self.receipt_key}, "
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key})"
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
f"thread_subscriptions: {self.thread_subscriptions_key})"
)
@@ -1272,6 +1280,7 @@ StreamToken.START = StreamToken(
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
thread_subscriptions_key=0,
)
@@ -1318,6 +1327,27 @@ class SlidingSyncStreamToken:
return f"{self.connection_position}/{stream_token_str}"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscriptionsToken:
"""
Token for a position in the thread subscriptions stream.
Format: `ts<stream_id>`
"""
stream_id: int
@staticmethod
def from_string(s: str) -> "ThreadSubscriptionsToken":
if not s.startswith("ts"):
raise ValueError("thread subscription token must start with `ts`")
return ThreadSubscriptionsToken(stream_id=int(s[2:]))
def to_string(self) -> str:
return f"ts{self.stream_id}"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedPosition:
"""Position of a newly persisted row with instance that persisted it."""

View File

@@ -50,6 +50,7 @@ from synapse.types import (
SlidingSyncStreamToken,
StrCollection,
StreamToken,
ThreadSubscriptionsToken,
UserID,
)
from synapse.types.rest.client import SlidingSyncBody
@@ -357,11 +358,50 @@ class SlidingSyncResult:
def __bool__(self) -> bool:
return bool(self.room_id_to_typing_map)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscriptionsExtension:
"""The Thread Subscriptions extension (MSC4308)
Attributes:
subscribed: map (room_id -> thread_root_id -> info) of new or changed subscriptions
unsubscribed: map (room_id -> thread_root_id -> info) of new unsubscriptions
prev_batch: if present, there is a gap and the client can use this token to backpaginate
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscription:
# always present when `subscribed`
automatic: Optional[bool]
# the same as our stream_id; useful for clients to resolve
# race conditions locally
bump_stamp: int
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUnsubscription:
# the same as our stream_id; useful for clients to resolve
# race conditions locally
bump_stamp: int
# room_id -> event_id (of thread root) -> the subscription change
subscribed: Optional[Mapping[str, Mapping[str, ThreadSubscription]]]
# room_id -> event_id (of thread root) -> the unsubscription
unsubscribed: Optional[Mapping[str, Mapping[str, ThreadUnsubscription]]]
prev_batch: Optional[ThreadSubscriptionsToken]
def __bool__(self) -> bool:
return (
bool(self.subscribed)
or bool(self.unsubscribed)
or bool(self.prev_batch)
)
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
def __bool__(self) -> bool:
return bool(
@@ -370,6 +410,7 @@ class SlidingSyncResult:
or self.account_data
or self.receipts
or self.typing
or self.thread_subscriptions
)
next_pos: SlidingSyncStreamToken

View File

@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from synapse._pydantic_compat import (
Extra,
Field,
StrictBool,
StrictInt,
StrictStr,
@@ -364,11 +365,25 @@ class SlidingSyncBody(RequestBodyModel):
# Process all room subscriptions defined in the Room Subscription API. (This is the default.)
rooms: Optional[List[StrictStr]] = ["*"]
class ThreadSubscriptionsExtension(RequestBodyModel):
"""The Thread Subscriptions extension (MSC4308)
Attributes:
enabled
limit: maximum number of subscription changes to return (default 100)
"""
enabled: Optional[StrictBool] = False
limit: StrictInt = 100
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
alias="io.element.msc4308.thread_subscriptions"
)
conn_id: Optional[StrictStr]

View File

@@ -347,6 +347,7 @@ T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
@overload
@@ -461,6 +462,23 @@ async def gather_optional_coroutines(
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
Optional[Coroutine[Any, Any, T5]],
Optional[Coroutine[Any, Any, T6]],
]
],
) -> Tuple[
Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5], Optional[T6]
]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
) -> Tuple[Optional[T1], ...]:

View File

@@ -29,6 +29,11 @@ import sys
from types import FrameType, TracebackType
from typing import NoReturn, Optional, Type
from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
)
def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
"""daemonize the current process
@@ -64,8 +69,14 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
pid_fh.write(old_pid)
sys.exit(1)
# Fork, creating a new process for the child.
process_id = os.fork()
# Stop the existing context *before* we fork the process. Otherwise the cputime
# metrics get confused about the per-thread resource usage appearing to go backwards
# because we're comparing the resource usage from the original process to the forked
# process. `PreserveLoggingContext` already takes care of restarting the original
# context *after* the block.
with PreserveLoggingContext():
# Fork, creating a new process for the child.
process_id = os.fork()
if process_id != 0:
# parent process: exit.
@@ -140,9 +151,10 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
# Cleanup pid file at exit.
def exit() -> None:
logger.warning("Stopping daemon.")
os.remove(pid_file)
sys.exit(0)
with LoggingContext("atexit"):
logger.warning("Stopping daemon.")
os.remove(pid_file)
sys.exit(0)
atexit.register(exit)

View File

@@ -37,4 +37,7 @@ class HomeserverAppStartTestCase(ConfigFileTestCase):
self.add_lines_to_config([" main:", " host: 127.0.0.1", " port: 1234"])
# Ensure that starting master process with worker config raises an exception
with self.assertRaises(ConfigError):
synapse.app.homeserver.setup(["-c", self.config_file])
homeserver_config = synapse.app.homeserver.load_or_generate_config(
["-c", self.config_file]
)
synapse.app.homeserver.setup(homeserver_config)

View File

@@ -112,4 +112,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase):
# Test that allowing open registration without verification raises an error
with self.assertRaises(ConfigError):
synapse.app.homeserver.setup(["-c", self.config_file])
homeserver_config = synapse.app.homeserver.load_or_generate_config(
["-c", self.config_file]
)
synapse.app.homeserver.setup(homeserver_config)

View File

@@ -2244,7 +2244,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_topo_token_is_accepted(self) -> None:
"""Test Topo Token is accepted."""
token = "t1-0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
@@ -2258,7 +2258,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
"""Test that stream token is accepted for forward pagination."""
token = "s0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),

View File

@@ -0,0 +1,497 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from http import HTTPStatus
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.rest.client import login, room, sync, thread_subscriptions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
# The name of the extension. Currently unstable-prefixed.
EXT_NAME = "io.element.msc4308.thread_subscriptions"
class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase):
"""
Test the thread subscriptions extension in the Sliding Sync API.
"""
maxDiff = None
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
thread_subscriptions.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc4306_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
super().prepare(reactor, clock, hs)
def test_no_data_initial_sync(self) -> None:
"""
Test enabling thread subscriptions extension during initial sync with no data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertNotIn(EXT_NAME, response_body["extensions"])
def test_no_data_incremental_sync(self) -> None:
"""
Test enabling thread subscriptions extension during incremental sync with no data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
initial_sync_body: JsonDict = {
"lists": {},
}
# Initial sync
response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok)
# Incremental sync with extension enabled
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
# Assert
self.assertNotIn(
EXT_NAME,
response_body["extensions"],
response_body,
)
def test_thread_subscription_initial_sync(self) -> None:
"""
Test thread subscriptions appear in initial sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {
"automatic": False,
"bump_stamp": base + 1,
}
}
}
},
)
def test_thread_subscription_incremental_sync(self) -> None:
"""
Test new thread subscriptions appear in incremental sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Initial sync
_, sync_pos = self.do_sync(sync_body, tok=user1_tok)
logger.info("Synced to: %r, now subscribing to thread", sync_pos)
# Subscribe
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
# Incremental sync
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
logger.info("Synced to: %r", sync_pos)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {
"automatic": False,
"bump_stamp": base + 1,
}
}
}
},
)
def test_unsubscribe_from_thread(self) -> None:
"""
Test unsubscribing from a thread.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok)
# Assert: Subscription present
self.assertIn(EXT_NAME, response_body["extensions"])
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {"automatic": False, "bump_stamp": base + 1}
}
}
},
)
# Unsubscribe
self._unsubscribe_from_thread(user1_id, room_id, thread_root_id)
# Incremental sync
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
# Assert: Unsubscription present
self.assertEqual(
response_body["extensions"][EXT_NAME],
{"unsubscribed": {room_id: {thread_root_id: {"bump_stamp": base + 2}}}},
)
def test_multiple_thread_subscriptions(self) -> None:
"""
Test handling of multiple thread subscriptions.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create thread roots
thread_root_resp1 = self.helper.send(
room_id, body="Thread root 1", tok=user1_tok
)
thread_root_id1 = thread_root_resp1["event_id"]
thread_root_resp2 = self.helper.send(
room_id, body="Thread root 2", tok=user1_tok
)
thread_root_id2 = thread_root_resp2["event_id"]
thread_root_resp3 = self.helper.send(
room_id, body="Thread root 3", tok=user1_tok
)
thread_root_id3 = thread_root_resp3["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Subscribe to threads
self._subscribe_to_thread(user1_id, room_id, thread_root_id1)
self._subscribe_to_thread(user1_id, room_id, thread_root_id2)
self._subscribe_to_thread(user1_id, room_id, thread_root_id3)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id1: {
"automatic": False,
"bump_stamp": base + 1,
},
thread_root_id2: {
"automatic": False,
"bump_stamp": base + 2,
},
thread_root_id3: {
"automatic": False,
"bump_stamp": base + 3,
},
}
}
},
)
def test_limit_parameter(self) -> None:
"""
Test limit parameter in thread subscriptions extension.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create 5 thread roots and subscribe to each
thread_root_ids = []
for i in range(5):
thread_root_resp = self.helper.send(
room_id, body=f"Thread root {i}", tok=user1_tok
)
thread_root_ids.append(thread_root_resp["event_id"])
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
sync_body = {
"lists": {},
"extensions": {EXT_NAME: {"enabled": True, "limit": 3}},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
thread_subscriptions = response_body["extensions"][EXT_NAME]
self.assertEqual(
len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions
)
def test_limit_and_companion_backpagination(self) -> None:
"""
Create 1 thread subscription, do a sync, create 4 more,
then sync with a limit of 2 and fill in the gap
using the companion /thread_subscriptions endpoint.
"""
thread_root_ids: List[str] = []
def make_subscription() -> None:
thread_root_resp = self.helper.send(
room_id, body="Some thread root", tok=user1_tok
)
thread_root_ids.append(thread_root_resp["event_id"])
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Make our first subscription
make_subscription()
# Sync for the first time
sync_body = {
"lists": {},
"extensions": {EXT_NAME: {"enabled": True, "limit": 2}},
}
sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok)
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1},
}
},
)
# Get our pos for the next sync
first_sync_pos = sync_resp["pos"]
# Create 5 more thread subscriptions and subscribe to each
for _ in range(5):
make_subscription()
# Now sync again. Our limit is 2,
# so we should get the latest 2 subscriptions,
# with a gap of 3 more subscriptions in the middle
sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos)
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5},
thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6},
}
},
)
# 1st backpagination: expecting a page with 2 subscriptions
page, end_tok = self._do_backpaginate(
from_tok=thread_subscriptions["prev_batch"],
to_tok=first_sync_pos,
limit=2,
access_token=user1_tok,
)
self.assertIsNotNone(end_tok, "backpagination should continue")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3},
thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4},
}
},
)
# 2nd backpagination: expecting a page with only 1 subscription
# and no other token for further backpagination
assert end_tok is not None
page, end_tok = self._do_backpaginate(
from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok
)
self.assertIsNone(end_tok, "backpagination should have finished")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2},
}
},
)
def _do_backpaginate(
self, *, from_tok: str, to_tok: str, limit: int, access_token: str
) -> Tuple[JsonDict, Optional[str]]:
channel = self.make_request(
"GET",
"/_matrix/client/unstable/io.element.msc4308/thread_subscriptions"
f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b",
access_token=access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
body = channel.json_body
return body, cast(Optional[str], body.get("end"))
def _subscribe_to_thread(
self, user_id: str, room_id: str, thread_root_id: str
) -> None:
"""
Helper method to subscribe a user to a thread.
"""
self.get_success(
self.store.subscribe_user_to_thread(
user_id=user_id,
room_id=room_id,
thread_root_event_id=thread_root_id,
automatic_event_orderings=None,
)
)
def _unsubscribe_from_thread(
self, user_id: str, room_id: str, thread_root_id: str
) -> None:
"""
Helper method to unsubscribe a user from a thread.
"""
self.get_success(
self.store.unsubscribe_user_from_thread(
user_id=user_id,
room_id=room_id,
thread_root_event_id=thread_root_id,
)
)

View File

@@ -46,6 +46,7 @@ from twisted.web.resource import Resource
from synapse.api.errors import HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.config._base import Config
from synapse.config.oembed import OEmbedEndpointConfig
from synapse.http.client import MultipartResponse
from synapse.http.types import QueryParams
@@ -53,6 +54,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
from synapse.module_api import MediaUploadLimit
from synapse.rest import admin
from synapse.rest.client import login, media
from synapse.server import HomeServer
@@ -2967,3 +2969,192 @@ class MediaUploadLimits(unittest.HomeserverTestCase):
# This will succeed as the weekly limit has reset
channel = self.upload_media(900)
self.assertEqual(channel.code, 200)
class MediaUploadLimitsModuleOverrides(unittest.HomeserverTestCase):
"""
This test case simulates a homeserver with media upload limits being overridden by the module API.
"""
servlets = [
media.register_servlets,
login.register_servlets,
admin.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
# default limits to use
config["media_upload_limits"] = [
{"time_period": "1d", "max_size": "1K"},
{"time_period": "1w", "max_size": "3K"},
]
return self.setup_test_homeserver(config=config)
async def _get_media_upload_limits_for_user(
self,
user_id: str,
) -> Optional[List[MediaUploadLimit]]:
# user1 has custom limits
if user_id == self.user1:
# n.b. we return these in increasing duration order and Synapse will need to sort them correctly
return [
MediaUploadLimit(
time_period_ms=Config.parse_duration("1d"), max_bytes=5000
),
MediaUploadLimit(
time_period_ms=Config.parse_duration("1w"), max_bytes=15000
),
]
# user2 has no limits
if user_id == self.user2:
return []
# otherwise use default
return None
async def _on_media_upload_limit_exceeded(
self,
user_id: str,
limit: MediaUploadLimit,
sent_bytes: int,
attempted_bytes: int,
) -> None:
self.last_media_upload_limit_exceeded: Optional[dict[str, object]] = {
"user_id": user_id,
"limit": limit,
"sent_bytes": sent_bytes,
"attempted_bytes": attempted_bytes,
}
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
self.store = hs.get_datastores().main
self.user1 = self.register_user("user1", "pass")
self.tok1 = self.login("user1", "pass")
self.user2 = self.register_user("user2", "pass")
self.tok2 = self.login("user2", "pass")
self.user3 = self.register_user("user3", "pass")
self.tok3 = self.login("user3", "pass")
self.last_media_upload_limit_exceeded = None
self.hs.get_module_api().register_media_repository_callbacks(
get_media_upload_limits_for_user=self._get_media_upload_limits_for_user,
on_media_upload_limit_exceeded=self._on_media_upload_limit_exceeded,
)
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def upload_media(self, size: int, tok: str) -> FakeChannel:
"""Helper to upload media of a given size with a given token."""
return self.make_request(
"POST",
"/_matrix/media/v3/upload",
content=b"0" * size,
access_token=tok,
shorthand=False,
content_type=b"text/plain",
custom_headers=[("Content-Length", str(size))],
)
def test_upload_under_limit(self) -> None:
"""Test that uploading media under the limit works."""
# User 1 uploads 100 bytes
channel = self.upload_media(100, self.tok1)
self.assertEqual(channel.code, 200)
# User 2 (unlimited) uploads 100 bytes
channel = self.upload_media(100, self.tok2)
self.assertEqual(channel.code, 200)
# User 3 (default) uploads 100 bytes
channel = self.upload_media(100, self.tok3)
self.assertEqual(channel.code, 200)
self.assertEqual(self.last_media_upload_limit_exceeded, None)
def test_uses_custom_limit(self) -> None:
"""Test that uploading media over the module provided daily limit fails."""
# User 1 uploads 3000 bytes
channel = self.upload_media(3000, self.tok1)
self.assertEqual(channel.code, 200)
# User 1 attempts to upload 4000 bytes taking it over the limit
channel = self.upload_media(4000, self.tok1)
self.assertEqual(channel.code, 400)
assert self.last_media_upload_limit_exceeded is not None
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user1)
self.assertEqual(
self.last_media_upload_limit_exceeded["limit"],
MediaUploadLimit(
max_bytes=5000, time_period_ms=Config.parse_duration("1d")
),
)
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 3000)
self.assertEqual(self.last_media_upload_limit_exceeded["attempted_bytes"], 4000)
# User 1 attempts to upload 20000 bytes which is over the weekly limit
# This tests that the limits have been sorted as expected
channel = self.upload_media(20000, self.tok1)
self.assertEqual(channel.code, 400)
assert self.last_media_upload_limit_exceeded is not None
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user1)
self.assertEqual(
self.last_media_upload_limit_exceeded["limit"],
MediaUploadLimit(
max_bytes=15000, time_period_ms=Config.parse_duration("1w")
),
)
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 3000)
self.assertEqual(
self.last_media_upload_limit_exceeded["attempted_bytes"], 20000
)
def test_uses_unlimited(self) -> None:
"""Test that unlimited user is not limited when module returns []."""
# User 2 uploads 10000 bytes which is over the default limit
channel = self.upload_media(10000, self.tok2)
self.assertEqual(channel.code, 200)
self.assertEqual(self.last_media_upload_limit_exceeded, None)
def test_uses_defaults(self) -> None:
"""Test that the default limits are applied when module returned None."""
# User 3 uploads 500 bytes
channel = self.upload_media(500, self.tok3)
self.assertEqual(channel.code, 200)
# User 3 uploads 800 bytes which is over the limit
channel = self.upload_media(800, self.tok3)
self.assertEqual(channel.code, 400)
assert self.last_media_upload_limit_exceeded is not None
self.assertEqual(self.last_media_upload_limit_exceeded["user_id"], self.user3)
self.assertEqual(
self.last_media_upload_limit_exceeded["limit"],
MediaUploadLimit(
max_bytes=1024, time_period_ms=Config.parse_duration("1d")
),
)
self.assertEqual(self.last_media_upload_limit_exceeded["sent_bytes"], 500)
self.assertEqual(self.last_media_upload_limit_exceeded["attempted_bytes"], 800)

View File

@@ -2245,7 +2245,7 @@ class RoomMessageListTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self) -> None:
token = "t1-0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
@@ -2256,7 +2256,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("end" in channel.json_body)
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
token = "s0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)

View File

@@ -189,19 +189,19 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
limit=50,
)
)
min_id = min(id for (id, _, _) in subscriptions)
min_id = min(id for (id, _, _, _, _) in subscriptions)
self.assertEqual(
subscriptions,
[
(min_id, self.room_id, self.thread_root_id),
(min_id + 1, self.room_id, self.other_thread_root_id),
(min_id, self.room_id, self.thread_root_id, True, True),
(min_id + 1, self.room_id, self.other_thread_root_id, True, False),
],
)
@@ -212,7 +212,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Check user has no subscriptions
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
@@ -280,20 +280,22 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
self.assertEqual(
updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)]
)
# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)]
)
def test_should_skip_autosubscription_after_unsubscription(self) -> None: