Compare commits

..

2 Commits

Author SHA1 Message Date
Olivier Wilkinson (reivilibre)
5d7c35b4d9 Tweak changelog 2022-12-06 11:58:15 +00:00
Olivier Wilkinson (reivilibre)
dc6b60f68d 1.73.0 2022-12-06 11:49:37 +00:00
89 changed files with 692 additions and 1774 deletions

View File

@@ -208,7 +208,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: JasonEtco/create-an-issue@77399b6110ef82b94c1c9f9f615acf9e604f7f56 # v2.5.0, 2020-12-06
- uses: JasonEtco/create-an-issue@5d9504915f79f9cc6d791934b8ef34f2353dd74d # v2.5.0, 2020-12-06
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:

View File

@@ -109,29 +109,7 @@ jobs:
components: clippy
- uses: Swatinem/rust-cache@v2
- run: cargo clippy -- -D warnings
# We also lint against a nightly rustc so that we can lint the benchmark
# suite, which requires a nightly compiler.
lint-clippy-nightly:
runs-on: ubuntu-latest
needs: changes
if: ${{ needs.changes.outputs.rust == 'true' }}
steps:
- uses: actions/checkout@v3
- name: Install Rust
# There don't seem to be versioned releases of this action per se: for each rust
# version there is a branch which gets constantly rebased on top of master.
# We pin to a specific commit for paranoia's sake.
uses: dtolnay/rust-toolchain@e645b0cf01249a964ec099494d38d2da0f0b349f
with:
toolchain: nightly-2022-12-01
components: clippy
- uses: Swatinem/rust-cache@v2
- run: cargo clippy --all-features -- -D warnings
- run: cargo clippy
lint-rustfmt:
runs-on: ubuntu-latest

View File

@@ -174,7 +174,7 @@ jobs:
steps:
- uses: actions/checkout@v3
- uses: JasonEtco/create-an-issue@77399b6110ef82b94c1c9f9f615acf9e604f7f56 # v2.5.0, 2020-12-06
- uses: JasonEtco/create-an-issue@5d9504915f79f9cc6d791934b8ef34f2353dd74d # v2.5.0, 2020-12-06
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:

View File

@@ -1,8 +1,14 @@
Synapse 1.73.0rc2 (2022-12-01)
==============================
Synapse 1.73.0 (2022-12-06)
===========================
Please note that legacy Prometheus metric names have been removed in this release; see [the upgrade notes](https://github.com/matrix-org/synapse/blob/release-v1.73/docs/upgrade.md#legacy-prometheus-metric-names-have-now-been-removed) for more details.
No significant changes since 1.73.0rc2.
Synapse 1.73.0rc2 (2022-12-01)
==============================
Bugfixes
--------
@@ -17,7 +23,7 @@ Features
- Speed-up `/messages` with `filter_events_for_client` optimizations. ([\#14527](https://github.com/matrix-org/synapse/issues/14527))
- Improve DB performance by reducing amount of data that gets read in `device_lists_changes_in_room`. ([\#14534](https://github.com/matrix-org/synapse/issues/14534))
- Adds support for handling avatar in SSO login. Contributed by @ashfame. ([\#13917](https://github.com/matrix-org/synapse/issues/13917))
- Adds support for handling avatar in SSO OIDC login. Contributed by @ashfame. ([\#13917](https://github.com/matrix-org/synapse/issues/13917))
- Move MSC3030 `/timestamp_to_event` endpoints to stable `v1` location (`/_matrix/client/v1/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction>`, `/_matrix/federation/v1/timestamp_to_event/<roomID>?ts=<timestamp>&dir=<direction>`). ([\#14471](https://github.com/matrix-org/synapse/issues/14471))
- Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.5/client-server-api/#aggregations) which return bundled aggregations. ([\#14491](https://github.com/matrix-org/synapse/issues/14491), [\#14508](https://github.com/matrix-org/synapse/issues/14508), [\#14510](https://github.com/matrix-org/synapse/issues/14510))
- Add unstable support for an Extensible Events room version (`org.matrix.msc1767.10`) via [MSC1767](https://github.com/matrix-org/matrix-spec-proposals/pull/1767), [MSC3931](https://github.com/matrix-org/matrix-spec-proposals/pull/3931), [MSC3932](https://github.com/matrix-org/matrix-spec-proposals/pull/3932), and [MSC3933](https://github.com/matrix-org/matrix-spec-proposals/pull/3933). ([\#14520](https://github.com/matrix-org/synapse/issues/14520), [\#14521](https://github.com/matrix-org/synapse/issues/14521), [\#14524](https://github.com/matrix-org/synapse/issues/14524))

View File

@@ -3,7 +3,3 @@
[workspace]
members = ["rust"]
[profile.dbgrelease]
inherits = "release"
debug = true

View File

@@ -1 +0,0 @@
Optimise push badge count calculations. Contributed by Nick @ Beeper (@fizzadar).

View File

@@ -1 +0,0 @@
Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`.

View File

@@ -1 +0,0 @@
Update worker settings for `pusher` and `federation_sender` functionality.

View File

@@ -1 +0,0 @@
Add links to third party package repositories, and point to the bug which highlights Ubuntu's out-of-date packages.

View File

@@ -1 +0,0 @@
Stop using deprecated `keyIds` parameter when calling `/_matrix/key/v2/server`.

View File

@@ -1 +0,0 @@
Share the `ClientRestResource` for both workers and the main process.

View File

@@ -1 +0,0 @@
Faster joins: use servers list approximation to send read receipts when in partial state instead of waiting for the full state of the room.

View File

@@ -1 +0,0 @@
Add new `push.enabled` config option to allow opting out of push notification calculation.

View File

@@ -1 +0,0 @@
Modernize unit tests configuration related to workers.

View File

@@ -1 +0,0 @@
Advertise support for Matrix 1.5 on `/_matrix/client/versions`.

View File

@@ -1 +0,0 @@
Bump jsonschema from 4.17.0 to 4.17.3.

View File

@@ -1 +0,0 @@
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.

View File

@@ -1 +0,0 @@
Add missing type hints.

View File

@@ -1 +0,0 @@
Fix Rust lint CI.

View File

@@ -1 +0,0 @@
Bump JasonEtco/create-an-issue from 2.5.0 to 2.8.1.

6
debian/changelog vendored
View File

@@ -1,3 +1,9 @@
matrix-synapse-py3 (1.73.0) stable; urgency=medium
* New Synapse release 1.73.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 06 Dec 2022 11:48:56 +0000
matrix-synapse-py3 (1.73.0~rc2) stable; urgency=medium
* New Synapse release 1.73.0rc2.

View File

@@ -84,9 +84,7 @@ file when you upgrade the Debian package to a later version.
##### Downstream Debian packages
Andrej Shadura maintains a
[`matrix-synapse`](https://packages.debian.org/sid/matrix-synapse) package in
the Debian repositories.
Andrej Shadura maintains a `matrix-synapse` package in the Debian repositories.
For `bookworm` and `sid`, it can be installed simply with:
```sh
@@ -102,27 +100,23 @@ for information on how to use backports.
##### Downstream Ubuntu packages
We do not recommend using the packages in the default Ubuntu repository
at this time, as they are [old and suffer from known security vulnerabilities](
https://bugs.launchpad.net/ubuntu/+source/matrix-synapse/+bug/1848709
).
at this time, as they are old and suffer from known security vulnerabilities.
The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
#### Fedora
Synapse is in the Fedora repositories as
[`matrix-synapse`](https://src.fedoraproject.org/rpms/matrix-synapse):
Synapse is in the Fedora repositories as `matrix-synapse`:
```sh
sudo dnf install matrix-synapse
```
Additionally, Oleg Girko provides Fedora RPMs at
Oleg Girko provides Fedora RPMs at
<https://obs.infoserver.lv/project/monitor/matrix-synapse>
#### OpenSUSE
Synapse is in the OpenSUSE repositories as
[`matrix-synapse`](https://software.opensuse.org/package/matrix-synapse):
Synapse is in the OpenSUSE repositories as `matrix-synapse`:
```sh
sudo zypper install matrix-synapse
@@ -157,8 +151,7 @@ sudo pip install py-bcrypt
#### Void Linux
Synapse can be found in the void repositories as
['synapse'](https://github.com/void-linux/void-packages/tree/master/srcpkgs/synapse):
Synapse can be found in the void repositories as 'synapse':
```sh
xbps-install -Su

View File

@@ -858,7 +858,7 @@ which are older than the room's maximum retention period. Synapse will also
filter events received over federation so that events that should have been
purged are ignored and not stored again.
The message retention policies feature is disabled by default. Please be advised
The message retention policies feature is disabled by default. Please be advised
that enabling this feature carries some risk. There are known bugs with the implementation
which can cause database corruption. Setting retention to delete older history
is less risky than deleting newer history but in general caution is advised when enabling this
@@ -3003,7 +3003,7 @@ Options for each entry include:
which is set to the claims returned by the UserInfo Endpoint and/or
in the ID Token.
* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications.
* `backchannel_logout_enabled`: set to `true` to process OIDC Back-Channel Logout notifications.
Those notifications are expected to be received on `/_synapse/client/oidc/backchannel_logout`.
Defaults to `false`.
@@ -3355,10 +3355,6 @@ Configuration settings related to push notifications
This setting defines options for push notifications.
This option has a number of sub-options. They are as follows:
* `enable_push`: Enables or disables push notification calculation. Note, disabling this will also
stop unread counts being calculated for rooms. This mode of operation is intended
for homeservers which may only have bots or appservice users connected, or are otherwise
not interested in push/unread counters. This is enabled by default.
* `include_content`: Clients requesting push notifications can either have the body of
the message sent in the notification poke along with other details
like the sender, or just the event ID and room ID (`event_id_only`).
@@ -3379,7 +3375,6 @@ This option has a number of sub-options. They are as follows:
Example configuration:
```yaml
push:
enable_push: true
include_content: false
group_unread_count_by_room: false
```
@@ -3425,7 +3420,7 @@ This option has the following sub-options:
NB. If you set this to true, and the last time the user_directory search
indexes were (re)built was before Synapse 1.44, you'll have to
rebuild the indexes in order to search through all known users.
These indexes are built the first time Synapse starts; admins can
manually trigger a rebuild via the API following the instructions
[for running background updates](../administration/admin_api/background_updates.md#run),
@@ -3684,7 +3679,7 @@ As a result, the worker configuration is divided into two parts.
1. The first part (in this section of the manual) defines which shardable tasks
are delegated to privileged workers. This allows unprivileged workers to make
requests to a privileged worker to act on their behalf.
request a privileged worker to act on their behalf.
1. [The second part](#individual-worker-configuration)
controls the behaviour of individual workers in isolation.
@@ -3696,7 +3691,7 @@ For guidance on setting up workers, see the [worker documentation](../../workers
A shared secret used by the replication APIs on the main process to authenticate
HTTP requests from workers.
The default, this value is omitted (equivalently `null`), which means that
The default, this value is omitted (equivalently `null`), which means that
traffic between the workers and the main process is not authenticated.
Example configuration:
@@ -3706,8 +3701,6 @@ worker_replication_secret: "secret_secret"
---
### `start_pushers`
Unnecessary to set if using [`pusher_instances`](#pusher_instances) with [`generic_workers`](../../workers.md#synapseappgeneric_worker).
Controls sending of push notifications on the main process. Set to `false`
if using a [pusher worker](../../workers.md#synapseapppusher). Defaults to `true`.
@@ -3718,30 +3711,25 @@ start_pushers: false
---
### `pusher_instances`
It is possible to scale the processes that handle sending push notifications to [sygnal](https://github.com/matrix-org/sygnal)
and email by running a [`generic_worker`](../../workers.md#synapseappgeneric_worker) and adding it's [`worker_name`](#worker_name) to
a `pusher_instances` map. Doing so will remove handling of this function from the main
process. Multiple workers can be added to this map, in which case the work is balanced
across them. Ensure the main process and all pusher workers are restarted after changing
this option.
It is possible to run multiple [pusher workers](../../workers.md#synapseapppusher),
in which case the work is balanced across them. Use this setting to list the pushers by
[`worker_name`](#worker_name). Ensure the main process and all pusher workers are
restarted after changing this option.
Example configuration for a single worker:
```yaml
pusher_instances:
- pusher_worker1
```
And for multiple workers:
If no or only one pusher worker is configured, this setting is not necessary.
The main process will send out push notifications by default if you do not disable
it by setting [`start_pushers: false`](#start_pushers).
Example configuration:
```yaml
start_pushers: false
pusher_instances:
- pusher_worker1
- pusher_worker2
```
---
### `send_federation`
Unnecessary to set if using [`federation_sender_instances`](#federation_sender_instances) with [`generic_workers`](../../workers.md#synapseappgeneric_worker).
Controls sending of outbound federation transactions on the main process.
Set to `false` if using a [federation sender worker](../../workers.md#synapseappfederation_sender).
Defaults to `true`.
@@ -3753,36 +3741,29 @@ send_federation: false
---
### `federation_sender_instances`
It is possible to scale the processes that handle sending outbound federation requests
by running a [`generic_worker`](../../workers.md#synapseappgeneric_worker) and adding it's [`worker_name`](#worker_name) to
a `federation_sender_instances` map. Doing so will remove handling of this function from
the main process. Multiple workers can be added to this map, in which case the work is
balanced across them.
It is possible to run multiple
[federation sender worker](../../workers.md#synapseappfederation_sender), in which
case the work is balanced across them. Use this setting to list the senders.
This configuration setting must be shared between all workers handling federation
sending, and if changed all federation sender workers must be stopped at the same time
and then started, to ensure that all instances are running with the same config (otherwise
This configuration setting must be shared between all federation sender workers, and if
changed all federation sender workers must be stopped at the same time and then
started, to ensure that all instances are running with the same config (otherwise
events may be dropped).
Example configuration for a single worker:
Example configuration:
```yaml
send_federation: false
federation_sender_instances:
- federation_sender1
```
And for multiple workers:
```yaml
federation_sender_instances:
- federation_sender1
- federation_sender2
```
---
### `instance_map`
When using workers this should be a map from [`worker_name`](#worker_name) to the
HTTP replication listener of the worker, if configured.
Each worker declared under [`stream_writers`](../../workers.md#stream-writers) needs
Each worker declared under [`stream_writers`](../../workers.md#stream-writers) needs
a HTTP replication listener, and that listener should be included in the `instance_map`.
(The main process also needs an HTTP replication listener, but it should not be
(The main process also needs an HTTP replication listener, but it should not be
listed in the `instance_map`.)
Example configuration:
@@ -3916,8 +3897,8 @@ worker_replication_http_tls: true
---
### `worker_listeners`
A worker can handle HTTP requests. To do so, a `worker_listeners` option
must be declared, in the same way as the [`listeners` option](#listeners)
A worker can handle HTTP requests. To do so, a `worker_listeners` option
must be declared, in the same way as the [`listeners` option](#listeners)
in the shared config.
Workers declared in [`stream_writers`](#stream_writers) will need to include a
@@ -3936,7 +3917,7 @@ worker_listeners:
### `worker_daemonize`
Specifies whether the worker should be started as a daemon process.
If Synapse is being managed by [systemd](../../systemd-with-workers/README.md), this option
If Synapse is being managed by [systemd](../../systemd-with-workers/README.md), this option
must be omitted or set to `false`.
Defaults to `false`.
@@ -3948,11 +3929,11 @@ worker_daemonize: true
---
### `worker_pid_file`
When running a worker as a daemon, we need a place to store the
When running a worker as a daemon, we need a place to store the
[PID](https://en.wikipedia.org/wiki/Process_identifier) of the worker.
This option defines the location of that "pid file".
This option is required if `worker_daemonize` is `true` and ignored
This option is required if `worker_daemonize` is `true` and ignored
otherwise. It has no default.
See also the [`pid_file` option](#pid_file) option for the main Synapse process.
@@ -4002,3 +3983,4 @@ background_updates:
min_batch_size: 10
default_batch_size: 50
```

View File

@@ -505,9 +505,6 @@ worker application type.
### `synapse.app.pusher`
It is likely this option will be deprecated in the future and is not recommended for new
installations. Instead, [use `synapse.app.generic_worker` with the `pusher_instances`](usage/configuration/config_documentation.md#pusher_instances).
Handles sending push notifications to sygnal and email. Doesn't handle any
REST endpoints itself, but you should set
[`start_pushers: false`](usage/configuration/config_documentation.md#start_pushers) in the
@@ -546,9 +543,6 @@ Note this worker cannot be load-balanced: only one instance should be active.
### `synapse.app.federation_sender`
It is likely this option will be deprecated in the future and not recommended for
new installations. Instead, [use `synapse.app.generic_worker` with the `federation_sender_instances`](usage/configuration/config_documentation.md#federation_sender_instances).
Handles sending federation traffic to other servers. Doesn't handle any
REST endpoints itself, but you should set
[`send_federation: false`](usage/configuration/config_documentation.md#send_federation)
@@ -645,9 +639,7 @@ equivalent to `synapse.app.generic_worker`:
* `synapse.app.client_reader`
* `synapse.app.event_creator`
* `synapse.app.federation_reader`
* `synapse.app.federation_sender`
* `synapse.app.frontend_proxy`
* `synapse.app.pusher`
* `synapse.app.synchrotron`

View File

@@ -59,6 +59,16 @@ exclude = (?x)
|tests/server_notices/test_resource_limits_server_notices.py
|tests/test_state.py
|tests/test_terms_auth.py
|tests/util/test_async_helpers.py
|tests/util/test_batching_queue.py
|tests/util/test_dict_cache.py
|tests/util/test_expiring_cache.py
|tests/util/test_file_consumer.py
|tests/util/test_linearizer.py
|tests/util/test_logcontext.py
|tests/util/test_lrucache.py
|tests/util/test_rwlock.py
|tests/util/test_wheel_timer.py
)$
[mypy-synapse.federation.transport.client]
@@ -127,9 +137,6 @@ disallow_untyped_defs = True
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False
[mypy-tests.util.*]
disallow_untyped_defs = True
[mypy-tests.utils]
disallow_untyped_defs = True

34
poetry.lock generated
View File

@@ -452,7 +452,7 @@ i18n = ["Babel (>=2.7)"]
[[package]]
name = "jsonschema"
version = "4.17.3"
version = "4.17.0"
description = "An implementation of JSON Schema validation for Python"
category = "main"
optional = false
@@ -888,17 +888,17 @@ tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]]
name = "pyopenssl"
version = "22.1.0"
version = "22.0.0"
description = "Python wrapper module around the OpenSSL library"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
cryptography = ">=38.0.0,<39"
cryptography = ">=35.0"
[package.extras]
docs = ["sphinx (!=5.2.0,!=5.2.0.post0)", "sphinx-rtd-theme"]
docs = ["sphinx", "sphinx-rtd-theme"]
test = ["flaky", "pretend", "pytest (>=3.0.1)"]
[[package]]
@@ -1076,7 +1076,7 @@ doc = ["Sphinx", "sphinx-rtd-theme"]
[[package]]
name = "sentry-sdk"
version = "1.11.1"
version = "1.11.0"
description = "Python client for Sentry (https://sentry.io)"
category = "main"
optional = true
@@ -1380,7 +1380,7 @@ python-versions = ">=3.6"
[[package]]
name = "types-bleach"
version = "5.0.3.1"
version = "5.0.3"
description = "Typing stubs for bleach"
category = "dev"
optional = false
@@ -1448,7 +1448,7 @@ python-versions = "*"
[[package]]
name = "types-psycopg2"
version = "2.9.21.2"
version = "2.9.21.1"
description = "Typing stubs for psycopg2"
category = "dev"
optional = false
@@ -2013,8 +2013,8 @@ jinja2 = [
{file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
]
jsonschema = [
{file = "jsonschema-4.17.3-py3-none-any.whl", hash = "sha256:a870ad254da1a8ca84b6a2905cac29d265f805acc57af304784962a2aa6508f6"},
{file = "jsonschema-4.17.3.tar.gz", hash = "sha256:0f864437ab8b6076ba6707453ef8f98a6a0d512a80e93f8abdb676f737ecb60d"},
{file = "jsonschema-4.17.0-py3-none-any.whl", hash = "sha256:f660066c3966db7d6daeaea8a75e0b68237a48e51cf49882087757bb59916248"},
{file = "jsonschema-4.17.0.tar.gz", hash = "sha256:5bfcf2bca16a087ade17e02b282d34af7ccd749ef76241e7f9bd7c0cb8a9424d"},
]
keyring = [
{file = "keyring-23.5.0-py3-none-any.whl", hash = "sha256:b0d28928ac3ec8e42ef4cc227822647a19f1d544f21f96457965dc01cf555261"},
@@ -2452,8 +2452,8 @@ pynacl = [
{file = "PyNaCl-1.5.0.tar.gz", hash = "sha256:8ac7448f09ab85811607bdd21ec2464495ac8b7c66d146bf545b0f08fb9220ba"},
]
pyopenssl = [
{file = "pyOpenSSL-22.1.0-py3-none-any.whl", hash = "sha256:b28437c9773bb6c6958628cf9c3bebe585de661dba6f63df17111966363dd15e"},
{file = "pyOpenSSL-22.1.0.tar.gz", hash = "sha256:7a83b7b272dd595222d672f5ce29aa030f1fb837630ef229f62e72e395ce8968"},
{file = "pyOpenSSL-22.0.0-py2.py3-none-any.whl", hash = "sha256:ea252b38c87425b64116f808355e8da644ef9b07e429398bfece610f893ee2e0"},
{file = "pyOpenSSL-22.0.0.tar.gz", hash = "sha256:660b1b1425aac4a1bea1d94168a85d99f0b3144c869dd4390d27629d0087f1bf"},
]
pyparsing = [
{file = "pyparsing-3.0.7-py3-none-any.whl", hash = "sha256:a6c06a88f252e6c322f65faf8f418b16213b51bdfaece0524c1c1bc30c63c484"},
@@ -2569,8 +2569,8 @@ semantic-version = [
{file = "semantic_version-2.10.0.tar.gz", hash = "sha256:bdabb6d336998cbb378d4b9db3a4b56a1e3235701dc05ea2690d9a997ed5041c"},
]
sentry-sdk = [
{file = "sentry-sdk-1.11.1.tar.gz", hash = "sha256:675f6279b6bb1fea09fd61751061f9a90dca3b5929ef631dd50dc8b3aeb245e9"},
{file = "sentry_sdk-1.11.1-py2.py3-none-any.whl", hash = "sha256:8b4ff696c0bdcceb3f70bbb87a57ba84fd3168b1332d493fcd16c137f709578c"},
{file = "sentry-sdk-1.11.0.tar.gz", hash = "sha256:e7b78a1ddf97a5f715a50ab8c3f7a93f78b114c67307785ee828ef67a5d6f117"},
{file = "sentry_sdk-1.11.0-py2.py3-none-any.whl", hash = "sha256:f467e6c7fac23d4d42bc83eb049c400f756cd2d65ab44f0cc1165d0c7c3d40bc"},
]
service-identity = [
{file = "service-identity-21.1.0.tar.gz", hash = "sha256:6e6c6086ca271dc11b033d17c3a8bea9f24ebff920c587da090afc9519419d34"},
@@ -2781,8 +2781,8 @@ typed-ast = [
{file = "typed_ast-1.5.2.tar.gz", hash = "sha256:525a2d4088e70a9f75b08b3f87a51acc9cde640e19cc523c7e41aa355564ae27"},
]
types-bleach = [
{file = "types-bleach-5.0.3.1.tar.gz", hash = "sha256:ce8772ea5126dab1883851b41e3aeff229aa5213ced36096990344e632e92373"},
{file = "types_bleach-5.0.3.1-py3-none-any.whl", hash = "sha256:af5f1b3a54ff279f54c29eccb2e6988ebb6718bc4061469588a5fd4880a79287"},
{file = "types-bleach-5.0.3.tar.gz", hash = "sha256:f7b3df8278efe176d9670d0f063a66c866c77577f71f54b9c7a320e31b1a7bbd"},
{file = "types_bleach-5.0.3-py3-none-any.whl", hash = "sha256:5931525d03571f36b2bb40210c34b662c4d26c8fd6f2b1e1e83fe4d2d2fd63c7"},
]
types-commonmark = [
{file = "types-commonmark-0.9.2.tar.gz", hash = "sha256:b894b67750c52fd5abc9a40a9ceb9da4652a391d75c1b480bba9cef90f19fc86"},
@@ -2813,8 +2813,8 @@ types-pillow = [
{file = "types_Pillow-9.3.0.1-py3-none-any.whl", hash = "sha256:79837755fe9659f29efd1016e9903ac4a500e0c73260483f07296bd6ca47668b"},
]
types-psycopg2 = [
{file = "types-psycopg2-2.9.21.2.tar.gz", hash = "sha256:bff045579642ce00b4a3c8f2e401b7f96dfaa34939f10be64b0dd3b53feca57d"},
{file = "types_psycopg2-2.9.21.2-py3-none-any.whl", hash = "sha256:084558d6bc4b2cfa249b06be0fdd9a14a69d307bae5bb5809a2f14cfbaa7a23f"},
{file = "types-psycopg2-2.9.21.1.tar.gz", hash = "sha256:f5532cf15afdc6b5ebb1e59b7d896617217321f488fd1fbd74e7efb94decfab6"},
{file = "types_psycopg2-2.9.21.1-py3-none-any.whl", hash = "sha256:858838f1972f39da2a6e28274201fed8619a40a235dd86e7f66f4548ec474395"},
]
types-pyopenssl = [
{file = "types-pyOpenSSL-22.1.0.2.tar.gz", hash = "sha256:7a350e29e55bc3ee4571f996b4b1c18c4e4098947db45f7485b016eaa35b44bc"},

View File

@@ -57,7 +57,7 @@ manifest-path = "rust/Cargo.toml"
[tool.poetry]
name = "matrix-synapse"
version = "1.73.0rc2"
version = "1.73.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"

View File

@@ -33,12 +33,10 @@ fn bench_match_exact(b: &mut Bencher) {
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
Some(0),
0,
Default::default(),
Default::default(),
true,
vec![],
false,
)
.unwrap();
@@ -69,12 +67,10 @@ fn bench_match_word(b: &mut Bencher) {
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
Some(0),
0,
Default::default(),
Default::default(),
true,
vec![],
false,
)
.unwrap();
@@ -105,12 +101,10 @@ fn bench_match_word_miss(b: &mut Bencher) {
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
Some(0),
0,
Default::default(),
Default::default(),
true,
vec![],
false,
)
.unwrap();
@@ -141,12 +135,10 @@ fn bench_eval_message(b: &mut Bencher) {
let eval = PushRuleEvaluator::py_new(
flattened_keys,
10,
Some(0),
0,
Default::default(),
Default::default(),
true,
vec![],
false,
)
.unwrap();

View File

@@ -1,77 +0,0 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#![feature(test)]
use synapse::tree_cache::TreeCache;
use test::Bencher;
extern crate test;
#[bench]
fn bench_tree_cache_get_non_empty(b: &mut Bencher) {
let mut cache: TreeCache<&str, &str> = TreeCache::new();
cache.set(["a", "b", "c", "d"], "f").unwrap();
b.iter(|| cache.get(&["a", "b", "c", "d"]));
}
#[bench]
fn bench_tree_cache_get_empty(b: &mut Bencher) {
let cache: TreeCache<&str, &str> = TreeCache::new();
b.iter(|| cache.get(&["a", "b", "c", "d"]));
}
#[bench]
fn bench_tree_cache_set(b: &mut Bencher) {
let mut cache: TreeCache<&str, &str> = TreeCache::new();
b.iter(|| cache.set(["a", "b", "c", "d"], "f").unwrap());
}
#[bench]
fn bench_tree_cache_length(b: &mut Bencher) {
let mut cache: TreeCache<u32, u32> = TreeCache::new();
for c1 in 0..=10 {
for c2 in 0..=10 {
for c3 in 0..=10 {
for c4 in 0..=10 {
cache.set([c1, c2, c3, c4], 1).unwrap()
}
}
}
}
b.iter(|| cache.len());
}
#[bench]
fn tree_cache_iterate(b: &mut Bencher) {
let mut cache: TreeCache<u32, u32> = TreeCache::new();
for c1 in 0..=10 {
for c2 in 0..=10 {
for c3 in 0..=10 {
for c4 in 0..=10 {
cache.set([c1, c2, c3, c4], 1).unwrap()
}
}
}
}
b.iter(|| cache.items().count());
}

View File

@@ -1,7 +1,6 @@
use pyo3::prelude::*;
pub mod push;
pub mod tree_cache;
/// Returns the hash of all the rust source files at the time it was compiled.
///
@@ -27,7 +26,6 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
push::register_module(py, m)?;
tree_cache::binding::register_module(py, m)?;
Ok(())
}

View File

@@ -12,8 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::borrow::Cow;
use std::collections::BTreeMap;
use crate::push::{PushRule, PushRules};
use anyhow::{Context, Error};
use lazy_static::lazy_static;
use log::warn;
@@ -96,7 +98,6 @@ pub struct PushRuleEvaluator {
#[pymethods]
impl PushRuleEvaluator {
/// Create a new `PushRuleEvaluator`. See struct docstring for details.
#[allow(clippy::too_many_arguments)]
#[new]
pub fn py_new(
flattened_keys: BTreeMap<String, String>,
@@ -152,12 +153,15 @@ impl PushRuleEvaluator {
let mut has_rver_condition = false;
for condition in push_rule.conditions.iter() {
has_rver_condition |= matches!(
condition,
// per MSC3932, we just need *any* room version condition to match
Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }),
);
has_rver_condition = has_rver_condition
|| match condition {
Condition::Known(known) => match known {
// per MSC3932, we just need *any* room version condition to match
KnownCondition::RoomVersionSupports { feature: _ } => true,
_ => false,
},
_ => false,
};
match self.match_condition(condition, user_id, display_name) {
Ok(true) => {}
Ok(false) => continue 'outer,
@@ -440,10 +444,6 @@ fn push_rule_evaluator() {
#[test]
fn test_requires_room_version_supports_condition() {
use std::borrow::Cow;
use crate::push::{PushRule, PushRules};
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()];

View File

@@ -1,247 +0,0 @@
use std::hash::Hash;
use anyhow::Error;
use pyo3::{
pyclass, pymethods,
types::{PyModule, PyTuple},
IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
};
use super::TreeCache;
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "tree_cache")?;
child_module.add_class::<PythonTreeCache>()?;
child_module.add_class::<StringTreeCache>()?;
m.add_submodule(child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import push` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.tree_cache", child_module)?;
Ok(())
}
#[derive(Clone)]
struct HashablePyObject {
obj: PyObject,
hash: isize,
}
impl HashablePyObject {
pub fn new(obj: &PyAny) -> Result<Self, Error> {
let hash = obj.hash()?;
Ok(HashablePyObject {
obj: obj.to_object(obj.py()),
hash,
})
}
}
impl IntoPy<PyObject> for HashablePyObject {
fn into_py(self, _: Python<'_>) -> PyObject {
self.obj.clone()
}
}
impl IntoPy<PyObject> for &HashablePyObject {
fn into_py(self, _: Python<'_>) -> PyObject {
self.obj.clone()
}
}
impl ToPyObject for HashablePyObject {
fn to_object(&self, _py: Python<'_>) -> PyObject {
self.obj.clone()
}
}
impl Hash for HashablePyObject {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.hash.hash(state);
}
}
impl PartialEq for HashablePyObject {
fn eq(&self, other: &Self) -> bool {
let equal = Python::with_gil(|py| {
let result = self.obj.as_ref(py).eq(other.obj.as_ref(py));
result.unwrap_or(false)
});
equal
}
}
impl Eq for HashablePyObject {}
#[pyclass]
struct PythonTreeCache(TreeCache<HashablePyObject, PyObject>);
#[pymethods]
impl PythonTreeCache {
#[new]
fn new() -> Self {
PythonTreeCache(Default::default())
}
pub fn set(&mut self, key: &PyAny, value: PyObject) -> Result<(), Error> {
let v: Vec<HashablePyObject> = key
.iter()?
.map(|obj| HashablePyObject::new(obj?))
.collect::<Result<_, _>>()?;
self.0.set(v, value)?;
Ok(())
}
pub fn get_node<'a>(
&'a self,
py: Python<'a>,
key: &'a PyAny,
) -> Result<Option<Vec<(&'a PyTuple, &'a PyObject)>>, Error> {
let v: Vec<HashablePyObject> = key
.iter()?
.map(|obj| HashablePyObject::new(obj?))
.collect::<Result<_, _>>()?;
let Some(node) = self.0.get_node(v.clone())? else {
return Ok(None)
};
let items = node
.items()
.map(|(k, value)| {
let vec = v.iter().chain(k.iter().map(|a| *a)).collect::<Vec<_>>();
let nk = PyTuple::new(py, vec);
(nk, value)
})
.collect::<Vec<_>>();
Ok(Some(items))
}
pub fn get(&self, key: &PyAny) -> Result<Option<&PyObject>, Error> {
let v: Vec<HashablePyObject> = key
.iter()?
.map(|obj| HashablePyObject::new(obj?))
.collect::<Result<_, _>>()?;
Ok(self.0.get(&v)?)
}
pub fn pop_node<'a>(
&'a mut self,
py: Python<'a>,
key: &'a PyAny,
) -> Result<Option<Vec<(&'a PyTuple, PyObject)>>, Error> {
let v: Vec<HashablePyObject> = key
.iter()?
.map(|obj| HashablePyObject::new(obj?))
.collect::<Result<_, _>>()?;
let Some(node) = self.0.pop_node(v.clone())? else {
return Ok(None)
};
let items = node
.into_items()
.map(|(k, value)| {
let vec = v.iter().chain(k.iter()).collect::<Vec<_>>();
let nk = PyTuple::new(py, vec);
(nk, value)
})
.collect::<Vec<_>>();
Ok(Some(items))
}
pub fn pop(&mut self, key: &PyAny) -> Result<Option<PyObject>, Error> {
let v: Vec<HashablePyObject> = key
.iter()?
.map(|obj| HashablePyObject::new(obj?))
.collect::<Result<_, _>>()?;
Ok(self.0.pop(&v)?)
}
pub fn clear(&mut self) {
self.0.clear()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn values(&self) -> Vec<&PyObject> {
self.0.values().collect()
}
pub fn items(&self) -> Vec<(Vec<&HashablePyObject>, &PyObject)> {
todo!()
}
}
#[pyclass]
struct StringTreeCache(TreeCache<String, String>);
#[pymethods]
impl StringTreeCache {
#[new]
fn new() -> Self {
StringTreeCache(Default::default())
}
pub fn set(&mut self, key: &PyAny, value: String) -> Result<(), Error> {
let key = key
.iter()?
.map(|o| o.expect("iter failed").extract().expect("not a string"));
self.0.set(key, value)?;
Ok(())
}
// pub fn get_node(&self, key: &PyAny) -> Result<Option<&TreeCacheNode<K, PyObject>>, Error> {
// todo!()
// }
pub fn get(&self, key: &PyAny) -> Result<Option<&String>, Error> {
let key = key.iter()?.map(|o| {
o.expect("iter failed")
.extract::<String>()
.expect("not a string")
});
Ok(self.0.get(key)?)
}
// pub fn pop_node(&mut self, key: &PyAny) -> Result<Option<TreeCacheNode<K, PyObject>>, Error> {
// todo!()
// }
pub fn pop(&mut self, key: Vec<String>) -> Result<Option<String>, Error> {
Ok(self.0.pop(&key)?)
}
pub fn clear(&mut self) {
self.0.clear()
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn values(&self) -> Vec<&String> {
self.0.values().collect()
}
pub fn items(&self) -> Vec<(Vec<&HashablePyObject>, &PyObject)> {
todo!()
}
}

View File

@@ -1,421 +0,0 @@
use std::{borrow::Borrow, collections::HashMap, hash::Hash};
use anyhow::{bail, Error};
pub mod binding;
pub enum TreeCacheNode<K, V> {
Leaf(V),
Branch(usize, HashMap<K, TreeCacheNode<K, V>>),
}
impl<K, V> TreeCacheNode<K, V> {
pub fn new_branch() -> Self {
TreeCacheNode::Branch(0, Default::default())
}
fn len(&self) -> usize {
match self {
TreeCacheNode::Leaf(_) => 1,
TreeCacheNode::Branch(size, _) => *size,
}
}
}
impl<'a, K: Eq + Hash + 'a, V> TreeCacheNode<K, V> {
pub fn set(
&mut self,
mut key: impl Iterator<Item = K>,
value: V,
) -> Result<(usize, usize), Error> {
if let Some(k) = key.next() {
match self {
TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
TreeCacheNode::Branch(size, map) => {
let node = map.entry(k).or_insert_with(TreeCacheNode::new_branch);
let (added, removed) = node.set(key, value)?;
*size += added;
*size -= removed;
Ok((added, removed))
}
}
} else {
let added = if let TreeCacheNode::Branch(_, map) = self {
(1, map.len())
} else {
(0, 0)
};
*self = TreeCacheNode::Leaf(value);
Ok(added)
}
}
pub fn pop<Q>(
&mut self,
current_key: Q,
mut next_keys: impl Iterator<Item = Q>,
) -> Result<Option<TreeCacheNode<K, V>>, Error>
where
Q: Borrow<K>,
Q: Hash + Eq + 'a,
{
if let Some(next_key) = next_keys.next() {
match self {
TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
TreeCacheNode::Branch(size, map) => {
let node = if let Some(node) = map.get_mut(current_key.borrow()) {
node
} else {
return Ok(None);
};
if let Some(popped) = node.pop(next_key, next_keys)? {
*size -= node.len();
Ok(Some(popped))
} else {
Ok(None)
}
}
}
} else {
match self {
TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
TreeCacheNode::Branch(size, map) => {
if let Some(node) = map.remove(current_key.borrow()) {
*size -= node.len();
Ok(Some(node))
} else {
Ok(None)
}
}
}
}
}
pub fn items(&'a self) -> impl Iterator<Item = (Vec<&K>, &V)> {
// To avoid a lot of mallocs we guess the length of the key. Ideally
// we'd know this.
let capacity_guesstimate = 10;
let mut stack = vec![(Vec::with_capacity(capacity_guesstimate), self)];
std::iter::from_fn(move || {
while let Some((prefix, node)) = stack.pop() {
match node {
TreeCacheNode::Leaf(value) => return Some((prefix, value)),
TreeCacheNode::Branch(_, map) => {
stack.extend(map.iter().map(|(k, v)| {
let mut new_prefix = Vec::with_capacity(capacity_guesstimate);
new_prefix.extend_from_slice(&prefix);
new_prefix.push(k);
(new_prefix, v)
}));
}
}
}
None
})
}
pub fn values(&'a self) -> impl Iterator<Item = &V> {
let mut stack = vec![self];
std::iter::from_fn(move || {
while let Some(node) = stack.pop() {
match node {
TreeCacheNode::Leaf(value) => return Some(value),
TreeCacheNode::Branch(_, map) => {
stack.extend(map.iter().map(|(_k, v)| v));
}
}
}
None
})
}
}
impl<'a, K: Clone + Eq + Hash + 'a, V> TreeCacheNode<K, V> {
pub fn into_items(self) -> impl Iterator<Item = (Vec<K>, V)> {
let mut stack = vec![(Vec::new(), self)];
std::iter::from_fn(move || {
while let Some((prefix, node)) = stack.pop() {
match node {
TreeCacheNode::Leaf(value) => return Some((prefix, value)),
TreeCacheNode::Branch(_, map) => {
stack.extend(map.into_iter().map(|(k, v)| {
let mut prefix = prefix.clone();
prefix.push(k);
(prefix, v)
}));
}
}
}
None
})
}
}
impl<K, V> Default for TreeCacheNode<K, V> {
fn default() -> Self {
TreeCacheNode::new_branch()
}
}
pub struct TreeCache<K, V> {
root: TreeCacheNode<K, V>,
}
impl<K, V> TreeCache<K, V> {
pub fn new() -> Self {
TreeCache {
root: TreeCacheNode::new_branch(),
}
}
}
impl<'a, K: Eq + Hash + 'a, V> TreeCache<K, V> {
pub fn set(&mut self, key: impl IntoIterator<Item = K>, value: V) -> Result<(), Error> {
self.root.set(key.into_iter(), value)?;
Ok(())
}
pub fn get_node<Q>(
&self,
key: impl IntoIterator<Item = Q>,
) -> Result<Option<&TreeCacheNode<K, V>>, Error>
where
Q: Borrow<K>,
Q: Hash + Eq + 'a,
{
let mut node = &self.root;
for k in key {
match node {
TreeCacheNode::Leaf(_) => bail!("Given key is too long"),
TreeCacheNode::Branch(_, map) => {
node = if let Some(node) = map.get(k.borrow()) {
node
} else {
return Ok(None);
};
}
}
}
Ok(Some(node))
}
pub fn get<Q>(&self, key: impl IntoIterator<Item = Q>) -> Result<Option<&V>, Error>
where
Q: Borrow<K>,
Q: Hash + Eq + 'a,
{
if let Some(node) = self.get_node(key)? {
match node {
TreeCacheNode::Leaf(value) => Ok(Some(value)),
TreeCacheNode::Branch(_, _) => bail!("Given key is too short"),
}
} else {
Ok(None)
}
}
pub fn pop_node<Q>(
&mut self,
key: impl IntoIterator<Item = Q>,
) -> Result<Option<TreeCacheNode<K, V>>, Error>
where
Q: Borrow<K>,
Q: Hash + Eq + 'a,
{
let mut key_iter = key.into_iter();
let k = if let Some(k) = key_iter.next() {
k
} else {
let node = std::mem::replace(&mut self.root, TreeCacheNode::new_branch());
return Ok(Some(node));
};
self.root.pop(k, key_iter)
}
pub fn pop(&mut self, key: &[K]) -> Result<Option<V>, Error> {
if let Some(node) = self.pop_node(key)? {
match node {
TreeCacheNode::Leaf(value) => Ok(Some(value)),
TreeCacheNode::Branch(_, _) => bail!("Given key is too short"),
}
} else {
Ok(None)
}
}
pub fn clear(&mut self) {
self.root = TreeCacheNode::new_branch();
}
pub fn len(&self) -> usize {
match self.root {
TreeCacheNode::Leaf(_) => 1,
TreeCacheNode::Branch(size, _) => size,
}
}
pub fn values(&self) -> impl Iterator<Item = &V> {
let mut stack = vec![&self.root];
std::iter::from_fn(move || {
while let Some(node) = stack.pop() {
match node {
TreeCacheNode::Leaf(value) => return Some(value),
TreeCacheNode::Branch(_, map) => {
stack.extend(map.values());
}
}
}
None
})
}
pub fn items(&self) -> impl Iterator<Item = (Vec<&K>, &V)> {
self.root.items()
}
}
impl<K, V> Default for TreeCache<K, V> {
fn default() -> Self {
TreeCache::new()
}
}
#[cfg(test)]
mod test {
use std::collections::BTreeSet;
use super::*;
#[test]
fn get_set() -> Result<(), Error> {
let mut cache = TreeCache::new();
cache.set(vec!["a", "b"], "c")?;
assert_eq!(cache.get(&["a", "b"])?, Some(&"c"));
let node = cache.get_node(&["a"])?.unwrap();
match node {
TreeCacheNode::Leaf(_) => bail!("expected branch"),
TreeCacheNode::Branch(_, map) => {
assert_eq!(map.len(), 1);
assert!(map.contains_key("b"));
}
}
Ok(())
}
#[test]
fn length() -> Result<(), Error> {
let mut cache = TreeCache::new();
cache.set(vec!["a", "b"], "c")?;
assert_eq!(cache.len(), 1);
cache.set(vec!["a", "b"], "d")?;
assert_eq!(cache.len(), 1);
cache.set(vec!["e", "f"], "g")?;
assert_eq!(cache.len(), 2);
cache.set(vec!["e", "h"], "i")?;
assert_eq!(cache.len(), 3);
cache.set(vec!["e"], "i")?;
assert_eq!(cache.len(), 2);
cache.pop_node(&["a"])?;
assert_eq!(cache.len(), 1);
Ok(())
}
#[test]
fn clear() -> Result<(), Error> {
let mut cache = TreeCache::new();
cache.set(vec!["a", "b"], "c")?;
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert_eq!(cache.get(&["a", "b"])?, None);
Ok(())
}
#[test]
fn pop() -> Result<(), Error> {
let mut cache = TreeCache::new();
cache.set(vec!["a", "b"], "c")?;
assert_eq!(cache.pop(&["a", "b"])?, Some("c"));
assert_eq!(cache.pop(&["a", "b"])?, None);
Ok(())
}
#[test]
fn values() -> Result<(), Error> {
let mut cache = TreeCache::new();
cache.set(vec!["a", "b"], "c")?;
let expected = ["c"].iter().collect();
assert_eq!(cache.values().collect::<BTreeSet<_>>(), expected);
cache.set(vec!["d", "e"], "f")?;
let expected = ["c", "f"].iter().collect();
assert_eq!(cache.values().collect::<BTreeSet<_>>(), expected);
Ok(())
}
#[test]
fn items() -> Result<(), Error> {
let mut cache = TreeCache::new();
cache.set(vec!["a", "b"], "c")?;
cache.set(vec!["d", "e"], "f")?;
let expected = [(vec![&"a", &"b"], &"c"), (vec![&"d", &"e"], &"f")]
.into_iter()
.collect();
assert_eq!(cache.items().collect::<BTreeSet<_>>(), expected);
Ok(())
}
}

View File

@@ -44,8 +44,40 @@ from synapse.http.server import JsonResource, OptionsResource
from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.rest import ClientRestResource
from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.client import (
account_data,
events,
initial_sync,
login,
presence,
profile,
push_rule,
read_marker,
receipts,
relations,
room,
room_batch,
room_keys,
sendtodevice,
sync,
tags,
user_directory,
versions,
voip,
)
from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet
from synapse.rest.client.devices import DevicesRestServlet
from synapse.rest.client.keys import (
KeyChangesServlet,
KeyQueryServlet,
KeyUploadServlet,
OneTimeKeyServlet,
)
from synapse.rest.client.register import (
RegisterRestServlet,
RegistrationTokenValidityRestServlet,
)
from synapse.rest.health import HealthResource
from synapse.rest.key.v2 import KeyResource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -168,7 +200,45 @@ class GenericWorkerServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource: Resource = ClientRestResource(self)
resource = JsonResource(self, canonical_json=False)
RegisterRestServlet(self).register(resource)
RegistrationTokenValidityRestServlet(self).register(resource)
login.register_servlets(self, resource)
ThreepidRestServlet(self).register(resource)
WhoamiRestServlet(self).register(resource)
DevicesRestServlet(self).register(resource)
# Read-only
KeyUploadServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
KeyChangesServlet(self).register(resource)
OneTimeKeyServlet(self).register(resource)
voip.register_servlets(self, resource)
push_rule.register_servlets(self, resource)
versions.register_servlets(self, resource)
profile.register_servlets(self, resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
room.register_servlets(self, resource, is_worker=True)
relations.register_servlets(self, resource)
room.register_deprecated_servlets(self, resource)
initial_sync.register_servlets(self, resource)
room_batch.register_servlets(self, resource)
room_keys.register_servlets(self, resource)
tags.register_servlets(self, resource)
account_data.register_servlets(self, resource)
receipts.register_servlets(self, resource)
read_marker.register_servlets(self, resource)
sendtodevice.register_servlets(self, resource)
user_directory.register_servlets(self, resource)
presence.register_servlets(self, resource)
resources[CLIENT_API_PREFIX] = resource

View File

@@ -26,7 +26,6 @@ class PushConfig(Config):
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
self.enable_push = push_config.get("enabled", True)
self.push_group_unread_count_by_room = push_config.get(
"group_unread_count_by_room", True
)

View File

@@ -14,6 +14,7 @@
import abc
import logging
import urllib
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
import attr
@@ -812,27 +813,31 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
results = {}
async def get_keys(key_to_fetch_item: _FetchKeyRequest) -> None:
async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
server_name = key_to_fetch_item.server_name
key_ids = key_to_fetch_item.key_ids
try:
keys = await self.get_server_verify_keys_v2_direct(server_name)
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
except KeyLookupError as e:
logger.warning("Error looking up keys from %s: %s", server_name, e)
logger.warning(
"Error looking up keys %s from %s: %s", key_ids, server_name, e
)
except Exception:
logger.exception("Error getting keys from %s", server_name)
logger.exception("Error getting keys %s from %s", key_ids, server_name)
await yieldable_gather_results(get_keys, keys_to_fetch)
await yieldable_gather_results(get_key, keys_to_fetch)
return results
async def get_server_verify_keys_v2_direct(
self, server_name: str
async def get_server_verify_key_v2_direct(
self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, FetchKeyResult]:
"""
Args:
server_name: Server to request keys from
server_name:
key_ids:
Returns:
Map from key ID to lookup result
@@ -840,41 +845,57 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
Raises:
KeyLookupError if there was a problem making the lookup
"""
time_now_ms = self.clock.time_msec()
try:
response = await self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server",
ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
#
# Furthermore, when we are acting as a notary server, we cannot
# wait all day for all of the origin servers, as the requesting
# server will otherwise time out before we can respond.
#
# (Note that get_json may make 4 attempts, so this can still take
# almost 45 seconds to fetch the headers, plus up to another 60s to
# read the response).
timeout=10000,
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve
# upon
raise KeyLookupError(str(e))
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
keys: Dict[str, FetchKeyResult] = {}
assert isinstance(response, dict)
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
% (server_name, response["server_name"])
)
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
if requested_key_id in keys:
continue
return await self.process_v2_response(
from_server=server_name,
response_json=response,
time_added_ms=time_now_ms,
)
time_now_ms = self.clock.time_msec()
try:
response = await self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id, safe=""),
ignore_backoff=True,
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
#
# Furthermore, when we are acting as a notary server, we cannot
# wait all day for all of the origin servers, as the requesting
# server will otherwise time out before we can respond.
#
# (Note that get_json may make 4 attempts, so this can still take
# almost 45 seconds to fetch the headers, plus up to another 60s to
# read the response).
timeout=10000,
)
except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve
# upon
raise KeyLookupError(str(e))
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
assert isinstance(response, dict)
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
% (server_name, response["server_name"])
)
response_keys = await self.process_v2_response(
from_server=server_name,
response_json=response,
time_added_ms=time_now_ms,
)
await self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
)
keys.update(response_keys)
return keys

View File

@@ -647,7 +647,7 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains_set = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)
domains = [

View File

@@ -1764,14 +1764,14 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
Returns:
A list of presence states for the given user to receive.
"""
updated_users = None
if from_key:
# Only return updates since the last sync
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
from_key
)
if not updated_users:
updated_users = []
if updated_users is not None:
# Get the actual presence update for each change
users_to_state = await self.get_presence_handler().current_state_for_users(
updated_users

View File

@@ -106,7 +106,6 @@ class BulkPushRuleEvaluator:
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self._event_auth_handler = hs.get_event_auth_handler()
self.should_calculate_push_rules = self.hs.config.push.enable_push
self._related_event_match_enabled = self.hs.config.experimental.msc3664_enabled
@@ -270,8 +269,6 @@ class BulkPushRuleEvaluator:
for each event, check if the message should increment the unread count, and
insert the results into the event_push_actions_staging table.
"""
if not self.should_calculate_push_rules:
return
# For batched events the power level events may not have been persisted yet,
# so we pass in the batched events. Thus if the event cannot be found in the
# database we can check in the batch.

View File

@@ -17,6 +17,7 @@ from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.util.async_helpers import concurrently_execute
async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
@@ -25,12 +26,23 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites)
room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
for room_id, notify_count in room_to_count.items():
# room_to_count may include rooms which the user has left,
# ignore those.
if room_id not in joins:
continue
room_notifs = []
async def get_room_unread_count(room_id: str) -> None:
room_notifs.append(
await store.get_unread_event_push_actions_by_room_for_user(
room_id,
user_id,
)
)
await concurrently_execute(get_room_unread_count, joins, 10)
for notifs in room_notifs:
# Combine the counts from all the threads.
notify_count = notifs.main_timeline.notify_count + sum(
n.notify_count for n in notifs.threads.values()
)
if notify_count == 0:
continue
@@ -39,10 +51,8 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
# return one badge count per conversation
badge += 1
else:
# Increase badge by number of notifications in room
# NOTE: this includes threaded and unthreaded notifications.
# increment the badge count by the number of unread messages in the room
badge += notify_count
return badge

View File

@@ -29,7 +29,7 @@ from synapse.rest.client import (
initial_sync,
keys,
knock,
login,
login as v1_login,
login_token_request,
logout,
mutual_rooms,
@@ -82,10 +82,6 @@ class ClientRestResource(JsonResource):
@staticmethod
def register_servlets(client_resource: HttpServer, hs: "HomeServer") -> None:
# Some servlets are only registered on the main process (and not worker
# processes).
is_main_process = hs.config.worker.worker_app is None
versions.register_servlets(hs, client_resource)
# Deprecated in r0
@@ -96,58 +92,45 @@ class ClientRestResource(JsonResource):
events.register_servlets(hs, client_resource)
room.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
v1_login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
if is_main_process:
directory.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
if is_main_process:
pusher.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)
if is_main_process:
logout.register_servlets(hs, client_resource)
logout.register_servlets(hs, client_resource)
sync.register_servlets(hs, client_resource)
if is_main_process:
filter.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
if is_main_process:
auth.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
read_marker.register_servlets(hs, client_resource)
room_keys.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
if is_main_process:
tokenrefresh.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)
account_data.register_servlets(hs, client_resource)
if is_main_process:
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
if is_main_process:
thirdparty.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)
if is_main_process:
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
room_batch.register_servlets(hs, client_resource)
if is_main_process:
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
if is_main_process:
password_policy.register_servlets(hs, client_resource)
knock.register_servlets(hs, client_resource)
password_policy.register_servlets(hs, client_resource)
knock.register_servlets(hs, client_resource)
# moving to /_synapse/admin
if is_main_process:
admin.register_servlets_for_client_rest_resource(hs, client_resource)
admin.register_servlets_for_client_rest_resource(hs, client_resource)
# unstable
if is_main_process:
mutual_rooms.register_servlets(hs, client_resource)
login_token_request.register_servlets(hs, client_resource)
rendezvous.register_servlets(hs, client_resource)
mutual_rooms.register_servlets(hs, client_resource)
login_token_request.register_servlets(hs, client_resource)
rendezvous.register_servlets(hs, client_resource)

View File

@@ -875,21 +875,19 @@ class AccountStatusRestServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.worker.worker_app is None:
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
ThreepidAddRestServlet(hs).register(http_server)
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
ThreepidAddRestServlet(hs).register(http_server)
ThreepidBindRestServlet(hs).register(http_server)
ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
if hs.config.worker.worker_app is None and hs.config.experimental.msc3720_enabled:
if hs.config.experimental.msc3720_enabled:
AccountStatusRestServlet(hs).register(http_server)

View File

@@ -342,10 +342,8 @@ class ClaimDehydratedDeviceServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.worker.worker_app is None:
DeleteDevicesRestServlet(hs).register(http_server)
DeleteDevicesRestServlet(hs).register(http_server)
DevicesRestServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
DeviceRestServlet(hs).register(http_server)
DehydratedDeviceServlet(hs).register(http_server)
ClaimDehydratedDeviceServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
DehydratedDeviceServlet(hs).register(http_server)
ClaimDehydratedDeviceServlet(hs).register(http_server)

View File

@@ -376,6 +376,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
if hs.config.worker.worker_app is None:
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)
SigningKeyUploadServlet(hs).register(http_server)
SignaturesUploadServlet(hs).register(http_server)

View File

@@ -949,10 +949,9 @@ def _calculate_registration_flows(
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.worker.worker_app is None:
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
RegistrationSubmitTokenServlet(hs).register(http_server)
RegistrationTokenValidityRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)

View File

@@ -1395,7 +1395,9 @@ class RoomSummaryRestServlet(ResolveRoomIdMixin, RestServlet):
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
def register_servlets(
hs: "HomeServer", http_server: HttpServer, is_worker: bool = False
) -> None:
RoomStateEventRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
@@ -1419,7 +1421,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
TimestampLookupRestServlet(hs).register(http_server)
# Some servlets only get registered for the main process.
if hs.config.worker.worker_app is None:
if not is_worker:
RoomForgetRestServlet(hs).register(http_server)

View File

@@ -77,7 +77,6 @@ class VersionsRestServlet(RestServlet):
"v1.2",
"v1.3",
"v1.4",
"v1.5",
],
# as per MSC1497:
"unstable_features": {

View File

@@ -842,11 +842,12 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_ids, from_key
)
# If an empty set was returned, there's nothing to do.
if user_ids_to_check is not None and not user_ids_to_check:
if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes: Set[str] = set()
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
@@ -857,25 +858,19 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
WHERE {stream_id_where_clause}
AND
"""
# If the stream change cache gave us no information, fetch *all*
# users between the stream IDs.
if user_ids_to_check is None:
txn.execute(sql, sql_args)
return {user_id for user_id, in txn}
# Otherwise, fetch changes for the given users.
else:
changes: Set[str] = set()
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + " AND " + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
# Query device changes with a batch of users at a time
# Assertion for mypy's benefit; see also
# https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
assert user_ids_to_check is not None
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
return changes

View File

@@ -74,7 +74,6 @@ receipt.
"""
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
@@ -96,7 +95,6 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
@@ -465,153 +463,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
"""Get the notification count by room for a user. Only considers notifications,
not highlight or unread counts, and threads are currently aggregated under their room.
This function is intentionally not cached because it is called to calculate the
unread badge for push notifications and thus the result is expected to change.
Note that this function assumes the user is a member of the room. Because
summary rows are not removed when a user leaves a room, the caller must
filter out those results from the result.
Returns:
A map of room ID to notification counts for the given user.
"""
return await self.db_pool.runInteraction(
"get_unread_counts_by_room_for_user",
self._get_unread_counts_by_room_for_user_txn,
user_id,
)
def _get_unread_counts_by_room_for_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, int]:
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
args.extend([user_id, user_id])
receipts_cte = f"""
WITH all_receipts AS (
SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE
{receipt_types_clause}
AND user_id = ?
GROUP BY room_id, thread_id
)
"""
receipts_joins = """
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS threaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NOT NULL
) AS threaded_receipts USING (room_id, thread_id)
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NULL
) AS unthreaded_receipts USING (room_id)
"""
# First get summary counts by room / thread for the user. We use the max receipt
# stream ordering of both threaded & unthreaded receipts to compare against the
# summary table.
#
# PostgreSQL and SQLite differ in comparing scalar numerics.
if isinstance(self.database_engine, PostgresEngine):
# GREATEST ignores NULLs.
max_clause = """GREATEST(
threaded_receipt_stream_ordering,
unthreaded_receipt_stream_ordering
)"""
else:
# MAX returns NULL if any are NULL, so COALESCE to 0 first.
max_clause = """MAX(
COALESCE(threaded_receipt_stream_ordering, 0),
COALESCE(unthreaded_receipt_stream_ordering, 0)
)"""
sql = f"""
{receipts_cte}
SELECT eps.room_id, eps.thread_id, notif_count
FROM event_push_summary AS eps
{receipts_joins}
WHERE user_id = ?
AND notif_count != 0
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
OR last_receipt_stream_ordering = {max_clause}
)
"""
txn.execute(sql, args)
seen_thread_ids = set()
room_to_count: Dict[str, int] = defaultdict(int)
for room_id, thread_id, notif_count in txn:
room_to_count[room_id] += notif_count
seen_thread_ids.add(thread_id)
# Now get any event push actions that haven't been rotated using the same OR
# join and filter by receipt and event push summary rotated up to stream ordering.
sql = f"""
{receipts_cte}
SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND epa.notif = 1
AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id, epa.thread_id
"""
txn.execute(sql, args)
for room_id, thread_id, notif_count in txn:
# Note: only count push actions we have valid summaries for with up to date receipt.
if thread_id not in seen_thread_ids:
continue
room_to_count[room_id] += notif_count
thread_id_clause, thread_ids_args = make_in_list_sql_clause(
self.database_engine, "epa.thread_id", seen_thread_ids
)
# Finally re-check event_push_actions for any rooms not in the summary, ignoring
# the rotated up-to position. This handles the case where a read receipt has arrived
# but not been rotated meaning the summary table is out of date, so we go back to
# the push actions table.
sql = f"""
{receipts_cte}
SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND NOT {thread_id_clause}
AND epa.notif = 1
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id
"""
args.extend(thread_ids_args)
txn.execute(sql, args)
for room_id, notif_count in txn:
room_to_count[room_id] += notif_count
return room_to_count
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,

View File

@@ -433,7 +433,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
async def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server")
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response
self.http_client.get_json.side_effect = get_json
@@ -469,6 +469,18 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
self.assertEqual(keys, {})
def test_keyid_containing_forward_slash(self) -> None:
"""We should url-encode any url unsafe chars in key ids.
Detects https://github.com/matrix-org/synapse/issues/14488.
"""
fetcher = ServerKeyFetcher(self.hs)
self.get_success(fetcher.get_keys("example.com", ["key/potato"], 0))
self.http_client.get_json.assert_called_once()
args, kwargs = self.http_client.get_json.call_args
self.assertEqual(kwargs["path"], "/_matrix/key/v2/server/key%2Fpotato")
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):

View File

@@ -126,13 +126,6 @@ class PresenceRouterTestModule:
class PresenceRouterTestCase(FederatingHomeserverTestCase):
"""
Test cases using a custom PresenceRouter
By default in test cases, federation sending is disabled. This class re-enables it
for the main process by setting `federation_sender_instances` to None.
"""
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -157,11 +150,6 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
self.sync_handler = self.hs.get_sync_handler()
self.module_api = homeserver.get_module_api()
def default_config(self) -> JsonDict:
config = super().default_config()
config["federation_sender_instances"] = None
return config
@override_config(
{
"presence": {
@@ -174,6 +162,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
},
"send_federation": True,
}
)
def test_receiving_all_presence_legacy(self):
@@ -191,6 +180,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
},
],
"send_federation": True,
}
)
def test_receiving_all_presence(self):
@@ -300,6 +290,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
}
},
"send_federation": True,
}
)
def test_send_local_online_presence_to_with_module_legacy(self):
@@ -319,6 +310,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
},
},
],
"send_federation": True,
}
)
def test_send_local_online_presence_to_with_module(self):

View File

@@ -7,21 +7,13 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.units import Edu
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.types import JsonDict
from synapse.util.retryutils import NotRetryingDestination
from tests.test_utils import event_injection, make_awaitable
from tests.unittest import FederatingHomeserverTestCase
from tests.unittest import FederatingHomeserverTestCase, override_config
class FederationCatchUpTestCases(FederatingHomeserverTestCase):
"""
Tests cases of catching up over federation.
By default for test cases federation sending is disabled. This Test class has it
re-enabled for the main process.
"""
servlets = [
admin.register_servlets,
room.register_servlets,
@@ -50,11 +42,6 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.record_transaction
)
def default_config(self) -> JsonDict:
config = super().default_config()
config["federation_sender_instances"] = None
return config
async def record_transaction(self, txn, json_cb):
if self.is_online:
data = json_cb()
@@ -92,6 +79,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)[0]
return {"event_id": event_id, "stream_ordering": stream_ordering}
@override_config({"send_federation": True})
def test_catch_up_destination_rooms_tracking(self):
"""
Tests that we populate the `destination_rooms` table as needed.
@@ -117,6 +105,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.assertEqual(row_2["event_id"], event_id_2)
self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
@override_config({"send_federation": True})
def test_catch_up_last_successful_stream_ordering_tracking(self):
"""
Tests that we populate the `destination_rooms` table as needed.
@@ -174,6 +163,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
"Send succeeded but not marked as last_successful_stream_ordering",
)
@override_config({"send_federation": True}) # critical to federate
def test_catch_up_from_blank_state(self):
"""
Runs an overall test of federation catch-up from scratch.
@@ -270,6 +260,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
return per_dest_queue, results_list
@override_config({"send_federation": True})
def test_catch_up_loop(self):
"""
Tests the behaviour of _catch_up_transmission_loop.
@@ -334,6 +325,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
event_5.internal_metadata.stream_ordering,
)
@override_config({"send_federation": True})
def test_catch_up_on_synapse_startup(self):
"""
Tests the behaviour of get_catch_up_outstanding_destinations and
@@ -432,6 +424,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# - all destinations are woken exactly once; they appear once in woken.
self.assertCountEqual(woken, server_names[:-1])
@override_config({"send_federation": True})
def test_not_latest_event(self):
"""Test that we send the latest event in the room even if its not ours."""

View File

@@ -25,17 +25,10 @@ from synapse.rest.client import login
from synapse.types import JsonDict, ReadReceipt
from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase
from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""
Test federation sending to update receipts.
By default for test cases federation sending is disabled. This Test class has it
re-enabled for the main process.
"""
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
@@ -45,17 +38,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return_value=make_awaitable({"test", "host2"})
)
hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
hs.get_storage_controllers().state.get_current_hosts_in_room
)
return hs
def default_config(self) -> JsonDict:
config = super().default_config()
config["federation_sender_instances"] = None
return config
@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
@@ -98,6 +83,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
],
)
@override_config({"send_federation": True})
def test_send_receipts_thread(self):
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
@@ -174,6 +160,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
],
)
@override_config({"send_federation": True})
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
@@ -260,13 +247,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
class FederationSenderDevicesTestCases(HomeserverTestCase):
"""
Test federation sending to update devices.
By default for test cases federation sending is disabled. This Test class has it
re-enabled for the main process.
"""
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -281,8 +261,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def default_config(self):
c = super().default_config()
# Enable federation sending on the main process.
c["federation_sender_instances"] = None
c["send_federation"] = True
return c
def prepare(self, reactor, clock, hs):

View File

@@ -992,8 +992,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def default_config(self):
config = super().default_config()
# Enable federation sending on the main process.
config["federation_sender_instances"] = None
config["send_federation"] = True
return config
def prepare(self, reactor, clock, hs):

View File

@@ -200,8 +200,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
@override_config({"send_federation": True})
def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
@@ -306,8 +305,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0], [])
self.assertEqual(events[1], 0)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
@override_config({"send_federation": True})
def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]

View File

@@ -56,8 +56,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Re-enables updating the user directory, as that function is needed below.
config["update_user_directory_from_worker"] = None
config["update_user_directory"] = True
self.appservice = ApplicationService(
token="i_am_an_app_service",
@@ -1046,9 +1045,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Re-enables updating the user directory, as that function is needed below. It
# will be force disabled later
config["update_user_directory_from_worker"] = None
config["update_user_directory"] = True
hs = self.setup_test_homeserver(config=config)
self.config = hs.config

View File

@@ -336,8 +336,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Test sending local online presence to users from the main process
_test_sending_local_online_presence_to_local_user(self, test_with_workers=False)
# Enable federation sending on the main process.
@override_config({"federation_sender_instances": None})
@override_config({"send_federation": True})
def test_send_local_online_presence_to_federation(self):
"""Tests that send_local_presence_to_users sends local online presence to remote users."""
# Create a user who will send presence updates

View File

@@ -6,11 +6,10 @@ from synapse.rest import admin
from synapse.rest.client import login, register, room
from synapse.types import create_requester
from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
from tests import unittest
class TestBulkPushRuleEvaluator(HomeserverTestCase):
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets_for_client_rest_resource,
@@ -73,43 +72,3 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@override_config({"push": {"enabled": False}})
def test_action_for_event_by_user_disabled_by_config(self) -> None:
"""Ensure that push rules are not calculated when disabled in the config"""
# Create a new user and room.
alice = self.register_user("alice", "pass")
token = self.login(alice, "pass")
room_id = self.helper.create_room_as(
alice, room_version=RoomVersions.V9.identifier, tok=token
)
# Alter the power levels in that room to include stringy and floaty levels.
# We need to suppress the validation logic or else it will reject these dodgy
# values. (Presumably this validation was not always present.)
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(alice)
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
event, context = self.get_success(
event_creation_handler.create_event(
requester,
{
"type": "m.room.message",
"room_id": room_id,
"content": {
"msgtype": "m.text",
"body": "helo",
},
"sender": alice,
},
)
)
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment]
# should not raise
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
bulk_evaluator._action_for_event_by_user.assert_not_called()

View File

@@ -66,6 +66,7 @@ class EmailPusherTests(HomeserverTestCase):
"riot_base_url": None,
}
config["public_baseurl"] = "http://aaa"
config["start_pushers"] = True
hs = self.setup_test_homeserver(config=config)

View File

@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -41,6 +41,11 @@ class HTTPPusherTests(HomeserverTestCase):
user_id = True
hijack_auth = False
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["start_pushers"] = True
return config
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.push_attempts: List[Tuple[Deferred, str, dict]] = []

View File

@@ -307,7 +307,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
stream to the master HS.
Args:
worker_app: Type of worker, e.g. `synapse.app.generic_worker`.
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
useful to e.g. pass some mocks for things like `federation_http_client`

View File

@@ -22,8 +22,9 @@ class FederationStreamTestCase(BaseStreamTestCase):
def _get_worker_hs_config(self) -> dict:
# enable federation sending on the worker
config = super()._get_worker_hs_config()
config["worker_name"] = "federation_sender1"
config["federation_sender_instances"] = ["federation_sender1"]
# TODO: make it so we don't need both of these
config["send_federation"] = False
config["worker_app"] = "synapse.app.federation_sender"
return config
def test_catchup(self):

View File

@@ -38,7 +38,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_app"] = "synapse.app.client_reader"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
@@ -53,7 +53,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
4. Return the final request.
"""
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
channel_1 = make_request(

View File

@@ -22,20 +22,20 @@ logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Test using one or more generic workers for registration."""
"""Test using one or more client readers for registration."""
servlets = [register.register_servlets]
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_app"] = "synapse.app.client_reader"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def test_register_single_worker(self):
"""Test that registration works when using a single generic worker."""
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
"""Test that registration works when using a single client reader worker."""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
channel_1 = make_request(
@@ -64,9 +64,9 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
def test_register_multi_worker(self):
"""Test that registration works when using multiple generic workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.generic_worker")
worker_hs_2 = self.make_worker_hs("synapse.app.generic_worker")
"""Test that registration works when using multiple client reader workers."""
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
site_1 = self._hs_to_site[worker_hs_1]
channel_1 = make_request(

View File

@@ -25,9 +25,8 @@ from tests.unittest import HomeserverTestCase
class FederationAckTestCase(HomeserverTestCase):
def default_config(self) -> dict:
config = super().default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_name"] = "federation_sender1"
config["federation_sender_instances"] = ["federation_sender1"]
config["worker_app"] = "synapse.app.federation_sender"
config["send_federation"] = False
return config
def make_homeserver(self, reactor, clock):

View File

@@ -27,19 +27,17 @@ logger = logging.getLogger(__name__)
class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"""
Various tests for federation sending on workers.
Federation sending is disabled by default, it will be enabled in each test by
updating 'federation_sender_instances'.
"""
servlets = [
login.register_servlets,
register_servlets_for_client_rest_resource,
room.register_servlets,
]
def default_config(self):
conf = super().default_config()
conf["send_federation"] = False
return conf
def test_send_event_single_sender(self):
"""Test that using a single federation sender worker correctly sends a
new event.
@@ -48,11 +46,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.generic_worker",
{
"worker_name": "federation_sender1",
"federation_sender_instances": ["federation_sender1"],
},
"synapse.app.federation_sender",
{"send_federation": False},
federation_http_client=mock_client,
)
@@ -78,13 +73,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.generic_worker",
"synapse.app.federation_sender",
{
"worker_name": "federation_sender1",
"federation_sender_instances": [
"federation_sender1",
"federation_sender2",
],
"send_federation": True,
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
federation_http_client=mock_client1,
)
@@ -92,13 +85,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.generic_worker",
"synapse.app.federation_sender",
{
"worker_name": "federation_sender2",
"federation_sender_instances": [
"federation_sender1",
"federation_sender2",
],
"send_federation": True,
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
federation_http_client=mock_client2,
)
@@ -145,13 +136,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.generic_worker",
"synapse.app.federation_sender",
{
"worker_name": "federation_sender1",
"federation_sender_instances": [
"federation_sender1",
"federation_sender2",
],
"send_federation": True,
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
federation_http_client=mock_client1,
)
@@ -159,13 +148,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.return_value = make_awaitable({})
self.make_worker_hs(
"synapse.app.generic_worker",
"synapse.app.federation_sender",
{
"worker_name": "federation_sender2",
"federation_sender_instances": [
"federation_sender1",
"federation_sender2",
],
"send_federation": True,
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
federation_http_client=mock_client2,
)

View File

@@ -38,6 +38,11 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
def default_config(self):
conf = super().default_config()
conf["start_pushers"] = False
return conf
def _create_pusher_and_send_msg(self, localpart):
# Create a user that will get push notifications
user_id = self.register_user(localpart, "pass")
@@ -87,8 +92,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "pusher1", "pusher_instances": ["pusher1"]},
"synapse.app.pusher",
{"start_pushers": False},
proxied_blacklisted_http_client=http_client_mock,
)
@@ -117,8 +122,9 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
"synapse.app.generic_worker",
"synapse.app.pusher",
{
"start_pushers": True,
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
@@ -131,8 +137,9 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.make_worker_hs(
"synapse.app.generic_worker",
"synapse.app.pusher",
{
"start_pushers": True,
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.parse
from io import BytesIO, StringIO
from typing import Any, Dict, Optional, Union
from unittest.mock import Mock
@@ -64,7 +65,9 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
self.assertEqual(path, "/_matrix/key/v2/server")
self.assertEqual(
path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),)
)
response = {
"server_name": server_name,

View File

@@ -156,7 +156,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str
def _assert_counts(notif_count: int, highlight_count: int) -> None:
def _assert_counts(noitf_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -168,22 +168,13 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=notif_count,
notify_count=noitf_count,
unread_count=0,
highlight_count=highlight_count,
),
)
self.assertEqual(counts.threads, {})
aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(aggregate_counts[room_id], notif_count)
def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
room_id,
@@ -292,7 +283,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str
def _assert_counts(
notif_count: int,
noitf_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
@@ -308,7 +299,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=notif_count,
notify_count=noitf_count,
unread_count=0,
highlight_count=highlight_count,
),
@@ -327,17 +318,6 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else:
self.assertEqual(counts.threads, {})
aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
@@ -474,7 +454,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
last_event_id: str
def _assert_counts(
notif_count: int,
noitf_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
@@ -490,7 +470,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=notif_count,
notify_count=noitf_count,
unread_count=0,
highlight_count=highlight_count,
),
@@ -509,17 +489,6 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
else:
self.assertEqual(counts.threads, {})
aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
@@ -677,7 +646,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
)
return result["event_id"]
def _assert_counts(notif_count: int, thread_notif_count: int) -> None:
def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
@@ -689,7 +658,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=notif_count, unread_count=0, highlight_count=0
notify_count=noitf_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Generator, List, NoReturn, Optional
from parameterized import parameterized_class
@@ -42,8 +41,8 @@ from tests.unittest import TestCase
class ObservableDeferredTest(TestCase):
def test_succeed(self) -> None:
origin_d: "Deferred[int]" = Deferred()
def test_succeed(self):
origin_d = Deferred()
observable = ObservableDeferred(origin_d)
observer1 = observable.observe()
@@ -53,18 +52,16 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
def check_called_first(res: int) -> int:
def check_called_first(res):
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
results: List[Optional[ObservableDeferred[int]]] = [None, None]
results = [None, None]
def check_val(
res: ObservableDeferred[int], idx: int
) -> ObservableDeferred[int]:
def check_val(res, idx):
results[idx] = res
return res
@@ -75,8 +72,8 @@ class ObservableDeferredTest(TestCase):
self.assertEqual(results[0], 123, "observer 1 callback result")
self.assertEqual(results[1], 123, "observer 2 callback result")
def test_failure(self) -> None:
origin_d: Deferred = Deferred()
def test_failure(self):
origin_d = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe()
@@ -86,16 +83,16 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer2.called)
# check the first observer is called first
def check_called_first(res: int) -> int:
def check_called_first(res):
self.assertFalse(observer2.called)
return res
observer1.addBoth(check_called_first)
# store the results
results: List[Optional[ObservableDeferred[str]]] = [None, None]
results = [None, None]
def check_val(res: ObservableDeferred[str], idx: int) -> None:
def check_val(res, idx):
results[idx] = res
return None
@@ -106,12 +103,10 @@ class ObservableDeferredTest(TestCase):
raise Exception("gah!")
except Exception as e:
origin_d.errback(e)
assert results[0] is not None
self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
def test_cancellation(self) -> None:
def test_cancellation(self):
"""Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
@@ -141,38 +136,37 @@ class ObservableDeferredTest(TestCase):
class TimeoutDeferredTest(TestCase):
def setUp(self) -> None:
def setUp(self):
self.clock = Clock()
def test_times_out(self) -> None:
def test_times_out(self):
"""Basic test case that checks that the original deferred is cancelled and that
the timing-out deferred is errbacked
"""
cancelled = False
cancelled = [False]
def canceller(_d: Deferred) -> None:
nonlocal cancelled
cancelled = True
def canceller(_d):
cancelled[0] = True
non_completing_d: Deferred = Deferred(canceller)
non_completing_d = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
self.assertFalse(cancelled, "deferred was cancelled prematurely")
self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
self.clock.pump((1.0,))
self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_times_out_when_canceller_throws(self) -> None:
def test_times_out_when_canceller_throws(self):
"""Test that we have successfully worked around
https://twistedmatrix.com/trac/ticket/9534"""
def canceller(_d: Deferred) -> None:
def canceller(_d):
raise Exception("can't cancel this deferred")
non_completing_d: Deferred = Deferred(canceller)
non_completing_d = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
self.assertNoResult(timing_out_d)
@@ -181,24 +175,22 @@ class TimeoutDeferredTest(TestCase):
self.failureResultOf(timing_out_d, defer.TimeoutError)
def test_logcontext_is_preserved_on_cancellation(self) -> None:
blocking_was_cancelled = False
def test_logcontext_is_preserved_on_cancellation(self):
blocking_was_cancelled = [False]
@defer.inlineCallbacks
def blocking() -> Generator["Deferred[object]", object, None]:
nonlocal blocking_was_cancelled
non_completing_d: Deferred = Deferred()
def blocking():
non_completing_d = Deferred()
with PreserveLoggingContext():
try:
yield non_completing_d
except CancelledError:
blocking_was_cancelled = True
blocking_was_cancelled[0] = True
raise
with LoggingContext("one") as context_one:
# the errbacks should be run in the test logcontext
def errback(res: Failure, deferred_name: str) -> Failure:
def errback(res, deferred_name):
self.assertIs(
current_context(),
context_one,
@@ -217,7 +209,7 @@ class TimeoutDeferredTest(TestCase):
self.clock.pump((1.0,))
self.assertTrue(
blocking_was_cancelled, "non-completing deferred was not cancelled"
blocking_was_cancelled[0], "non-completing deferred was not cancelled"
)
self.failureResultOf(timing_out_d, defer.TimeoutError)
self.assertIs(current_context(), context_one)
@@ -228,13 +220,13 @@ class _TestException(Exception):
class ConcurrentlyExecuteTest(TestCase):
def test_limits_runners(self) -> None:
def test_limits_runners(self):
"""If we have more tasks than runners, we should get the limit of runners"""
started = 0
waiters = []
processed = []
async def callback(v: int) -> None:
async def callback(v):
# when we first enter, bump the start count
nonlocal started
started += 1
@@ -243,7 +235,7 @@ class ConcurrentlyExecuteTest(TestCase):
processed.append(v)
# wait for the goahead before returning
d2: "Deferred[int]" = Deferred()
d2 = Deferred()
waiters.append(d2)
await d2
@@ -273,16 +265,16 @@ class ConcurrentlyExecuteTest(TestCase):
self.assertCountEqual(processed, [1, 2, 3, 4, 5])
self.successResultOf(d2)
def test_preserves_stacktraces(self) -> None:
def test_preserves_stacktraces(self):
"""Test that the stacktrace from an exception thrown in the callback is preserved"""
d1: "Deferred[int]" = Deferred()
d1 = Deferred()
async def callback(v: int) -> None:
async def callback(v):
# alas, this doesn't work at all without an await here
await d1
raise _TestException("bah")
async def caller() -> None:
async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@@ -298,17 +290,17 @@ class ConcurrentlyExecuteTest(TestCase):
d1.callback(0)
self.successResultOf(d2)
def test_preserves_stacktraces_on_preformed_failure(self) -> None:
def test_preserves_stacktraces_on_preformed_failure(self):
"""Test that the stacktrace on a Failure returned by the callback is preserved"""
d1: "Deferred[int]" = Deferred()
d1 = Deferred()
f = Failure(_TestException("bah"))
async def callback(v: int) -> None:
async def callback(v):
# alas, this doesn't work at all without an await here
await d1
await defer.fail(f)
async def caller() -> None:
async def caller():
try:
await concurrently_execute(callback, [1], 2)
except _TestException as e:
@@ -344,7 +336,7 @@ class CancellationWrapperTests(TestCase):
else:
raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
def test_succeed(self) -> None:
def test_succeed(self):
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@@ -354,7 +346,7 @@ class CancellationWrapperTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertEqual("success", self.successResultOf(wrapper_deferred))
def test_failure(self) -> None:
def test_failure(self):
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = self.wrap_deferred(deferred)
@@ -369,7 +361,7 @@ class CancellationWrapperTests(TestCase):
class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""
def test_cancellation(self) -> None:
def test_cancellation(self):
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
@@ -392,7 +384,7 @@ class StopCancellationTests(TestCase):
class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""
def test_deferred_cancellation(self) -> None:
def test_deferred_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -413,12 +405,12 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_coroutine_cancellation(self) -> None:
def test_coroutine_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
blocking_deferred: "Deferred[None]" = Deferred()
completion_deferred: "Deferred[None]" = Deferred()
async def task() -> NoReturn:
async def task():
await blocking_deferred
completion_deferred.callback(None)
# Raise an exception. Twisted should consume it, otherwise unwanted
@@ -442,7 +434,7 @@ class DelayCancellationTests(TestCase):
# Now that the original coroutine has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_suppresses_second_cancellation(self) -> None:
def test_suppresses_second_cancellation(self):
"""Test that a second cancellation is suppressed.
Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
@@ -467,7 +459,7 @@ class DelayCancellationTests(TestCase):
# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)
def test_propagates_cancelled_error(self) -> None:
def test_propagates_cancelled_error(self):
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred)
@@ -480,14 +472,14 @@ class DelayCancellationTests(TestCase):
self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
def test_preserves_logcontext(self) -> None:
def test_preserves_logcontext(self):
"""Test that logging contexts are preserved."""
blocking_d: "Deferred[None]" = Deferred()
async def inner() -> None:
async def inner():
await make_deferred_yieldable(blocking_d)
async def outer() -> None:
async def outer():
with LoggingContext("c") as c:
try:
await delay_cancellation(inner())
@@ -511,7 +503,7 @@ class DelayCancellationTests(TestCase):
class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper"
def test_sleep(self) -> None:
def test_sleep(self):
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -526,7 +518,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d.called)
def test_explicit_wake(self) -> None:
def test_explicit_wake(self):
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -543,7 +535,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
def test_multiple_sleepers_timeout(self) -> None:
def test_multiple_sleepers_timeout(self):
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
@@ -563,7 +555,7 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
self.assertTrue(d2.called)
def test_multiple_sleepers_wake(self) -> None:
def test_multiple_sleepers_wake(self):
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)

View File

@@ -11,10 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable
@@ -30,7 +26,7 @@ from tests.unittest import TestCase
class BatchingQueueTestCase(TestCase):
def setUp(self) -> None:
def setUp(self):
self.clock, hs_clock = get_clock()
# We ensure that we remove any existing metrics for "test_queue".
@@ -41,27 +37,25 @@ class BatchingQueueTestCase(TestCase):
except KeyError:
pass
self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
self.queue: BatchingQueue[str, str] = BatchingQueue(
"test_queue", hs_clock, self._process_queue
)
self._pending_calls = []
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)
async def _process_queue(self, values: List[str]) -> str:
d: "defer.Deferred[str]" = defer.Deferred()
async def _process_queue(self, values):
d = defer.Deferred()
self._pending_calls.append((values, d))
return await make_deferred_yieldable(d)
def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
def _get_sample_with_name(self, metric, name) -> int:
"""For a prometheus metric get the value of the sample that has a
matching "name" label.
"""
for sample in next(iter(metric.collect())).samples:
for sample in metric.collect()[0].samples:
if sample.labels.get("name") == name:
return sample.value
self.fail("Found no matching sample")
def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
def _assert_metrics(self, queued, keys, in_flight):
"""Assert that the metrics are correct"""
sample = self._get_sample_with_name(number_queued, self.queue._name)
@@ -81,7 +75,7 @@ class BatchingQueueTestCase(TestCase):
"number_in_flight",
)
def test_simple(self) -> None:
def test_simple(self):
"""Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return.
"""
@@ -112,7 +106,7 @@ class BatchingQueueTestCase(TestCase):
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_batching(self) -> None:
def test_batching(self):
"""Test that multiple calls at the same time get batched up into one
call to `_process_queue`.
"""
@@ -140,7 +134,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d2), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_queuing(self) -> None:
def test_queuing(self):
"""Test that we queue up requests while a `_process_queue` is being
called.
"""
@@ -190,7 +184,7 @@ class BatchingQueueTestCase(TestCase):
self.assertEqual(self.successResultOf(queue_d3), "bar2")
self._assert_metrics(queued=0, keys=0, in_flight=0)
def test_different_keys(self) -> None:
def test_different_keys(self):
"""Test that calls to different keys get processed in parallel."""
self.assertFalse(self._pending_calls)

View File

@@ -1,20 +1,5 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from os import PathLike
from typing import Generator, Optional, Union
from typing import Generator, Optional
from unittest.mock import patch
from synapse.util.check_dependencies import (
@@ -27,17 +12,17 @@ from tests.unittest import TestCase
class DummyDistribution(metadata.Distribution):
def __init__(self, version: str):
def __init__(self, version: object):
self._version = version
@property
def version(self) -> str:
def version(self):
return self._version
def locate_file(self, path: Union[str, PathLike]) -> PathLike:
def locate_file(self, path):
raise NotImplementedError()
def read_text(self, filename: str) -> None:
def read_text(self, filename):
raise NotImplementedError()
@@ -45,7 +30,7 @@ old = DummyDistribution("0.1.2")
old_release_candidate = DummyDistribution("0.1.2rc3")
new = DummyDistribution("1.2.3")
new_release_candidate = DummyDistribution("1.2.3rc4")
distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type]
distribution_with_no_version = DummyDistribution(None)
# could probably use stdlib TestCase --- no need for twisted here
@@ -60,7 +45,7 @@ class TestDependencyChecker(TestCase):
If `distribution = None`, we pretend that the package is not installed.
"""
def mock_distribution(name: str) -> DummyDistribution:
def mock_distribution(name: str):
if distribution is None:
raise metadata.PackageNotFoundError
else:

View File

@@ -19,12 +19,10 @@ from tests import unittest
class DictCacheTestCase(unittest.TestCase):
def setUp(self) -> None:
self.cache: DictionaryCache[str, str, str] = DictionaryCache(
"foobar", max_entries=10
)
def setUp(self):
self.cache = DictionaryCache("foobar", max_entries=10)
def test_simple_cache_hit_full(self) -> None:
def test_simple_cache_hit_full(self):
key = "test_simple_cache_hit_full"
v = self.cache.get(key)
@@ -39,7 +37,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
def test_simple_cache_hit_partial(self) -> None:
def test_simple_cache_hit_partial(self):
key = "test_simple_cache_hit_partial"
seq = self.cache.sequence
@@ -49,7 +47,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
def test_simple_cache_miss_partial(self) -> None:
def test_simple_cache_miss_partial(self):
key = "test_simple_cache_miss_partial"
seq = self.cache.sequence
@@ -59,7 +57,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
def test_simple_cache_hit_miss_partial(self) -> None:
def test_simple_cache_hit_miss_partial(self):
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@@ -73,7 +71,7 @@ class DictCacheTestCase(unittest.TestCase):
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
def test_multi_insert(self) -> None:
def test_multi_insert(self):
key = "test_simple_cache_hit_miss_partial"
seq = self.cache.sequence
@@ -94,7 +92,7 @@ class DictCacheTestCase(unittest.TestCase):
)
self.assertEqual(c.full, False)
def test_invalidation(self) -> None:
def test_invalidation(self):
"""Test that the partial dict and full dicts get invalidated
separately.
"""
@@ -108,7 +106,7 @@ class DictCacheTestCase(unittest.TestCase):
# entry for "a" warm.
for i in range(20):
self.cache.get(key, ["a"])
self.cache.update(seq, f"key{i}", {"1": "2"})
self.cache.update(seq, f"key{i}", {1: 2})
# We should have evicted the full dict...
r = self.cache.get(key)

View File

@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, cast
from synapse.util import Clock
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
@@ -23,21 +21,17 @@ from .. import unittest
class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self) -> None:
def test_get_set(self):
clock = MockClock()
cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=1
)
cache = ExpiringCache("test", clock, max_len=1)
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
def test_eviction(self) -> None:
def test_eviction(self):
clock = MockClock()
cache: ExpiringCache[str, str] = ExpiringCache(
"test", cast(Clock, clock), max_len=2
)
cache = ExpiringCache("test", clock, max_len=2)
cache["key"] = "value"
cache["key2"] = "value2"
@@ -49,11 +43,9 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key2"), "value2")
self.assertEqual(cache.get("key3"), "value3")
def test_iterable_eviction(self) -> None:
def test_iterable_eviction(self):
clock = MockClock()
cache: ExpiringCache[str, List[int]] = ExpiringCache(
"test", cast(Clock, clock), max_len=5, iterable=True
)
cache = ExpiringCache("test", clock, max_len=5, iterable=True)
cache["key"] = [1]
cache["key2"] = [2, 3]
@@ -69,11 +61,9 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key3"), [4, 5])
self.assertEqual(cache.get("key4"), [6, 7])
def test_time_eviction(self) -> None:
def test_time_eviction(self):
clock = MockClock()
cache: ExpiringCache[str, int] = ExpiringCache(
"test", cast(Clock, clock), expiry_ms=1000
)
cache = ExpiringCache("test", clock, expiry_ms=1000)
cache["key"] = 1
clock.advance_time(0.5)

View File

@@ -12,28 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from io import BytesIO
from typing import BinaryIO, Generator, Optional, cast
from io import StringIO
from unittest.mock import NonCallableMock
from zope.interface import implementer
from twisted.internet import defer, reactor
from twisted.internet import defer, reactor as _reactor
from twisted.internet.interfaces import IPullProducer
from synapse.types import ISynapseReactor
from synapse.util.file_consumer import BackgroundFileConsumer
from tests import unittest
reactor = cast(ISynapseReactor, _reactor)
class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
string_file = BytesIO()
def test_pull_consumer(self):
string_file = StringIO()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
@@ -41,57 +35,55 @@ class FileConsumerTests(unittest.TestCase):
yield producer.register_with_consumer(consumer)
yield producer.write_and_wait(b"Foo")
yield producer.write_and_wait("Foo")
self.assertEqual(string_file.getvalue(), b"Foo")
self.assertEqual(string_file.getvalue(), "Foo")
yield producer.write_and_wait(b"Bar")
yield producer.write_and_wait("Bar")
self.assertEqual(string_file.getvalue(), b"FooBar")
self.assertEqual(string_file.getvalue(), "FooBar")
finally:
consumer.unregisterProducer()
yield consumer.wait() # type: ignore[misc]
yield consumer.wait()
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]:
string_file = BlockingBytesWrite()
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
def test_push_consumer(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
consumer.registerProducer(producer, True)
consumer.write(b"Foo")
yield string_file.wait_for_n_writes(1) # type: ignore[misc]
consumer.write("Foo")
yield string_file.wait_for_n_writes(1)
self.assertEqual(string_file.buffer, b"Foo")
self.assertEqual(string_file.buffer, "Foo")
consumer.write(b"Bar")
yield string_file.wait_for_n_writes(2) # type: ignore[misc]
consumer.write("Bar")
yield string_file.wait_for_n_writes(2)
self.assertEqual(string_file.buffer, b"FooBar")
self.assertEqual(string_file.buffer, "FooBar")
finally:
consumer.unregisterProducer()
yield consumer.wait() # type: ignore[misc]
yield consumer.wait()
self.assertTrue(string_file.closed)
@defer.inlineCallbacks
def test_push_producer_feedback(
self,
) -> Generator["defer.Deferred[object]", object, None]:
string_file = BlockingBytesWrite()
consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor)
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
resume_deferred: defer.Deferred = defer.Deferred()
resume_deferred = defer.Deferred()
producer.resumeProducing.side_effect = lambda: resume_deferred.callback(
None
)
@@ -101,72 +93,65 @@ class FileConsumerTests(unittest.TestCase):
number_writes = 0
with string_file.write_lock:
for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
consumer.write(b"Foo")
consumer.write("Foo")
number_writes += 1
producer.pauseProducing.assert_called_once()
yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc]
yield string_file.wait_for_n_writes(number_writes)
yield resume_deferred
producer.resumeProducing.assert_called_once()
finally:
consumer.unregisterProducer()
yield consumer.wait() # type: ignore[misc]
yield consumer.wait()
self.assertTrue(string_file.closed)
@implementer(IPullProducer)
class DummyPullProducer:
def __init__(self) -> None:
self.consumer: Optional[BackgroundFileConsumer] = None
self.deferred: "defer.Deferred[object]" = defer.Deferred()
def __init__(self):
self.consumer = None
self.deferred = defer.Deferred()
def resumeProducing(self) -> None:
def resumeProducing(self):
d = self.deferred
self.deferred = defer.Deferred()
d.callback(None)
def stopProducing(self) -> None:
raise RuntimeError("Unexpected call")
def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]":
assert self.consumer is not None
def write_and_wait(self, bytes):
d = self.deferred
self.consumer.write(write_bytes)
self.consumer.write(bytes)
return d
def register_with_consumer(
self, consumer: BackgroundFileConsumer
) -> "defer.Deferred[object]":
def register_with_consumer(self, consumer):
d = self.deferred
self.consumer = consumer
self.consumer.registerProducer(self, False)
return d
class BlockingBytesWrite:
def __init__(self) -> None:
self.buffer = b""
class BlockingStringWrite:
def __init__(self):
self.buffer = ""
self.closed = False
self.write_lock = threading.Lock()
self._notify_write_deferred: Optional[defer.Deferred] = None
self._notify_write_deferred = None
self._number_of_writes = 0
def write(self, write_bytes: bytes) -> None:
def write(self, bytes):
with self.write_lock:
self.buffer += write_bytes
self.buffer += bytes
self._number_of_writes += 1
reactor.callFromThread(self._notify_write)
def close(self) -> None:
def close(self):
self.closed = True
def _notify_write(self) -> None:
def _notify_write(self):
"Called by write to indicate a write happened"
with self.write_lock:
if not self._notify_write_deferred:
@@ -176,9 +161,7 @@ class BlockingBytesWrite:
d.callback(None)
@defer.inlineCallbacks
def wait_for_n_writes(
self, n: int
) -> Generator["defer.Deferred[object]", object, None]:
def wait_for_n_writes(self, n):
"Wait for n writes to have happened"
while True:
with self.write_lock:

View File

@@ -19,7 +19,7 @@ from tests.unittest import TestCase
class ChunkSeqTests(TestCase):
def test_short_seq(self) -> None:
def test_short_seq(self):
parts = chunk_seq("123", 8)
self.assertEqual(
@@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
["123"],
)
def test_long_seq(self) -> None:
def test_long_seq(self):
parts = chunk_seq("abcdefghijklmnop", 8)
self.assertEqual(
@@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
["abcdefgh", "ijklmnop"],
)
def test_uneven_parts(self) -> None:
def test_uneven_parts(self):
parts = chunk_seq("abcdefghijklmnop", 5)
self.assertEqual(
@@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
["abcde", "fghij", "klmno", "p"],
)
def test_empty_input(self) -> None:
def test_empty_input(self):
parts: Iterable[Sequence] = chunk_seq([], 5)
self.assertEqual(
@@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
class SortTopologically(TestCase):
def test_empty(self) -> None:
def test_empty(self):
"Test that an empty graph works correctly"
graph: Dict[int, List[int]] = {}
self.assertEqual(list(sorted_topologically([], graph)), [])
def test_handle_empty_graph(self) -> None:
def test_handle_empty_graph(self):
"Test that a graph where a node doesn't have an entry is treated as empty"
graph: Dict[int, List[int]] = {}
@@ -67,7 +67,7 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
def test_disconnected(self) -> None:
def test_disconnected(self):
"Test that a graph with no edges work"
graph: Dict[int, List[int]] = {1: [], 2: []}
@@ -75,20 +75,20 @@ class SortTopologically(TestCase):
# For disconnected nodes the output is simply sorted.
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
def test_linear(self) -> None:
def test_linear(self):
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_subset(self) -> None:
def test_subset(self):
"Test that only sorting a subset of the graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
def test_fork(self) -> None:
def test_fork(self):
"Test that a forked graph works"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
@@ -96,13 +96,13 @@ class SortTopologically(TestCase):
# always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_duplicates(self) -> None:
def test_duplicates(self):
"Test that a graph with duplicate edges work"
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_multiple_paths(self) -> None:
def test_multiple_paths(self):
"Test that a graph with multiple paths between two nodes work"
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}

View File

@@ -1,21 +1,5 @@
# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Generator, cast
import twisted.python.failure
from twisted.internet import defer, reactor as _reactor
from twisted.internet import defer, reactor
from synapse.logging.context import (
SENTINEL_CONTEXT,
@@ -26,30 +10,25 @@ from synapse.logging.context import (
nested_logging_context,
run_in_background,
)
from synapse.types import ISynapseReactor
from synapse.util import Clock
from .. import unittest
reactor = cast(ISynapseReactor, _reactor)
class LoggingContextTestCase(unittest.TestCase):
def _check_test_key(self, value: str) -> None:
context = current_context()
assert isinstance(context, LoggingContext)
self.assertEqual(context.name, value)
def _check_test_key(self, value):
self.assertEqual(current_context().name, value)
def test_with_context(self) -> None:
def test_with_context(self):
with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
def test_sleep(self):
clock = Clock(reactor)
@defer.inlineCallbacks
def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
def competing_callback():
with LoggingContext("competing"):
yield clock.sleep(0)
self._check_test_key("competing")
@@ -60,18 +39,17 @@ class LoggingContextTestCase(unittest.TestCase):
yield clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
def _test_run_in_background(self, function):
sentinel_context = current_context()
callback_completed = False
callback_completed = [False]
with LoggingContext("one"):
# fire off function, but don't wait on it.
d2 = run_in_background(function)
def cb(res: object) -> object:
nonlocal callback_completed
callback_completed = True
def cb(res):
callback_completed[0] = True
return res
d2.addCallback(cb)
@@ -82,8 +60,8 @@ class LoggingContextTestCase(unittest.TestCase):
# the logcontext is left in a sane state.
d2 = defer.Deferred()
def check_logcontext() -> None:
if not callback_completed:
def check_logcontext():
if not callback_completed[0]:
reactor.callLater(0.01, check_logcontext)
return
@@ -100,31 +78,31 @@ class LoggingContextTestCase(unittest.TestCase):
# test is done once d2 finishes
return d2
def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
def blocking_function():
yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred:
def test_run_in_background_with_non_blocking_fn(self):
@defer.inlineCallbacks
def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]:
def nonblocking_function():
with PreserveLoggingContext():
yield defer.succeed(None)
return self._test_run_in_background(nonblocking_function)
def test_run_in_background_with_chained_deferred(self) -> defer.Deferred:
def test_run_in_background_with_chained_deferred(self):
# a function which returns a deferred which looks like it has been
# called, but is actually paused
def testfunc() -> defer.Deferred:
def testfunc():
return make_deferred_yieldable(_chained_deferred_function())
return self._test_run_in_background(testfunc)
def test_run_in_background_with_coroutine(self) -> defer.Deferred:
async def testfunc() -> None:
def test_run_in_background_with_coroutine(self):
async def testfunc():
self._check_test_key("one")
d = Clock(reactor).sleep(0)
self.assertIs(current_context(), SENTINEL_CONTEXT)
@@ -133,20 +111,18 @@ class LoggingContextTestCase(unittest.TestCase):
return self._test_run_in_background(testfunc)
def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred:
async def testfunc() -> None:
def test_run_in_background_with_nonblocking_coroutine(self):
async def testfunc():
self._check_test_key("one")
return self._test_run_in_background(testfunc)
@defer.inlineCallbacks
def test_make_deferred_yieldable(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_make_deferred_yieldable(self):
# a function which returns an incomplete deferred, but doesn't follow
# the synapse rules.
def blocking_function() -> defer.Deferred:
d: defer.Deferred = defer.Deferred()
def blocking_function():
d = defer.Deferred()
reactor.callLater(0, d.callback, None)
return d
@@ -163,9 +139,7 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("one")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_chained_deferreds(
self,
) -> Generator["defer.Deferred[object]", object, None]:
def test_make_deferred_yieldable_with_chained_deferreds(self):
sentinel_context = current_context()
with LoggingContext("one"):
@@ -178,7 +152,7 @@ class LoggingContextTestCase(unittest.TestCase):
# now it should be restored
self._check_test_key("one")
def test_nested_logging_context(self) -> None:
def test_nested_logging_context(self):
with LoggingContext("foo"):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.name, "foo-bar")
@@ -187,11 +161,11 @@ class LoggingContextTestCase(unittest.TestCase):
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on
# its callback list, so won't yet call any other new callbacks.
def _chained_deferred_function() -> defer.Deferred:
def _chained_deferred_function():
d = defer.succeed(None)
def cb(res: object) -> defer.Deferred:
d2: defer.Deferred = defer.Deferred()
def cb(res):
d2 = defer.Deferred()
reactor.callLater(0, d2.callback, res)
return d2

View File

@@ -23,7 +23,7 @@ class TestException(Exception):
class LogFormatterTestCase(unittest.TestCase):
def test_formatter(self) -> None:
def test_formatter(self):
formatter = LogFormatter()
try:

View File

@@ -13,11 +13,10 @@
# limitations under the License.
from typing import List, Tuple
from typing import List
from unittest.mock import Mock, patch
from synapse.metrics.jemalloc import JemallocStats
from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries
from synapse.util.caches.treecache import TreeCache
@@ -26,14 +25,14 @@ from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self) -> None:
cache: LruCache[str, str] = LruCache(1)
def test_get_set(self):
cache = LruCache(1)
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
def test_eviction(self) -> None:
cache: LruCache[int, int] = LruCache(2)
def test_eviction(self):
cache = LruCache(2)
cache[1] = 1
cache[2] = 2
@@ -46,8 +45,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(2), 2)
self.assertEqual(cache.get(3), 3)
def test_setdefault(self) -> None:
cache: LruCache[str, int] = LruCache(1)
def test_setdefault(self):
cache = LruCache(1)
self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1)
@@ -55,15 +54,14 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache["key"] = 2 # Make sure overriding works.
self.assertEqual(cache.get("key"), 2)
def test_pop(self) -> None:
cache: LruCache[str, int] = LruCache(1)
def test_pop(self):
cache = LruCache(1)
cache["key"] = 1
self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None)
def test_del_multi(self) -> None:
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
def test_del_multi(self):
cache = LruCache(4, cache_type=TreeCache)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom"
@@ -73,7 +71,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("animal", "cat")), "mew")
self.assertEqual(cache.get(("vehicles", "car")), "vroom")
cache.del_multi(("animal",)) # type: ignore[arg-type]
cache.del_multi(("animal",))
self.assertEqual(len(cache), 2)
self.assertEqual(cache.get(("animal", "cat")), None)
self.assertEqual(cache.get(("animal", "dog")), None)
@@ -81,22 +79,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(("vehicles", "train")), "chuff")
# Man from del_multi say "Yes".
def test_clear(self) -> None:
cache: LruCache[str, int] = LruCache(1)
def test_clear(self):
cache = LruCache(1)
cache["key"] = 1
cache.clear()
self.assertEqual(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
def test_special_size(self) -> None:
cache: LruCache = LruCache(10, "mycache")
def test_special_size(self):
cache = LruCache(10, "mycache")
self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self) -> None:
def test_get(self):
m = Mock()
cache: LruCache[str, str] = LruCache(1)
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -113,9 +111,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
def test_multi_get(self) -> None:
def test_multi_get(self):
m = Mock()
cache: LruCache[str, str] = LruCache(1)
cache = LruCache(1)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -132,9 +130,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
def test_set(self) -> None:
def test_set(self):
m = Mock()
cache: LruCache[str, str] = LruCache(1)
cache = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -148,9 +146,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.set("key", "value")
self.assertEqual(m.call_count, 1)
def test_pop(self) -> None:
def test_pop(self):
m = Mock()
cache: LruCache[str, str] = LruCache(1)
cache = LruCache(1)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -164,13 +162,12 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
cache.pop("key")
self.assertEqual(m.call_count, 1)
def test_del_multi(self) -> None:
def test_del_multi(self):
m1 = Mock()
m2 = Mock()
m3 = Mock()
m4 = Mock()
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache)
cache = LruCache(4, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
@@ -182,17 +179,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
cache.del_multi(("a",)) # type: ignore[arg-type]
cache.del_multi(("a",))
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
self.assertEqual(m3.call_count, 0)
self.assertEqual(m4.call_count, 0)
def test_clear(self) -> None:
def test_clear(self):
m1 = Mock()
m2 = Mock()
cache: LruCache[str, str] = LruCache(5)
cache = LruCache(5)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -205,11 +202,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
self.assertEqual(m1.call_count, 1)
self.assertEqual(m2.call_count, 1)
def test_eviction(self) -> None:
def test_eviction(self):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache: LruCache[str, str] = LruCache(2)
cache = LruCache(2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -244,8 +241,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self) -> None:
cache: LruCache[str, List[int]] = LruCache(5, size_callback=len)
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
@@ -272,7 +269,6 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
cache["key1"] = []
self.assertEqual(len(cache), 0)
assert isinstance(cache.cache, dict)
cache.cache["key1"].drop_from_cache()
self.assertIsNone(
cache.pop("key1"), "Cache entry should have been evicted but wasn't"
@@ -282,17 +278,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
class TimeEvictionTestCase(unittest.HomeserverTestCase):
"""Test that time based eviction works correctly."""
def default_config(self) -> JsonDict:
def default_config(self):
config = super().default_config()
config.setdefault("caches", {})["expiry_time"] = "30m"
return config
def test_evict(self) -> None:
def test_evict(self):
setup_expire_lru_cache_entries(self.hs)
cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock())
cache = LruCache(5, clock=self.hs.get_clock())
# Check that we evict entries we haven't accessed for 30 minutes.
cache["key1"] = 1
@@ -336,7 +332,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
}
)
@patch("synapse.util.caches.lrucache.get_jemalloc_stats")
def test_evict_memory(self, jemalloc_interface: Mock) -> None:
def test_evict_memory(self, jemalloc_interface) -> None:
mock_jemalloc_class = Mock(spec=JemallocStats)
jemalloc_interface.return_value = mock_jemalloc_class
@@ -344,7 +340,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
mock_jemalloc_class.get_stat.return_value = 924288000
setup_expire_lru_cache_entries(self.hs)
cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock())
cache = LruCache(4, clock=self.hs.get_clock())
cache["key1"] = 1
cache["key2"] = 2

View File

@@ -21,14 +21,14 @@ from tests.unittest import TestCase
class MacaroonGeneratorTestCase(TestCase):
def setUp(self) -> None:
def setUp(self):
self.reactor, hs_clock = get_clock()
self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret")
self.other_macaroon_generator = MacaroonGenerator(
hs_clock, "tesths", b"anothersecretkey"
)
def test_guest_access_token(self) -> None:
def test_guest_access_token(self):
"""Test the generation and verification of guest access tokens"""
token = self.macaroon_generator.generate_guest_access_token("@user:tesths")
user_id = self.macaroon_generator.verify_guest_token(token)
@@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase):
with self.assertRaises(MacaroonVerificationFailedException):
self.macaroon_generator.verify_guest_token(token)
def test_delete_pusher_token(self) -> None:
def test_delete_pusher_token(self):
"""Test the generation and verification of delete_pusher tokens"""
token = self.macaroon_generator.generate_delete_pusher_token(
"@user:tesths", "m.mail", "john@example.com"
@@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase):
)
self.assertEqual(user_id, "@user:tesths")
def test_oidc_session_token(self) -> None:
def test_oidc_session_token(self):
"""Test the generation and verification of OIDC session cookies"""
state = "arandomstate"
session_data = OidcSessionData(

View File

@@ -13,19 +13,16 @@
# limitations under the License.
from typing import Optional
from twisted.internet.defer import Deferred
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import ThreadedMemoryReactorClock, get_clock
from tests.server import get_clock
from tests.unittest import TestCase
from tests.utils import default_config
class FederationRateLimiterTestCase(TestCase):
def test_ratelimit(self) -> None:
def test_ratelimit(self):
"""A simple test with the default values"""
reactor, clock = get_clock()
rc_config = build_rc_config()
@@ -35,7 +32,7 @@ class FederationRateLimiterTestCase(TestCase):
# shouldn't block
self.successResultOf(d1)
def test_concurrent_limit(self) -> None:
def test_concurrent_limit(self):
"""Test what happens when we hit the concurrent limit"""
reactor, clock = get_clock()
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
@@ -59,7 +56,7 @@ class FederationRateLimiterTestCase(TestCase):
cm2.__exit__(None, None, None)
self.successResultOf(d3)
def test_sleep_limit(self) -> None:
def test_sleep_limit(self):
"""Test what happens when we hit the sleep limit"""
reactor, clock = get_clock()
rc_config = build_rc_config(
@@ -82,7 +79,7 @@ class FederationRateLimiterTestCase(TestCase):
self.assertAlmostEqual(sleep_time, 500, places=3)
def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float:
def _await_resolution(reactor, d):
"""advance the clock until the deferred completes.
Returns the number of milliseconds it took to complete.
@@ -93,7 +90,7 @@ def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float
return (reactor.seconds() - start_time) * 1000
def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings:
def build_rc_config(settings: Optional[dict] = None):
config_dict = default_config("test")
config_dict.update(settings or {})
config = HomeServerConfig()

View File

@@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self) -> None:
def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastores().main
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
@@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
def test_limiter(self) -> None:
def test_limiter(self):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastores().main

View File

@@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
acquired_d: "Deferred[None]" = Deferred()
unblock_d: "Deferred[None]" = Deferred()
async def reader_or_writer() -> str:
async def reader_or_writer():
async with read_or_write(key):
acquired_d.callback(None)
await unblock_d
@@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
)
def test_rwlock(self) -> None:
def test_rwlock(self):
rwlock = ReadWriteLock()
key = "key"
@@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
self.assertTrue(acquired_d.called)
def test_lock_handoff_to_nonblocking_writer(self) -> None:
def test_lock_handoff_to_nonblocking_writer(self):
"""Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock()
key = "key"
@@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
self.assertTrue(d3.called)
def test_cancellation_while_holding_read_lock(self) -> None:
def test_cancellation_while_holding_read_lock(self):
"""Test cancellation while holding a read lock.
A waiting writer should be given the lock when the reader holding the lock is
@@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write completed", self.successResultOf(writer_d))
def test_cancellation_while_holding_write_lock(self) -> None:
def test_cancellation_while_holding_write_lock(self):
"""Test cancellation while holding a write lock.
A waiting reader should be given the lock when the writer holding the lock is
@@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("read completed", self.successResultOf(reader_d))
def test_cancellation_while_waiting_for_read_lock(self) -> None:
def test_cancellation_while_waiting_for_read_lock(self):
"""Test cancellation while waiting for a read lock.
Tests that cancelling a waiting reader:
@@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
)
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
def test_cancellation_while_waiting_for_write_lock(self) -> None:
def test_cancellation_while_waiting_for_write_lock(self):
"""Test cancellation while waiting for a write lock.
Tests that cancelling a waiting writer:

View File

@@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
Tests for StreamChangeCache.
"""
def test_prefilled_cache(self) -> None:
def test_prefilled_cache(self):
"""
Providing a prefilled cache to StreamChangeCache will result in a cache
with the prefilled-cache entered in.
@@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
def test_has_entity_changed(self) -> None:
def test_has_entity_changed(self):
"""
StreamChangeCache.entity_has_changed will mark entities as changed, and
has_entity_changed will observe the changed entities.
@@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
def test_entity_has_changed_pops_off_start(self) -> None:
def test_entity_has_changed_pops_off_start(self):
"""
StreamChangeCache.entity_has_changed will respect the max size and
purge the oldest items upon reaching that max size.
@@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
)
self.assertIsNone(cache.get_all_entities_changed(1))
def test_get_all_entities_changed(self) -> None:
def test_get_all_entities_changed(self):
"""
StreamChangeCache.get_all_entities_changed will return all changed
entities since the given position. If the position is before the start
@@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
r = cache.get_all_entities_changed(3)
self.assertTrue(r == ok1 or r == ok2)
def test_has_any_entity_changed(self) -> None:
def test_has_any_entity_changed(self):
"""
StreamChangeCache.has_any_entity_changed will return True if any
entities have been changed since the provided stream position, and
@@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3))
def test_get_entities_changed(self) -> None:
def test_get_entities_changed(self):
"""
StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the
@@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
{"bar@baz.net"},
)
def test_max_pos(self) -> None:
def test_max_pos(self):
"""
StreamChangeCache.get_max_pos_of_last_change will return the most
recent point where the entity could have changed. If the entity is not

View File

@@ -19,7 +19,7 @@ from .. import unittest
class StringUtilsTestCase(unittest.TestCase):
def test_client_secret_regex(self) -> None:
def test_client_secret_regex(self):
"""Ensure that client_secret does not contain illegal characters"""
good = [
"abcde12345",
@@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
with self.assertRaises(SynapseError):
assert_valid_client_secret(client_secret)
def test_base62_encode(self) -> None:
def test_base62_encode(self):
self.assertEqual("0", base62_encode(0))
self.assertEqual("10", base62_encode(62))
self.assertEqual("1c", base62_encode(100))

View File

@@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
class CanonicaliseEmailTests(HomeserverTestCase):
def test_no_at(self) -> None:
def test_no_at(self):
with self.assertRaises(ValueError):
canonicalise_email("address-without-at.bar")
def test_two_at(self) -> None:
def test_two_at(self):
with self.assertRaises(ValueError):
canonicalise_email("foo@foo@test.bar")
def test_bad_format(self) -> None:
def test_bad_format(self):
with self.assertRaises(ValueError):
canonicalise_email("user@bad.example.net@good.example.com")
def test_valid_format(self) -> None:
def test_valid_format(self):
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
def test_domain_to_lower(self) -> None:
def test_domain_to_lower(self):
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
def test_domain_with_umlaut(self) -> None:
def test_domain_with_umlaut(self):
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
def test_address_casefold(self) -> None:
def test_address_casefold(self):
self.assertEqual(
canonicalise_email("Strauß@Example.com"), "strauss@example.com"
)
def test_address_trim(self) -> None:
def test_address_trim(self):
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")

View File

@@ -19,7 +19,7 @@ from .. import unittest
class TreeCacheTestCase(unittest.TestCase):
def test_get_set_onelevel(self) -> None:
def test_get_set_onelevel(self):
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 2)
def test_pop_onelevel(self) -> None:
def test_pop_onelevel(self):
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
@@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b",)), "B")
self.assertEqual(len(cache), 1)
def test_get_set_twolevel(self) -> None:
def test_get_set_twolevel(self):
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.get(("b", "a")), "BA")
self.assertEqual(len(cache), 3)
def test_pop_twolevel(self) -> None:
def test_pop_twolevel(self):
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual(cache.pop(("b", "a")), None)
self.assertEqual(len(cache), 1)
def test_pop_mixedlevel(self) -> None:
def test_pop_mixedlevel(self):
cache = TreeCache()
cache[("a", "a")] = "AA"
cache[("a", "b")] = "AB"
@@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
def test_clear(self) -> None:
def test_clear(self):
cache = TreeCache()
cache[("a",)] = "A"
cache[("b",)] = "B"
cache.clear()
self.assertEqual(len(cache), 0)
def test_contains(self) -> None:
def test_contains(self):
cache = TreeCache()
cache[("a",)] = "A"
self.assertTrue(("a",) in cache)

View File

@@ -18,8 +18,8 @@ from .. import unittest
class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
def test_single_insert_fetch(self):
wheel = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 150)
@@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(156), [obj])
self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
def test_multi_insert(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()
@@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(200), [obj3])
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
def test_insert_past(self):
wheel = WheelTimer(bucket_size=5)
obj = object()
wheel.insert(100, obj, 50)
self.assertListEqual(wheel.fetch(120), [obj])
def test_insert_past_multi(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
def test_insert_past_multi(self):
wheel = WheelTimer(bucket_size=5)
obj1 = object()
obj2 = object()

View File

@@ -125,8 +125,7 @@ def default_config(
"""
config_dict = {
"server_name": name,
# Setting this to an empty list turns off federation sending.
"federation_sender_instances": [],
"send_federation": False,
"media_store_path": "media",
# the test signing key is just an arbitrary ed25519 key to keep the config
# parser happy
@@ -184,9 +183,8 @@ def default_config(
# rooms will fail.
"default_room_version": DEFAULT_ROOM_VERSION,
# disable user directory updates, because they get done in the
# background, which upsets the test runner. Setting this to an
# (obviously) fake worker name disables updating the user directory.
"update_user_directory_from_worker": "does_not_exist_worker_name",
# background, which upsets the test runner.
"update_user_directory": False,
"caches": {"global_factor": 1, "sync_response_cache_duration": 0},
"listeners": [{"port": 0, "type": "http"}],
}