Compare commits

..

7 Commits

Author SHA1 Message Date
David Robertson
0953cad3e4 docstringggggg 2022-05-18 10:19:30 +01:00
David Robertson
ad6a6675bf go away flake8 2022-05-17 14:14:40 +01:00
David Robertson
21d1347f2c Changelog 2022-05-17 14:00:56 +01:00
David Robertson
a9fe3350f8 Drive-by typo fix I spotted while debugging 2022-05-17 13:55:58 +01:00
David Robertson
a1adede444 discard strings with nulls before user dir update 2022-05-17 13:55:58 +01:00
David Robertson
79f1cef5e4 Move non_null_str_or_none to s.utils.stringutils 2022-05-17 13:55:58 +01:00
David Robertson
8c977edec8 Reproduce #12755 2022-05-17 13:55:44 +01:00
68 changed files with 402 additions and 1917 deletions

View File

@@ -1,14 +1,3 @@
Synapse 1.59.1 (2022-05-18)
===========================
This release fixes a long-standing issue which could prevent Synapse's user directory for updating properly.
Bugfixes
----------------
- Fix a long-standing bug where the user directory background process would fail to make forward progress if a user included a null codepoint in their display name or avatar. Contributed by Nick @ Beeper. ([\#12762](https://github.com/matrix-org/synapse/issues/12762))
Synapse 1.59.0 (2022-05-17)
===========================

View File

@@ -1 +0,0 @@
Preparation for faster-room-join work: return subsets of room state which we already have, immediately.

View File

@@ -1 +0,0 @@
Add support for [MSC3787: Allowing knocks to restricted rooms](https://github.com/matrix-org/matrix-spec-proposals/pull/3787).

View File

@@ -1 +0,0 @@
Remove code which updates unused database column `application_services_state.last_txn`.

View File

@@ -1 +0,0 @@
Add some type hints to datastore.

View File

@@ -1 +0,0 @@
Fix push to dismiss notifications when read on another client. Contributed by @SpiritCroc @ Beeper.

View File

@@ -1 +0,0 @@
Downgrade some OIDC errors to warnings in the logs, to reduce the noise of Sentry reports.

View File

@@ -1 +0,0 @@
Link to the configuration manual from the welcome page of the documentation.

View File

@@ -1 +0,0 @@
Add some type hints to datastore.

View File

@@ -1 +0,0 @@
Add information regarding the `rc_invites` ratelimiting option to the configuration docs.

View File

@@ -1 +0,0 @@
Add documentation for cancellation of request processing.

View File

@@ -1 +0,0 @@
Recommend using docker to run tests against postgres.

View File

@@ -1 +0,0 @@
Tweak the mypy plugin so that `@cached` can accept `on_invalidate=None`.

View File

@@ -1 +0,0 @@
Delete events from the `federation_inbound_events_staging` table when a room is purged through the admin API.

View File

@@ -1 +0,0 @@
Move methods that call `add_push_rule` to the `PushRuleStore` class.

View File

@@ -1 +0,0 @@
Make handling of federation Authorization header (more) compliant with RFC7230.

View File

@@ -1 +0,0 @@
Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.

View File

@@ -1 +0,0 @@
Give a meaningful error message when a client tries to create a room with an invalid alias localpart.

View File

@@ -1 +0,0 @@
Do not keep going if there are 5 back-to-back background update failures.

View File

@@ -1 +0,0 @@
Fix federation when using the demo scripts.

View File

@@ -1 +0,0 @@
Fix invalid YAML syntax in the example documentation for the `url_preview_accept_language` config option.

6
debian/changelog vendored
View File

@@ -1,9 +1,3 @@
matrix-synapse-py3 (1.59.1) stable; urgency=medium
* New Synapse release 1.59.1.
-- Synapse Packaging team <packages@matrix.org> Wed, 18 May 2022 11:41:46 +0100
matrix-synapse-py3 (1.59.0) stable; urgency=medium
* New Synapse release 1.59.0.

View File

@@ -12,7 +12,6 @@ export PYTHONPATH
echo "$PYTHONPATH"
# Create servers which listen on HTTP at 808x and HTTPS at 848x.
for port in 8080 8081 8082; do
echo "Starting server on port $port... "
@@ -20,12 +19,10 @@ for port in 8080 8081 8082; do
mkdir -p demo/$port
pushd demo/$port || exit
# Generate the configuration for the homeserver at localhost:848x, note that
# the homeserver name needs to match the HTTPS listening port for federation
# to properly work..
# Generate the configuration for the homeserver at localhost:848x.
python3 -m synapse.app.homeserver \
--generate-config \
--server-name "localhost:$https_port" \
--server-name "localhost:$port" \
--config-path "$port.config" \
--report-stats no

View File

@@ -89,7 +89,6 @@
- [Database Schemas](development/database_schema.md)
- [Experimental features](development/experimental_features.md)
- [Synapse Architecture]()
- [Cancellation](development/synapse_architecture/cancellation.md)
- [Log Contexts](log_contexts.md)
- [Replication](replication.md)
- [TCP Replication](tcp_replication.md)

View File

@@ -206,32 +206,7 @@ This means that we need to run our unit tests against PostgreSQL too. Our CI doe
this automatically for pull requests and release candidates, but it's sometimes
useful to reproduce this locally.
#### Using Docker
The easiest way to do so is to run Postgres via a docker container. In one
terminal:
```shell
docker run --rm -e POSTGRES_PASSWORD=mysecretpassword -e POSTGRES_USER=postgres -e POSTGRES_DB=postgress -p 5432:5432 postgres:14
```
If you see an error like
```
docker: Error response from daemon: driver failed programming external connectivity on endpoint nice_ride (b57bbe2e251b70015518d00c9981e8cb8346b5c785250341a6c53e3c899875f1): Error starting userland proxy: listen tcp4 0.0.0.0:5432: bind: address already in use.
```
then something is already bound to port 5432. You're probably already running postgres locally.
Once you have a postgres server running, invoke `trial` in a second terminal:
```shell
SYNAPSE_POSTGRES=1 SYNAPSE_POSTGRES_HOST=127.0.0.1 SYNAPSE_POSTGRES_USER=postgres SYNAPSE_POSTGRES_PASSWORD=mysecretpassword poetry run trial tests
````
#### Using an existing Postgres installation
If you have postgres already installed on your system, you can run `trial` with the
To do so, [configure Postgres](../postgres.md) and run `trial` with the
following environment variables matching your configuration:
- `SYNAPSE_POSTGRES` to anything nonempty
@@ -254,8 +229,8 @@ You don't need to specify the host, user, port or password if your Postgres
server is set to authenticate you over the UNIX socket (i.e. if the `psql` command
works without further arguments).
Your Postgres account needs to be able to create databases; see the postgres
docs for [`ALTER ROLE`](https://www.postgresql.org/docs/current/sql-alterrole.html).
Your Postgres account needs to be able to create databases.
## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)).

View File

@@ -5,7 +5,7 @@
Requires you to have a [Synapse development environment setup](https://matrix-org.github.io/synapse/develop/development/contributing_guide.html#4-install-the-dependencies).
The demo setup allows running three federation Synapse servers, with server
names `localhost:8480`, `localhost:8481`, and `localhost:8482`.
names `localhost:8080`, `localhost:8081`, and `localhost:8082`.
You can access them via any Matrix client over HTTP at `localhost:8080`,
`localhost:8081`, and `localhost:8082` or over HTTPS at `localhost:8480`,
@@ -20,10 +20,9 @@ and the servers are configured in a highly insecure way, including:
The servers are configured to store their data under `demo/8080`, `demo/8081`, and
`demo/8082`. This includes configuration, logs, SQLite databases, and media.
Note that when joining a public room on a different homeserver via "#foo:bar.net",
then you are (in the current implementation) joining a room with room_id "foo".
This means that it won't work if your homeserver already has a room with that
name.
Note that when joining a public room on a different HS via "#foo:bar.net", then
you are (in the current impl) joining a room with room_id "foo". This means that
it won't work if your HS already has a room with that name.
## Using the demo scripts

View File

@@ -1,392 +0,0 @@
# Cancellation
Sometimes, requests take a long time to service and clients disconnect
before Synapse produces a response. To avoid wasting resources, Synapse
can cancel request processing for select endpoints marked with the
`@cancellable` decorator.
Synapse makes use of Twisted's `Deferred.cancel()` feature to make
cancellation work. The `@cancellable` decorator does nothing by itself
and merely acts as a flag, signalling to developers and other code alike
that a method can be cancelled.
## Enabling cancellation for an endpoint
1. Check that the endpoint method, and any `async` functions in its call
tree handle cancellation correctly. See
[Handling cancellation correctly](#handling-cancellation-correctly)
for a list of things to look out for.
2. Add the `@cancellable` decorator to the `on_GET/POST/PUT/DELETE`
method. It's not recommended to make non-`GET` methods cancellable,
since cancellation midway through some database updates is less
likely to be handled correctly.
## Mechanics
There are two stages to cancellation: downward propagation of a
`cancel()` call, followed by upwards propagation of a `CancelledError`
out of a blocked `await`.
Both Twisted and asyncio have a cancellation mechanism.
| | Method | Exception | Exception inherits from |
|---------------|---------------------|-----------------------------------------|-------------------------|
| Twisted | `Deferred.cancel()` | `twisted.internet.defer.CancelledError` | `Exception` (!) |
| asyncio | `Task.cancel()` | `asyncio.CancelledError` | `BaseException` |
### Deferred.cancel()
When Synapse starts handling a request, it runs the async method
responsible for handling it using `defer.ensureDeferred`, which returns
a `Deferred`. For example:
```python
def do_something() -> Deferred[None]:
...
@cancellable
async def on_GET() -> Tuple[int, JsonDict]:
d = make_deferred_yieldable(do_something())
await d
return 200, {}
request = defer.ensureDeferred(on_GET())
```
When a client disconnects early, Synapse checks for the presence of the
`@cancellable` decorator on `on_GET`. Since `on_GET` is cancellable,
`Deferred.cancel()` is called on the `Deferred` from
`defer.ensureDeferred`, ie. `request`. Twisted knows which `Deferred`
`request` is waiting on and passes the `cancel()` call on to `d`.
The `Deferred` being waited on, `d`, may have its own handling for
`cancel()` and pass the call on to other `Deferred`s.
Eventually, a `Deferred` handles the `cancel()` call by resolving itself
with a `CancelledError`.
### CancelledError
The `CancelledError` gets raised out of the `await` and bubbles up, as
per normal Python exception handling.
## Handling cancellation correctly
In general, when writing code that might be subject to cancellation, two
things must be considered:
* The effect of `CancelledError`s raised out of `await`s.
* The effect of `Deferred`s being `cancel()`ed.
Examples of code that handles cancellation incorrectly include:
* `try-except` blocks which swallow `CancelledError`s.
* Code that shares the same `Deferred`, which may be cancelled, between
multiple requests.
* Code that starts some processing that's exempt from cancellation, but
uses a logging context from cancellable code. The logging context
will be finished upon cancellation, while the uncancelled processing
is still using it.
Some common patterns are listed below in more detail.
### `async` function calls
Most functions in Synapse are relatively straightforward from a
cancellation standpoint: they don't do anything with `Deferred`s and
purely call and `await` other `async` functions.
An `async` function handles cancellation correctly if its own code
handles cancellation correctly and all the async function it calls
handle cancellation correctly. For example:
```python
async def do_two_things() -> None:
check_something()
await do_something()
await do_something_else()
```
`do_two_things` handles cancellation correctly if `do_something` and
`do_something_else` handle cancellation correctly.
That is, when checking whether a function handles cancellation
correctly, its implementation and all its `async` function calls need to
be checked, recursively.
As `check_something` is not `async`, it does not need to be checked.
### CancelledErrors
Because Twisted's `CancelledError`s are `Exception`s, it's easy to
accidentally catch and suppress them. Care must be taken to ensure that
`CancelledError`s are allowed to propagate upwards.
<table width="100%">
<tr>
<td width="50%" valign="top">
**Bad**:
```python
try:
await do_something()
except Exception:
# `CancelledError` gets swallowed here.
logger.info(...)
```
</td>
<td width="50%" valign="top">
**Good**:
```python
try:
await do_something()
except CancelledError:
raise
except Exception:
logger.info(...)
```
</td>
</tr>
<tr>
<td width="50%" valign="top">
**OK**:
```python
try:
check_something()
# A `CancelledError` won't ever be raised here.
except Exception:
logger.info(...)
```
</td>
<td width="50%" valign="top">
**Good**:
```python
try:
await do_something()
except ValueError:
logger.info(...)
```
</td>
</tr>
</table>
#### defer.gatherResults
`defer.gatherResults` produces a `Deferred` which:
* broadcasts `cancel()` calls to every `Deferred` being waited on.
* wraps the first exception it sees in a `FirstError`.
Together, this means that `CancelledError`s will be wrapped in
a `FirstError` unless unwrapped. Such `FirstError`s are liable to be
swallowed, so they must be unwrapped.
<table width="100%">
<tr>
<td width="50%" valign="top">
**Bad**:
```python
async def do_something() -> None:
await make_deferred_yieldable(
defer.gatherResults([...], consumeErrors=True)
)
try:
await do_something()
except CancelledError:
raise
except Exception:
# `FirstError(CancelledError)` gets swallowed here.
logger.info(...)
```
</td>
<td width="50%" valign="top">
**Good**:
```python
async def do_something() -> None:
await make_deferred_yieldable(
defer.gatherResults([...], consumeErrors=True)
).addErrback(unwrapFirstError)
try:
await do_something()
except CancelledError:
raise
except Exception:
logger.info(...)
```
</td>
</tr>
</table>
### Creation of `Deferred`s
If a function creates a `Deferred`, the effect of cancelling it must be considered. `Deferred`s that get shared are likely to have unintended behaviour when cancelled.
<table width="100%">
<tr>
<td width="50%" valign="top">
**Bad**:
```python
cache: Dict[str, Deferred[None]] = {}
def wait_for_room(room_id: str) -> Deferred[None]:
deferred = cache.get(room_id)
if deferred is None:
deferred = Deferred()
cache[room_id] = deferred
# `deferred` can have multiple waiters.
# All of them will observe a `CancelledError`
# if any one of them is cancelled.
return make_deferred_yieldable(deferred)
# Request 1
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
# Request 2
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
```
</td>
<td width="50%" valign="top">
**Good**:
```python
cache: Dict[str, Deferred[None]] = {}
def wait_for_room(room_id: str) -> Deferred[None]:
deferred = cache.get(room_id)
if deferred is None:
deferred = Deferred()
cache[room_id] = deferred
# `deferred` will never be cancelled now.
# A `CancelledError` will still come out of
# the `await`.
# `delay_cancellation` may also be used.
return make_deferred_yieldable(stop_cancellation(deferred))
# Request 1
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
# Request 2
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
```
</td>
</tr>
<tr>
<td width="50%" valign="top">
</td>
<td width="50%" valign="top">
**Good**:
```python
cache: Dict[str, List[Deferred[None]]] = {}
def wait_for_room(room_id: str) -> Deferred[None]:
if room_id not in cache:
cache[room_id] = []
# Each request gets its own `Deferred` to wait on.
deferred = Deferred()
cache[room_id]].append(deferred)
return make_deferred_yieldable(deferred)
# Request 1
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
# Request 2
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
```
</td>
</table>
### Uncancelled processing
Some `async` functions may kick off some `async` processing which is
intentionally protected from cancellation, by `stop_cancellation` or
other means. If the `async` processing inherits the logcontext of the
request which initiated it, care must be taken to ensure that the
logcontext is not finished before the `async` processing completes.
<table width="100%">
<tr>
<td width="50%" valign="top">
**Bad**:
```python
cache: Optional[ObservableDeferred[None]] = None
async def do_something_else(
to_resolve: Deferred[None]
) -> None:
await ...
logger.info("done!")
to_resolve.callback(None)
async def do_something() -> None:
if not cache:
to_resolve = Deferred()
cache = ObservableDeferred(to_resolve)
# `do_something_else` will never be cancelled and
# can outlive the `request-1` logging context.
run_in_background(do_something_else, to_resolve)
await make_deferred_yieldable(cache.observe())
with LoggingContext("request-1"):
await do_something()
```
</td>
<td width="50%" valign="top">
**Good**:
```python
cache: Optional[ObservableDeferred[None]] = None
async def do_something_else(
to_resolve: Deferred[None]
) -> None:
await ...
logger.info("done!")
to_resolve.callback(None)
async def do_something() -> None:
if not cache:
to_resolve = Deferred()
cache = ObservableDeferred(to_resolve)
run_in_background(do_something_else, to_resolve)
# We'll wait until `do_something_else` is
# done before raising a `CancelledError`.
await make_deferred_yieldable(
delay_cancellation(cache.observe())
)
else:
await make_deferred_yieldable(cache.observe())
with LoggingContext("request-1"):
await do_something()
```
</td>
</tr>
<tr>
<td width="50%">
**OK**:
```python
cache: Optional[ObservableDeferred[None]] = None
async def do_something_else(
to_resolve: Deferred[None]
) -> None:
await ...
logger.info("done!")
to_resolve.callback(None)
async def do_something() -> None:
if not cache:
to_resolve = Deferred()
cache = ObservableDeferred(to_resolve)
# `do_something_else` will get its own independent
# logging context. `request-1` will not count any
# metrics from `do_something_else`.
run_as_background_process(
"do_something_else",
do_something_else,
to_resolve,
)
await make_deferred_yieldable(cache.observe())
with LoggingContext("request-1"):
await do_something()
```
</td>
<td width="50%">
</td>
</tr>
</table>

View File

@@ -1194,7 +1194,7 @@ For more information on using Synapse with Postgres,
see [here](../../postgres.md).
Example SQLite configuration:
```yaml
```
database:
name: sqlite3
args:
@@ -1202,7 +1202,7 @@ database:
```
Example Postgres configuration:
```yaml
```
database:
name: psycopg2
txn_limit: 10000
@@ -1357,20 +1357,6 @@ This option sets ratelimiting how often invites can be sent in a room or to a
specific user. `per_room` defaults to `per_second: 0.3`, `burst_count: 10` and
`per_user` defaults to `per_second: 0.003`, `burst_count: 5`.
Client requests that invite user(s) when [creating a
room](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom)
will count against the `rc_invites.per_room` limit, whereas
client requests to [invite a single user to a
room](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite)
will count against both the `rc_invites.per_user` and `rc_invites.per_room` limits.
Federation requests to invite a user will count against the `rc_invites.per_user`
limit only, as Synapse presumes ratelimiting by room will be done by the sending server.
The `rc_invites.per_user` limit applies to the *receiver* of the invite, rather than the
sender, meaning that a `rc_invite.per_user.burst_count` of 5 mandates that a single user
cannot *receive* more than a burst of 5 invites at a time.
Example configuration:
```yaml
rc_invites:
@@ -1679,10 +1665,10 @@ Defaults to "en".
Example configuration:
```yaml
url_preview_accept_language:
- 'en-UK'
- 'en-US;q=0.9'
- 'fr;q=0.8'
- '*;q=0.7'
- en-UK
- en-US;q=0.9
- fr;q=0.8
- *;q=0.7
```
----
Config option: `oembed`

View File

@@ -7,10 +7,10 @@ team.
## Installing and using Synapse
This documentation covers topics for **installation**, **configuration** and
**maintenance** of your Synapse process:
**maintainence** of your Synapse process:
* Learn how to [install](setup/installation.md) and
[configure](usage/configuration/config_documentation.md) your own instance, perhaps with [Single
[configure](usage/configuration/index.html) your own instance, perhaps with [Single
Sign-On](usage/configuration/user_authentication/index.html).
* See how to [upgrade](upgrade.md) between Synapse versions.
@@ -65,7 +65,7 @@ following documentation:
Want to help keep Synapse going but don't know how to code? Synapse is a
[Matrix.org Foundation](https://matrix.org) project. Consider becoming a
supporter on [Liberapay](https://liberapay.com/matrixdotorg),
supportor on [Liberapay](https://liberapay.com/matrixdotorg),
[Patreon](https://patreon.com/matrixdotorg) or through
[PayPal](https://paypal.me/matrixdotorg) via a one-time donation.

View File

@@ -1,6 +1,6 @@
[mypy]
namespace_packages = True
plugins = pydantic.mypy, mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = normal
check_untyped_defs = True
show_error_codes = True
@@ -27,6 +27,9 @@ exclude = (?x)
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/databases/main/devices.py
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/push_rule.py
|synapse/storage/databases/main/roommember.py
|synapse/storage/schema/
|tests/api/test_auth.py
@@ -86,12 +89,6 @@ exclude = (?x)
|tests/utils.py
)$
[pydantic-mypy]
init_forbid_extra = True
init_typed = True
warn_required_dynamic_aliases = True
warn_untyped_fields = True
[mypy-synapse._scripts.*]
disallow_untyped_defs = True

54
poetry.lock generated
View File

@@ -778,21 +778,6 @@ category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[[package]]
name = "pydantic"
version = "1.9.0"
description = "Data validation and settings management using python 3.6 type hinting"
category = "main"
optional = false
python-versions = ">=3.6.1"
[package.dependencies]
typing-extensions = ">=3.7.4.3"
[package.extras]
dotenv = ["python-dotenv (>=0.10.4)"]
email = ["email-validator (>=1.0.3)"]
[[package]]
name = "pyflakes"
version = "2.4.0"
@@ -1578,7 +1563,7 @@ url_preview = ["lxml"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7.1"
content-hash = "54ec27d5187386653b8d0d13ed843f86ae68b3ebbee633c82dfffc7605b99f74"
content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1"
[metadata.files]
attrs = [
@@ -2266,43 +2251,6 @@ pycparser = [
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
]
pydantic = [
{file = "pydantic-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cb23bcc093697cdea2708baae4f9ba0e972960a835af22560f6ae4e7e47d33f5"},
{file = "pydantic-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1d5278bd9f0eee04a44c712982343103bba63507480bfd2fc2790fa70cd64cf4"},
{file = "pydantic-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab624700dc145aa809e6f3ec93fb8e7d0f99d9023b713f6a953637429b437d37"},
{file = "pydantic-1.9.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8d7da6f1c1049eefb718d43d99ad73100c958a5367d30b9321b092771e96c25"},
{file = "pydantic-1.9.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3c3b035103bd4e2e4a28da9da7ef2fa47b00ee4a9cf4f1a735214c1bcd05e0f6"},
{file = "pydantic-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3011b975c973819883842c5ab925a4e4298dffccf7782c55ec3580ed17dc464c"},
{file = "pydantic-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:086254884d10d3ba16da0588604ffdc5aab3f7f09557b998373e885c690dd398"},
{file = "pydantic-1.9.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:0fe476769acaa7fcddd17cadd172b156b53546ec3614a4d880e5d29ea5fbce65"},
{file = "pydantic-1.9.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8e9dcf1ac499679aceedac7e7ca6d8641f0193c591a2d090282aaf8e9445a46"},
{file = "pydantic-1.9.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e4c28f30e767fd07f2ddc6f74f41f034d1dd6bc526cd59e63a82fe8bb9ef4c"},
{file = "pydantic-1.9.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:c86229333cabaaa8c51cf971496f10318c4734cf7b641f08af0a6fbf17ca3054"},
{file = "pydantic-1.9.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:c0727bda6e38144d464daec31dff936a82917f431d9c39c39c60a26567eae3ed"},
{file = "pydantic-1.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:dee5ef83a76ac31ab0c78c10bd7d5437bfdb6358c95b91f1ba7ff7b76f9996a1"},
{file = "pydantic-1.9.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d9c9bdb3af48e242838f9f6e6127de9be7063aad17b32215ccc36a09c5cf1070"},
{file = "pydantic-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ee7e3209db1e468341ef41fe263eb655f67f5c5a76c924044314e139a1103a2"},
{file = "pydantic-1.9.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b6037175234850ffd094ca77bf60fb54b08b5b22bc85865331dd3bda7a02fa1"},
{file = "pydantic-1.9.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b2571db88c636d862b35090ccf92bf24004393f85c8870a37f42d9f23d13e032"},
{file = "pydantic-1.9.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8b5ac0f1c83d31b324e57a273da59197c83d1bb18171e512908fe5dc7278a1d6"},
{file = "pydantic-1.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:bbbc94d0c94dd80b3340fc4f04fd4d701f4b038ebad72c39693c794fd3bc2d9d"},
{file = "pydantic-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e0896200b6a40197405af18828da49f067c2fa1f821491bc8f5bde241ef3f7d7"},
{file = "pydantic-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bdfdadb5994b44bd5579cfa7c9b0e1b0e540c952d56f627eb227851cda9db77"},
{file = "pydantic-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:574936363cd4b9eed8acdd6b80d0143162f2eb654d96cb3a8ee91d3e64bf4cf9"},
{file = "pydantic-1.9.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c556695b699f648c58373b542534308922c46a1cda06ea47bc9ca45ef5b39ae6"},
{file = "pydantic-1.9.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f947352c3434e8b937e3aa8f96f47bdfe6d92779e44bb3f41e4c213ba6a32145"},
{file = "pydantic-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5e48ef4a8b8c066c4a31409d91d7ca372a774d0212da2787c0d32f8045b1e034"},
{file = "pydantic-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:96f240bce182ca7fe045c76bcebfa0b0534a1bf402ed05914a6f1dadff91877f"},
{file = "pydantic-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:815ddebb2792efd4bba5488bc8fde09c29e8ca3227d27cf1c6990fc830fd292b"},
{file = "pydantic-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c5b77947b9e85a54848343928b597b4f74fc364b70926b3c4441ff52620640c"},
{file = "pydantic-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c68c3bc88dbda2a6805e9a142ce84782d3930f8fdd9655430d8576315ad97ce"},
{file = "pydantic-1.9.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a79330f8571faf71bf93667d3ee054609816f10a259a109a0738dac983b23c3"},
{file = "pydantic-1.9.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f5a64b64ddf4c99fe201ac2724daada8595ada0d102ab96d019c1555c2d6441d"},
{file = "pydantic-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a733965f1a2b4090a5238d40d983dcd78f3ecea221c7af1497b845a9709c1721"},
{file = "pydantic-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cc6a4cb8a118ffec2ca5fcb47afbacb4f16d0ab8b7350ddea5e8ef7bcc53a16"},
{file = "pydantic-1.9.0-py3-none-any.whl", hash = "sha256:085ca1de245782e9b46cefcf99deecc67d418737a1fd3f6a4f511344b613a5b3"},
{file = "pydantic-1.9.0.tar.gz", hash = "sha256:742645059757a56ecd886faf4ed2441b9c0cd406079c2b4bee51bcc3fbcd510a"},
]
pyflakes = [
{file = "pyflakes-2.4.0-py2.py3-none-any.whl", hash = "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e"},
{file = "pyflakes-2.4.0.tar.gz", hash = "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c"},

View File

@@ -54,7 +54,7 @@ skip_gitignore = true
[tool.poetry]
name = "matrix-synapse"
version = "1.59.1"
version = "1.59.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"
@@ -182,7 +182,6 @@ hiredis = { version = "*", optional = true }
Pympler = { version = "*", optional = true }
parameterized = { version = ">=0.7.4", optional = true }
idna = { version = ">=2.5", optional = true }
pydantic = ">=1.9.0"
[tool.poetry.extras]
# NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified

View File

@@ -21,7 +21,7 @@ from typing import Callable, Optional, Type
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType, NoneType, UnionType
from mypy.types import CallableType, NoneType
class SynapsePlugin(Plugin):
@@ -72,20 +72,13 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
# Third, we add an optional "on_invalidate" argument.
#
# This is a either
# - a callable which accepts no input and returns nothing, or
# - None.
calltyp = UnionType(
[
NoneType(),
CallableType(
arg_types=[],
arg_kinds=[],
arg_names=[],
ret_type=NoneType(),
fallback=ctx.api.named_generic_type("builtins.function", []),
),
]
# This is a callable which accepts no input and returns nothing.
calltyp = CallableType(
arg_types=[],
arg_kinds=[],
arg_names=[],
ret_type=NoneType(),
fallback=ctx.api.named_generic_type("builtins.function", []),
)
arg_types.append(calltyp)
@@ -102,7 +95,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
def plugin(version: str) -> Type[SynapsePlugin]:
# This is the entry point of the plugin, and lets us deal with the fact
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#

View File

@@ -65,8 +65,6 @@ class JoinRules:
PRIVATE: Final = "private"
# As defined for MSC3083.
RESTRICTED: Final = "restricted"
# As defined for MSC3787.
KNOCK_RESTRICTED: Final = "knock_restricted"
class RestrictedJoinRuleTypes:

View File

@@ -81,9 +81,6 @@ class RoomVersion:
msc2716_historical: bool
# MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
msc2716_redactions: bool
# MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
# knocks and restricted join rules into the same join condition.
msc3787_knock_restricted_join_rule: bool
class RoomVersions:
@@ -102,7 +99,6 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V2 = RoomVersion(
"2",
@@ -119,7 +115,6 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V3 = RoomVersion(
"3",
@@ -136,7 +131,6 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V4 = RoomVersion(
"4",
@@ -153,7 +147,6 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V5 = RoomVersion(
"5",
@@ -170,7 +163,6 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V6 = RoomVersion(
"6",
@@ -187,7 +179,6 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
MSC2176 = RoomVersion(
"org.matrix.msc2176",
@@ -204,7 +195,6 @@ class RoomVersions:
msc2403_knocking=False,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V7 = RoomVersion(
"7",
@@ -221,7 +211,6 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V8 = RoomVersion(
"8",
@@ -238,7 +227,6 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
V9 = RoomVersion(
"9",
@@ -255,7 +243,6 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=False,
)
MSC2716v3 = RoomVersion(
"org.matrix.msc2716v3",
@@ -272,24 +259,6 @@ class RoomVersions:
msc2403_knocking=True,
msc2716_historical=True,
msc2716_redactions=True,
msc3787_knock_restricted_join_rule=False,
)
MSC3787 = RoomVersion(
"org.matrix.msc3787",
RoomDisposition.UNSTABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
msc2176_redaction_rules=False,
msc3083_join_rules=True,
msc3375_redaction_rules=True,
msc2403_knocking=True,
msc2716_historical=False,
msc2716_redactions=False,
msc3787_knock_restricted_join_rule=True,
)
@@ -307,7 +276,6 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.V8,
RoomVersions.V9,
RoomVersions.MSC2716v3,
RoomVersions.MSC3787,
)
}

View File

@@ -1,240 +0,0 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple
from pydantic import BaseModel, StrictBool, StrictStr, constr, validator
from pydantic.fields import ModelField
from synapse.util.stringutils import parse_and_validate_mxc_uri
# Ugly workaround for https://github.com/samuelcolvin/pydantic/issues/156. Mypy doesn't
# consider expressions like `constr(...)` to be valid types.
if TYPE_CHECKING:
IDP_ID_TYPE = str
IDP_BRAND_TYPE = str
else:
IDP_ID_TYPE = constr(
strict=True,
min_length=1,
max_length=250,
regex="^[A-Za-z0-9._~-]+$", # noqa: F722
)
IDP_BRAND_TYPE = constr(
strict=True,
min_length=1,
max_length=255,
regex="^[a-z][a-z0-9_.-]*$", # noqa: F722
)
# the following list of enum members is the same as the keys of
# authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
# to avoid importing authlib here.
class ClientAuthMethods(str, Enum):
# The duplication is unfortunate. 3.11 should have StrEnum though,
# and there is a backport available for 3.8.6.
client_secret_basic = "client_secret_basic"
client_secret_post = "client_secret_post"
none = "none"
class UserProfileMethod(str, Enum):
# The duplication is unfortunate. 3.11 should have StrEnum though,
# and there is a backport available for 3.8.6.
auto = "auto"
userinfo_endpoint = "userinfo_endpoint"
class SSOAttributeRequirement(BaseModel):
class Config:
# Complain if someone provides a field that's not one of those listed here.
# Pydantic suggests making your own BaseModel subclass if you want to do this,
# see https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
extra = "forbid"
attribute: StrictStr
# Note: a comment in config/oidc.py suggests that `value` may be optional. But
# The JSON schema seems to forbid this.
value: StrictStr
class ClientSecretJWTKey(BaseModel):
class Config:
extra = "forbid"
# a pem-encoded signing key
# TODO: how should we handle key_file?
key: StrictStr
# properties to include in the JWT header
# TODO: validator should enforce that jwt_header contains an 'alg'.
jwt_header: Mapping[str, str]
# properties to include in the JWT payload.
jwt_payload: Mapping[str, str] = {}
class OIDCProviderModel(BaseModel):
"""
Notes on Pydantic:
- I've used StrictStr because a plain `str` e.g. accepts integers and calls str()
on them
- pulling out constr() into IDP_ID_TYPE is a little awkward, but necessary to keep
mypy happy
-
"""
# a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol.
idp_id: IDP_ID_TYPE
@validator("idp_id")
def ensure_idp_id_prefix(cls, idp_id: str) -> str:
"""Prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
clashes with other mechs (such as SAML, CAS).
We allow "oidc" as an exception so that people migrating from old-style
"oidc_config" format (which has long used "oidc" as its idp_id) can migrate to
a new-style "oidc_providers" entry without changing the idp_id for their provider
(and thereby invalidating their user_external_ids data).
"""
if idp_id != "oidc":
return "oidc-" + idp_id
return idp_id
# user-facing name for this identity provider.
idp_name: StrictStr
# Optional MXC URI for icon for this IdP.
idp_icon: Optional[StrictStr]
@validator("idp_icon")
def idp_icon_is_an_mxc_url(cls, idp_icon: str) -> str:
parse_and_validate_mxc_uri(idp_icon)
return idp_icon
# Optional brand identifier for this IdP.
idp_brand: Optional[StrictStr]
# whether the OIDC discovery mechanism is used to discover endpoints
discover: StrictBool = True
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
# discover the provider's endpoints.
issuer: StrictStr
# oauth2 client id to use
client_id: StrictStr
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
# a secret.
client_secret: Optional[StrictStr]
# key to use to construct a JWT to use as a client secret. May be `None` if
# `client_secret` is set.
# TODO: test that ClientSecretJWTKey is being parsed correctly
client_secret_jwt_key: Optional[ClientSecretJWTKey]
# TODO: what is the precise relationship between client_auth_method, client_secret
# and client_secret_jwt_key? Is there anything we should enforce with a validator?
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
# 'none'.
client_auth_method: ClientAuthMethods = ClientAuthMethods.client_secret_basic
# list of scopes to request
scopes: Tuple[StrictStr, ...] = ("openid",)
# the oauth2 authorization endpoint. Required if discovery is disabled.
authorization_endpoint: Optional[StrictStr]
# the oauth2 token endpoint. Required if discovery is disabled.
token_endpoint: Optional[StrictStr]
# Normally, validators aren't run when fields don't have a value provided.
# Using validate=True ensures we run the validator even in that situation.
@validator("authorization_endpoint", "token_endpoint", always=True)
def endpoints_required_if_discovery_disabled(
cls,
endpoint_url: Optional[str],
values: Mapping[str, Any],
field: ModelField,
) -> Optional[str]:
# `if "discover" in values means: don't run our checks if "discover" didn't
# pass validation. (NB: validation order is the field definition order)
if "discover" in values and not values["discover"] and endpoint_url is None:
raise ValueError(f"{field.name} is required if discovery is disabled")
return endpoint_url
# the OIDC userinfo endpoint. Required if discovery is disabled and the
# "openid" scope is not requested.
userinfo_endpoint: Optional[StrictStr]
@validator("userinfo_endpoint", always=True)
def userinfo_endpoint_required_without_discovery_and_without_openid_scope(
cls, userinfo_endpoint: Optional[str], values: Mapping[str, Any]
) -> Optional[str]:
discovery_disabled = "discover" in values and not values["discover"]
openid_scope_not_requested = (
"scopes" in values and "openid" not in values["scopes"]
)
if (
discovery_disabled
and openid_scope_not_requested
and userinfo_endpoint is None
):
raise ValueError(
"userinfo_requirement is required if discovery is disabled and"
"the 'openid' scope is not requested"
)
return userinfo_endpoint
# URI where to fetch the JWKS. Required if discovery is disabled and the
# "openid" scope is used.
jwks_uri: Optional[StrictStr]
@validator("jwks_uri", always=True)
def jwks_uri_required_without_discovery_but_with_openid_scope(
cls, jwks_uri: Optional[str], values: Mapping[str, Any]
) -> Optional[str]:
discovery_disabled = "discover" in values and not values["discover"]
openid_scope_requested = "scopes" in values and "openid" in values["scopes"]
if discovery_disabled and openid_scope_requested and jwks_uri is None:
raise ValueError(
"jwks_uri is required if discovery is disabled and"
"the 'openid' scope is not requested"
)
return jwks_uri
# Whether to skip metadata verification
skip_verification: StrictBool = False
# Whether to fetch the user profile from the userinfo endpoint. Valid
# values are: "auto" or "userinfo_endpoint".
user_profile_method: UserProfileMethod = UserProfileMethod.auto
# whether to allow a user logging in via OIDC to match a pre-existing account
# instead of failing
allow_existing_users: StrictBool = False
# the class of the user mapping provider
# TODO there was logic for this
user_mapping_provider_class: Any # TODO: Type
# the config of the user mapping provider
# TODO
user_mapping_provider_config: Any
# required attributes to require in userinfo to allow login/registration
# TODO: wouldn't this be better expressed as a Mapping[str, str]?
attribute_requirements: Tuple[SSOAttributeRequirement, ...] = ()
class LegacyOIDCProviderModel(OIDCProviderModel):
# These fields could be omitted in the old scheme.
idp_id: IDP_ID_TYPE = "oidc"
idp_name: StrictStr = "OIDC"
# TODO
# top-level config: check we don't have any duplicate idp_ids now
# compute callback url

View File

@@ -414,12 +414,7 @@ def _is_membership_change_allowed(
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif (
room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED
) or (
room_version.msc3787_knock_restricted_join_rule
and join_rule == JoinRules.KNOCK_RESTRICTED
):
elif room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED:
# This is the same as public, but the event must contain a reference
# to the server who authorised the join. If the event does not contain
# the proper content it is rejected.
@@ -445,13 +440,8 @@ def _is_membership_change_allowed(
if authorising_user_level < invite_level:
raise AuthError(403, "Join event authorised by invalid server.")
elif (
join_rule == JoinRules.INVITE
or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
or (
room_version.msc3787_knock_restricted_join_rule
and join_rule == JoinRules.KNOCK_RESTRICTED
)
elif join_rule == JoinRules.INVITE or (
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
):
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
@@ -472,10 +462,7 @@ def _is_membership_change_allowed(
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
if join_rule != JoinRules.KNOCK and (
not room_version.msc3787_knock_restricted_join_rule
or join_rule != JoinRules.KNOCK_RESTRICTED
):
if join_rule != JoinRules.KNOCK:
raise AuthError(403, "You don't have permission to knock")
elif target_user_id != event.user_id:
raise AuthError(403, "You cannot knock for other users")

View File

@@ -15,17 +15,7 @@
import abc
import logging
from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Hashable,
Iterable,
List,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -419,7 +409,7 @@ class FederationSender(AbstractFederationSender):
)
return
destinations: Optional[Collection[str]] = None
destinations: Optional[Set[str]] = None
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
@@ -454,7 +444,7 @@ class FederationSender(AbstractFederationSender):
)
return
sharded_destinations = {
destinations = {
d
for d in destinations
if self._federation_shard_config.should_handle(
@@ -466,12 +456,12 @@ class FederationSender(AbstractFederationSender):
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
# send the event to it.
sharded_destinations.discard(send_on_behalf_of)
destinations.discard(send_on_behalf_of)
logger.debug("Sending %s to %r", event, sharded_destinations)
logger.debug("Sending %s to %r", event, destinations)
if sharded_destinations:
await self._send_pdu(event, sharded_destinations)
if destinations:
await self._send_pdu(event, destinations)
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)

View File

@@ -169,16 +169,14 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str
"""
try:
header_str = header_bytes.decode("utf-8")
params = re.split(" +", header_str)[1].split(",")
params = header_str.split(" ")[1].split(",")
param_dict: Dict[str, str] = {
k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
k: v for k, v in [param.split("=", maxsplit=1) for param in params]
}
def strip_quotes(value: str) -> str:
if value.startswith('"'):
return re.sub(
"\\\\(.)", lambda matchobj: matchobj.group(1), value[1:-1]
)
return value[1:-1]
else:
return value

View File

@@ -71,9 +71,6 @@ class DirectoryHandler:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if ":" in room_alias.localpart:
raise SynapseError(400, "Invalid character in room alias localpart: ':'.")
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this.

View File

@@ -241,15 +241,7 @@ class EventAuthHandler:
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self._store.get_event(join_rules_event_id)
content_join_rule = join_rules_event.content.get("join_rule")
if content_join_rule == JoinRules.RESTRICTED:
return True
# also check for MSC3787 behaviour
if room_version.msc3787_knock_restricted_join_rule:
return content_join_rule == JoinRules.KNOCK_RESTRICTED
return False
return join_rules_event.content.get("join_rule") == JoinRules.RESTRICTED
async def get_rooms_that_allow_join(
self, state_ids: StateMap[str]

View File

@@ -224,7 +224,7 @@ class OidcHandler:
self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.warning("Could not verify session for OIDC callback: %s", e)
logger.exception("Could not verify session for OIDC callback")
self._sso_handler.render_error(request, "mismatching_session", str(e))
return
@@ -827,7 +827,7 @@ class OidcProvider:
logger.debug("Exchanging OAuth2 code for a token")
token = await self._exchange_code(code)
except OidcError as e:
logger.warning("Could not exchange OAuth2 code: %s", e)
logger.exception("Could not exchange OAuth2 code")
self._sso_handler.render_error(request, e.error, e.error_description)
return

View File

@@ -751,21 +751,6 @@ class RoomCreationHandler:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
if ":" in config["room_alias_name"]:
# Prevent someone from trying to pass in a full alias here.
# Note that it's permissible for a room alias to have multiple
# hash symbols at the start (notably bridged over from IRC, too),
# but the first colon in the alias is defined to separate the local
# part from the server name.
# (remember server names can contain port numbers, also separated
# by a colon. But under no circumstances should the local part be
# allowed to contain a colon!)
raise SynapseError(
400,
"':' is not permitted in the room alias name. "
"Please note this expects a local part — 'wombat', not '#wombat:example.com'.",
)
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
mapping = await self.store.get_association_from_room_alias(room_alias)

View File

@@ -53,7 +53,6 @@ class RoomBatchHandler:
# We want to use the successor event depth so they appear after `prev_event` because
# it has a larger `depth` but before the successor event because the `stream_ordering`
# is negative before the successor event.
assert most_recent_prev_event_id is not None
successor_event_ids = await self.store.get_successor_events(
most_recent_prev_event_id
)
@@ -140,7 +139,6 @@ class RoomBatchHandler:
_,
) = await self.store.get_max_depth_of(event_ids)
# mapping from (type, state_key) -> state_event_id
assert most_recent_event_id is not None
prev_state_map = await self.state_store.get_state_ids_for_event(
most_recent_event_id
)

View File

@@ -562,13 +562,8 @@ class RoomSummaryHandler:
if join_rules_event_id:
join_rules_event = await self._store.get_event(join_rules_event_id)
join_rule = join_rules_event.content.get("join_rule")
if (
join_rule == JoinRules.PUBLIC
or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
or (
room_version.msc3787_knock_restricted_join_rule
and join_rule == JoinRules.KNOCK_RESTRICTED
)
if join_rule == JoinRules.PUBLIC or (
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
):
return True

View File

@@ -411,10 +411,10 @@ class SyncHandler:
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
async def push_rules_for_user(self, user: UserID) -> JsonDict:
user_id = user.to_string()
rules_raw = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules_raw)
rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
return rules
async def ephemeral_by_room(

View File

@@ -747,7 +747,7 @@ class MatrixFederationHttpClient:
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(
(
'X-Matrix origin="%s",key="%s",sig="%s",destination="%s"'
'X-Matrix origin=%s,key="%s",sig="%s",destination="%s"'
% (
self.server_name,
key,

View File

@@ -405,7 +405,7 @@ class HttpPusher(Pusher):
rejected = []
if "rejected" in resp:
rejected = resp["rejected"]
if not rejected:
else:
self.badge_count_last_call = badge
return rejected

View File

@@ -148,9 +148,9 @@ class PushRuleRestServlet(RestServlet):
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rules_raw = await self.store.get_push_rules_for_user(user_id)
rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rules_raw)
rules = format_push_rules_for_user(requester.user, rules)
path_parts = path.split("/")[1:]

View File

@@ -239,13 +239,13 @@ class StateHandler:
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]:
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
Args:
@@ -288,6 +288,7 @@ class StateHandler:
#
# first of all, figure out the state before the event
#
if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
@@ -418,37 +419,33 @@ class StateHandler:
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = await self.state_store.get_state_group_for_events(event_ids)
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)
state_group_ids = state_groups.values()
if len(state_groups_ids) == 0:
return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
# check if each event has same state group id, if so there's no state to resolve
state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set
state = await self.state_store.get_state_for_groups(state_group_ids_set)
prev_group, delta_ids = await self.state_store.get_state_group_delta(
state_group_id
)
return _StateCacheEntry(
state=state[state_group_id],
state_group=state_group_id,
state=state_list,
state_group=name,
prev_group=prev_group,
delta_ids=delta_ids,
)
elif len(state_group_ids_set) == 0:
return _StateCacheEntry(state={}, state_group=None)
room_version = await self.store.get_room_version_id(room_id)
state_to_resolve = await self.state_store.get_state_for_groups(
state_group_ids_set
)
result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_to_resolve,
state_groups_ids,
None,
state_res_store=StateResolutionStore(self.store),
)

View File

@@ -282,20 +282,12 @@ class BackgroundUpdater:
self._running = True
back_to_back_failures = 0
try:
logger.info("Starting background schema updates")
while self.enabled:
try:
result = await self.do_next_background_update(sleep)
back_to_back_failures = 0
except Exception:
back_to_back_failures += 1
if back_to_back_failures >= 5:
raise RuntimeError(
"5 back-to-back background update failures; aborting."
)
logger.exception("Error doing update")
else:
if result:

View File

@@ -26,7 +26,11 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -151,6 +155,8 @@ class DataStore(
],
)
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)

View File

@@ -203,29 +203,19 @@ class ApplicationServiceTransactionWorkerStore(
"""Get the application service state.
Args:
service: The service whose state to get.
service: The service whose state to set.
Returns:
An ApplicationServiceState, or None if we have yet to attempt any
transactions to the AS.
An ApplicationServiceState or none.
"""
# if we have created transactions for this AS but not yet attempted to send
# them, we will have a row in the table with state=NULL (recording the stream
# positions we have processed up to).
#
# On the other hand, if we have yet to create any transactions for this AS at
# all, then there will be no row for the AS.
#
# In either case, we return None to indicate "we don't yet know the state of
# this AS".
result = await self.db_pool.simple_select_one_onecol(
result = await self.db_pool.simple_select_one(
"application_services_state",
{"as_id": service.id},
retcol="state",
["state"],
allow_none=True,
desc="get_appservice_state",
)
if result:
return ApplicationServiceState(result)
return ApplicationServiceState(result.get("state"))
return None
async def set_appservice_state(
@@ -306,6 +296,14 @@ class ApplicationServiceTransactionWorkerStore(
"""
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
"application_services_state",
{"as_id": service.id},
{"last_txn": txn_id},
)
# Delete txn
self.db_pool.simple_delete_txn(
txn,
@@ -454,15 +452,16 @@ class ApplicationServiceTransactionWorkerStore(
% (stream_type,)
)
# this may be the first time that we're recording any state for this AS, so
# we don't yet know if a row for it exists; hence we have to upsert here.
await self.db_pool.simple_upsert(
table="application_services_state",
keyvalues={"as_id": service.id},
values={f"{stream_type}_stream_id": pos},
# no need to lock when emulating upsert: as_id is a unique key
lock=False,
desc="set_appservice_stream_type_pos",
def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
stream_id_type = "%s_stream_id" % stream_type
txn.execute(
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
% stream_id_type,
(pos, service.id),
)
await self.db_pool.runInteraction(
"set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
)

View File

@@ -14,17 +14,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter, Gauge
@@ -43,7 +33,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict
from synapse.storage.types import Cursor
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -145,7 +135,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -168,11 +158,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_ids_using_cover_index_txn(
self,
txn: LoggingTransaction,
room_id: str,
event_ids: Collection[str],
include_given: bool,
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
) -> Set[str]:
"""Calculates the auth chain IDs using the chain index."""
@@ -229,9 +215,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
# Add all linked chains reachable from initial set of chains.
for batch2 in batch_iter(event_chains, 1000):
for batch in batch_iter(event_chains, 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
@@ -311,7 +297,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
front = set(event_ids)
while front:
new_front: Set[str] = set()
new_front = set()
for chunk in batch_iter(front, 100):
# Pull the auth events either from the cache or DB.
to_fetch: List[str] = [] # Event IDs to fetch from DB
@@ -330,7 +316,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Note we need to batch up the results by event ID before
# adding to the cache.
to_cache: Dict[str, List[Tuple[str, int]]] = {}
to_cache = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -363,7 +349,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Check if we have indexed the room so we can use the chain cover
# algorithm.
room = await self.get_room(room_id) # type: ignore[attr-defined]
room = await self.get_room(room_id)
if room["has_auth_chain_index"]:
try:
return await self.db_pool.runInteraction(
@@ -384,7 +370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_difference_using_cover_index_txn(
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using the chain index.
@@ -458,9 +444,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
for batch2 in batch_iter(set(seen_chains), 1000):
for batch in batch_iter(set(seen_chains), 1000):
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
txn.database_engine, "origin_chain_id", batch
)
txn.execute(sql % (clause,), args)
@@ -543,7 +529,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]
self, txn, state_sets: List[Set[str]]
) -> Set[str]:
"""Calculates the auth chain difference using a breadth first search.
@@ -616,7 +602,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# I think building a temporary list with fetchall is more efficient than
# just `search.extend(txn)`, but this is unconfirmed
search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
search.extend(txn.fetchall())
# sort by depth
search.sort()
@@ -659,7 +645,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We parse the results and add the to the `found` set and the
# cache (note we need to batch up the results by event ID before
# adding to the cache).
to_cache: Dict[str, List[Tuple[str, int]]] = {}
to_cache = {}
for event_id, auth_event_id, auth_event_depth in txn:
to_cache.setdefault(event_id, []).append(
(auth_event_id, auth_event_depth)
@@ -710,7 +696,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return {eid for eid, n in event_to_missing_sets.items() if n}
async def get_oldest_event_ids_with_depth_in_room(
self, room_id: str
self, room_id
) -> List[Tuple[str, int]]:
"""Gets the oldest events(backwards extremities) in the room along with the
aproximate depth.
@@ -727,9 +713,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
def get_oldest_event_ids_with_depth_in_room_txn(
txn: LoggingTransaction, room_id: str
) -> List[Tuple[str, int]]:
def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
# Assemble a dictionary with event_id -> depth for the oldest events
# we know of in the room. Backwards extremeties are the oldest
# events we know of in the room but we only know of them because
@@ -759,7 +743,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id, False))
return cast(List[Tuple[str, int]], txn.fetchall())
return txn.fetchall()
return await self.db_pool.runInteraction(
"get_oldest_event_ids_with_depth_in_room",
@@ -768,7 +752,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
async def get_insertion_event_backward_extremities_in_room(
self, room_id: str
self, room_id
) -> List[Tuple[str, int]]:
"""Get the insertion events we know about that we haven't backfilled yet.
@@ -784,9 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
List of (event_id, depth) tuples
"""
def get_insertion_event_backward_extremities_in_room_txn(
txn: LoggingTransaction, room_id: str
) -> List[Tuple[str, int]]:
def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
sql = """
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
/* We only want insertion events that are also marked as backwards extremities */
@@ -798,7 +780,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
return cast(List[Tuple[str, int]], txn.fetchall())
return txn.fetchall()
return await self.db_pool.runInteraction(
"get_insertion_event_backward_extremities_in_room",
@@ -806,7 +788,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@@ -835,7 +817,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return max_depth_event_id, current_max_depth
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
Args:
@@ -883,9 +865,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
def _get_prev_events_for_room_txn(
self, txn: LoggingTransaction, room_id: str
) -> List[str]:
def _get_prev_events_for_room_txn(self, txn, room_id: str):
# we just use the 10 newest events. Older events will become
# prev_events of future events.
@@ -916,7 +896,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
sorted by extremity count.
"""
def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
def _get_rooms_with_many_extremities_txn(txn):
where_clause = "1=1"
if room_id_filter:
where_clause = "room_id NOT IN (%s)" % (
@@ -957,9 +937,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(
self, txn: LoggingTransaction, room_id: str
) -> Optional[int]:
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self.db_pool.simple_select_one_onecol_txn(
txn,
table="room_depth",
@@ -988,24 +966,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
# stream_ordering from before a restart
last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
last_change = max(self._stream_order_on_start, last_change)
# provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it.
if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -1013,7 +989,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
stream_orderings from that point.
"""
if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
@@ -1026,7 +1002,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
WHERE room_id = ?
"""
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
@@ -1128,8 +1104,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
]
async def get_backfill_events(
self, room_id: str, seed_event_id_list: List[str], limit: int
) -> List[EventBase]:
self, room_id: str, seed_event_id_list: list, limit: int
):
"""Get a list of Events for a given topic that occurred before (and
including) the events in seed_event_id_list. Return a list of max size `limit`
@@ -1147,19 +1123,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
events = await self.get_events_as_list(event_ids)
return sorted(
# type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
# But it's never None, because these events were previously persisted to the DB.
events,
key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
)
def _get_backfill_events(
self,
txn: LoggingTransaction,
room_id: str,
seed_event_id_list: List[str],
limit: int,
) -> Set[str]:
def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
"""
We want to make sure that we do a breadth-first, "depth" ordered search.
We also handle navigating historical branches of history connected by
@@ -1172,7 +1139,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
limit,
)
event_id_results: Set[str] = set()
event_id_results = set()
# In a PriorityQueue, the lowest valued entries are retrieved first.
# We're using depth as the priority in the queue and tie-break based on
@@ -1180,7 +1147,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# highest and newest-in-time message. We add events to the queue with a
# negative depth so that we process the newest-in-time messages first
# going backwards in time. stream_ordering follows the same pattern.
queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
queue = PriorityQueue()
for seed_event_id in seed_event_id_list:
event_lookup_result = self.db_pool.simple_select_one_txn(
@@ -1286,13 +1253,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return event_id_results
async def get_missing_events(
self,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[EventBase]:
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = await self.db_pool.runInteraction(
"get_missing_events",
self._get_missing_events,
@@ -1303,18 +1264,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(ids)
def _get_missing_events(
self,
txn: LoggingTransaction,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
limit: int,
) -> List[str]:
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
event_results: List[str] = []
event_results = []
query = (
"SELECT prev_event_id FROM event_edges "
@@ -1357,7 +1311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@wrap_as_background_process("delete_old_forward_extrem_cache")
async def _delete_old_forward_extrem_cache(self) -> None:
def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = """
@@ -1370,7 +1324,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) AND stream_ordering < ?
"""
txn.execute(
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
)
await self.db_pool.runInteraction(
@@ -1428,9 +1382,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
if self.db_pool.engine.supports_returning:
def _remove_received_event_from_staging_txn(
txn: LoggingTransaction,
) -> Optional[int]:
def _remove_received_event_from_staging_txn(txn):
sql = """
DELETE FROM federation_inbound_events_staging
WHERE origin = ? AND event_id = ?
@@ -1438,24 +1390,21 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (origin, event_id))
row = cast(Optional[Tuple[int]], txn.fetchone())
return txn.fetchone()
if row is None:
return None
return row[0]
return await self.db_pool.runInteraction(
row = await self.db_pool.runInteraction(
"remove_received_event_from_staging",
_remove_received_event_from_staging_txn,
db_autocommit=True,
)
if row is None:
return None
return row[0]
else:
def _remove_received_event_from_staging_txn(
txn: LoggingTransaction,
) -> Optional[int]:
def _remove_received_event_from_staging_txn(txn):
received_ts = self.db_pool.simple_select_one_onecol_txn(
txn,
table="federation_inbound_events_staging",
@@ -1488,9 +1437,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, str]]:
"""Get the next event ID in the staging area for the given room."""
def _get_next_staged_event_id_for_room_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, str]]:
def _get_next_staged_event_id_for_room_txn(txn):
sql = """
SELECT origin, event_id
FROM federation_inbound_events_staging
@@ -1501,7 +1448,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (room_id,))
return cast(Optional[Tuple[str, str]], txn.fetchone())
return txn.fetchone()
return await self.db_pool.runInteraction(
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
@@ -1514,9 +1461,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
) -> Optional[Tuple[str, EventBase]]:
"""Get the next event in the staging area for the given room."""
def _get_next_staged_event_for_room_txn(
txn: LoggingTransaction,
) -> Optional[Tuple[str, str, str]]:
def _get_next_staged_event_for_room_txn(txn):
sql = """
SELECT event_json, internal_metadata, origin
FROM federation_inbound_events_staging
@@ -1526,7 +1471,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
"""
txn.execute(sql, (room_id,))
return cast(Optional[Tuple[str, str, str]], txn.fetchone())
return txn.fetchone()
row = await self.db_pool.runInteraction(
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
@@ -1654,20 +1599,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@wrap_as_background_process("_get_stats_for_federation_staging")
async def _get_stats_for_federation_staging(self) -> None:
async def _get_stats_for_federation_staging(self):
"""Update the prometheus metrics for the inbound federation staging area."""
def _get_stats_for_federation_staging_txn(
txn: LoggingTransaction,
) -> Tuple[int, int]:
def _get_stats_for_federation_staging_txn(txn):
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
txn.execute(
"SELECT min(received_ts) FROM federation_inbound_events_staging"
)
(received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
(received_ts,) = txn.fetchone()
# If there is nothing in the staging area default it to 0.
age = 0
@@ -1708,21 +1651,19 @@ class EventFederationStore(EventFederationWorkerStore):
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
async def clean_room_for_join(self, room_id: str) -> None:
await self.db_pool.runInteraction(
async def clean_room_for_join(self, room_id):
return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
def _clean_room_for_join_txn(self, txn, room_id):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
async def _background_delete_non_state_event_auth(
self, progress: JsonDict, batch_size: int
) -> int:
def delete_event_auth(txn: LoggingTransaction) -> bool:
async def _background_delete_non_state_event_auth(self, progress, batch_size):
def delete_event_auth(txn):
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
max_stream_id = progress.get("max_stream_id_exclusive")

View File

@@ -14,19 +14,16 @@
import calendar
import logging
import time
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
from typing import TYPE_CHECKING, Dict
from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
from synapse.storage.types import Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -76,7 +73,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
@wrap_as_background_process("read_forward_extremities")
async def _read_forward_extremities(self) -> None:
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
def fetch(txn):
txn.execute(
"""
SELECT t1.c, t2.c
@@ -89,7 +86,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) t2 ON t1.room_id = t2.room_id
"""
)
return cast(List[Tuple[int, int]], txn.fetchall())
return txn.fetchall()
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
@@ -107,20 +104,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
def _count_messages(txn: LoggingTransaction) -> int:
def _count_messages(txn):
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn: LoggingTransaction) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -133,7 +130,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction(
@@ -141,14 +138,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn: LoggingTransaction) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.encrypted'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction(
@@ -163,20 +160,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
call to this function, it will return None.
"""
def _count_messages(txn: LoggingTransaction) -> int:
def _count_messages(txn):
sql = """
SELECT COUNT(*) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction("count_messages", _count_messages)
async def count_daily_sent_messages(self) -> int:
def _count_messages(txn: LoggingTransaction) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
like_clause = "%:" + self.hs.hostname
@@ -189,7 +186,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction(
@@ -197,14 +194,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
)
async def count_daily_active_rooms(self) -> int:
def _count(txn: LoggingTransaction) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
WHERE type = 'm.room.message'
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
return count
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -230,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
def _count_users(self, txn: Cursor, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@@ -245,7 +242,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
# Mypy knows that fetchone() might return None if there are no rows.
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
# returns exactly one row.
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone() # type: ignore[misc]
return count
async def count_r30_users(self) -> Dict[str, int]:
@@ -259,7 +256,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
A mapping of counts globally as well as broken out by platform.
"""
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
def _count_r30_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -324,7 +321,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
results["all"] = count
return results
@@ -351,7 +348,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
- "web" (any web application -- it's not possible to distinguish Element Web here)
"""
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
def _count_r30v2_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -448,8 +445,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
thirty_days_in_secs * 1000,
),
)
(count,) = cast(Tuple[int], txn.fetchone())
results["all"] = count
row = txn.fetchone()
if row is None:
results["all"] = 0
else:
results["all"] = row[0]
return results
@@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
Generates daily visit data for use in cohort/ retention analysis
"""
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
def _generate_user_daily_visits(txn):
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
a_day_in_milliseconds = 24 * 60 * 60 * 1000

View File

@@ -417,7 +417,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"room_account_data",
"room_tags",
"local_current_membership",
"federation_inbound_events_staging",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))

View File

@@ -14,18 +14,14 @@
# limitations under the License.
import abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -34,12 +30,9 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
AbstractStreamIdTracker,
IdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -64,11 +57,7 @@ def _is_experimental_rule_enabled(
return True
def _load_rules(
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
) -> List[JsonDict]:
def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
@@ -148,7 +137,7 @@ class PushRulesWorkerStore(
)
@abc.abstractmethod
def get_max_push_rules_stream_id(self) -> int:
def get_max_push_rules_stream_id(self):
"""Get the position of the push rules stream.
Returns:
@@ -157,7 +146,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
async def get_push_rules_for_user(self, user_id):
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -179,7 +168,7 @@ class PushRulesWorkerStore(
return _load_rules(rows, enabled_map, self.hs.config.experimental)
@cached(max_entries=5000)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
@@ -195,13 +184,13 @@ class PushRulesWorkerStore(
return False
else:
def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
(count,) = cast(Tuple[int], txn.fetchone())
(count,) = txn.fetchone()
return bool(count)
return await self.db_pool.runInteraction(
@@ -213,13 +202,11 @@ class PushRulesWorkerStore(
list_name="user_ids",
num_args=1,
)
async def bulk_get_push_rules(
self, user_ids: Collection[str]
) -> Dict[str, List[JsonDict]]:
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
results = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -243,18 +230,67 @@ class PushRulesWorkerStore(
return results
async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: dict
) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
new_room_id: ID of the new room.
user_id : ID of user the push rule belongs to.
rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
# Change room id in each condition
for condition in rule.get("conditions", []):
if condition.get("key") == "room_id":
condition["pattern"] = new_room_id
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
conditions=rule["conditions"],
actions=rule["actions"],
)
async def copy_push_rules_from_room_to_room_for_user(
self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
old_room_id: ID of the old room.
new_room_id: ID of the new room.
user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
conditions = rule.get("conditions", [])
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
)
async def bulk_get_push_rules_enabled(
self, user_ids: Collection[str]
) -> Dict[str, Dict[str, bool]]:
async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
results = {user_id: {} for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
@@ -270,7 +306,7 @@ class PushRulesWorkerStore(
async def get_all_push_rule_updates(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for push_rules replication stream.
Args:
@@ -295,9 +331,7 @@ class PushRulesWorkerStore(
if last_id == current_id:
return [], current_id, False
def get_all_push_rule_updates_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
def get_all_push_rule_updates_txn(txn):
sql = """
SELECT stream_id, user_id
FROM push_rules_stream
@@ -306,10 +340,7 @@ class PushRulesWorkerStore(
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = cast(
List[Tuple[int, Tuple[str]]],
[(stream_id, (user_id,)) for stream_id, user_id in txn],
)
updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
limited = False
upper_bound = current_id
@@ -325,30 +356,15 @@ class PushRulesWorkerStore(
class PushRuleStore(PushRulesWorkerStore):
# Because we have write access, this will be a StreamIdGenerator
# (see PushRulesWorkerStore.__init__)
_push_rules_stream_id_gen: AbstractStreamIdGenerator
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
async def add_push_rule(
self,
user_id: str,
rule_id: str,
priority_class: int,
conditions: List[Dict[str, str]],
actions: List[Union[JsonDict, str]],
before: Optional[str] = None,
after: Optional[str] = None,
user_id,
rule_id,
priority_class,
conditions,
actions,
before=None,
after=None,
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
@@ -384,17 +400,17 @@ class PushRuleStore(PushRulesWorkerStore):
def _add_push_rule_relative_txn(
self,
txn: LoggingTransaction,
stream_id: int,
event_stream_ordering: int,
user_id: str,
rule_id: str,
priority_class: int,
conditions_json: str,
actions_json: str,
before: str,
after: str,
) -> None:
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
before,
after,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
@@ -454,15 +470,15 @@ class PushRuleStore(PushRulesWorkerStore):
def _add_push_rule_highest_priority_txn(
self,
txn: LoggingTransaction,
stream_id: int,
event_stream_ordering: int,
user_id: str,
rule_id: str,
priority_class: int,
conditions_json: str,
actions_json: str,
) -> None:
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
self.database_engine.lock_table(txn, "push_rules")
@@ -494,17 +510,17 @@ class PushRuleStore(PushRulesWorkerStore):
def _upsert_push_rule_txn(
self,
txn: LoggingTransaction,
stream_id: int,
event_stream_ordering: int,
user_id: str,
rule_id: str,
priority_class: int,
priority: int,
conditions_json: str,
actions_json: str,
update_stream: bool = True,
) -> None:
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
priority,
conditions_json,
actions_json,
update_stream=True,
):
"""Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked"""
@@ -584,11 +600,7 @@ class PushRuleStore(PushRulesWorkerStore):
rule_id: The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(
txn: LoggingTransaction,
stream_id: int,
event_stream_ordering: int,
) -> None:
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
# we don't use simple_delete_one_txn because that would fail if the
# user did not have a push_rule_enable row.
self.db_pool.simple_delete_txn(
@@ -649,14 +661,14 @@ class PushRuleStore(PushRulesWorkerStore):
def _set_push_rule_enabled_txn(
self,
txn: LoggingTransaction,
stream_id: int,
event_stream_ordering: int,
user_id: str,
rule_id: str,
enabled: bool,
is_default_rule: bool,
) -> None:
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
enabled,
is_default_rule,
):
new_id = self._push_rules_enable_id_gen.get_next()
if not is_default_rule:
@@ -728,11 +740,7 @@ class PushRuleStore(PushRulesWorkerStore):
"""
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(
txn: LoggingTransaction,
stream_id: int,
event_stream_ordering: int,
) -> None:
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
if is_default_rule:
# Add a dummy rule to the rules table with the user specified
# actions.
@@ -786,15 +794,8 @@ class PushRuleStore(PushRulesWorkerStore):
)
def _insert_push_rules_update_txn(
self,
txn: LoggingTransaction,
stream_id: int,
event_stream_ordering: int,
user_id: str,
rule_id: str,
op: str,
data: Optional[JsonDict] = None,
) -> None:
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
):
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
@@ -813,56 +814,5 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
def get_max_push_rules_stream_id(self) -> int:
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: dict
) -> None:
"""Copy a single push rule from one room to another for a specific user.
Args:
new_room_id: ID of the new room.
user_id : ID of user the push rule belongs to.
rule: A push rule.
"""
# Create new rule id
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
# Change room id in each condition
for condition in rule.get("conditions", []):
if condition.get("key") == "room_id":
condition["pattern"] = new_room_id
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
priority_class=rule["priority_class"],
conditions=rule["conditions"],
actions=rule["actions"],
)
async def copy_push_rules_from_room_to_room_for_user(
self, old_room_id: str, new_room_id: str, user_id: str
) -> None:
"""Copy all of the push rules from one room to another for a specific
user.
Args:
old_room_id: ID of the old room.
new_room_id: ID of the new room.
user_id: ID of user to copy push rules for.
"""
# Retrieve push rules for this user
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule in user_push_rules:
conditions = rule.get("conditions", [])
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)

View File

@@ -15,7 +15,6 @@
import logging
from typing import (
TYPE_CHECKING,
Callable,
Collection,
Dict,
FrozenSet,
@@ -38,12 +37,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.roommember import (
@@ -52,7 +46,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
from synapse.types import PersistedEventPosition, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -121,7 +115,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@wrap_as_background_process("_count_known_servers")
async def _count_known_servers(self) -> int:
async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -129,7 +123,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
`synapse_federation_known_servers` LaterGauge to collect.
"""
def _transact(txn: LoggingTransaction) -> int:
def _transact(txn):
if isinstance(self.database_engine, Sqlite3Engine):
query = """
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@@ -156,9 +150,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._known_servers_count = max([count, 1])
return self._known_servers_count
def _check_safe_current_state_events_membership_updated_txn(
self, txn: LoggingTransaction
) -> None:
def _check_safe_current_state_events_membership_updated_txn(self, txn):
"""Checks if it is safe to assume the new current_state_events
membership column is up to date
"""
@@ -190,7 +182,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_users_in_room", self.get_users_in_room_txn, room_id
)
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
# If we can assume current_state_events.membership is up to date
# then we can avoid a join, which is a Very Good Thing given how
# frequently this function gets called.
@@ -230,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
A mapping from user ID to ProfileInfo.
"""
def _get_users_in_room_with_profiles(
txn: LoggingTransaction,
) -> Dict[str, ProfileInfo]:
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
sql = """
SELECT state_key, display_name, avatar_url FROM room_memberships as m
INNER JOIN current_state_events as c
@@ -260,9 +250,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
dict of membership states, pointing to a MemberSummary named tuple.
"""
def _get_room_summary_txn(
txn: LoggingTransaction,
) -> Dict[str, MemberSummary]:
def _get_room_summary_txn(txn):
# first get counts.
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
@@ -291,7 +279,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (room_id,))
res: Dict[str, MemberSummary] = {}
res = {}
for count, membership in txn:
res.setdefault(membership, MemberSummary([], count))
@@ -412,7 +400,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def _get_rooms_for_local_user_where_membership_is_txn(
self,
txn: LoggingTransaction,
txn,
user_id: str,
membership_list: List[str],
) -> List[RoomsForUser]:
@@ -500,7 +488,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_user_with_stream_ordering_txn(
self, txn: LoggingTransaction, user_id: str
self, txn, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
# We use `current_state_events` here and not `local_current_membership`
# as a) this gets called with remote users and b) this only gets called
@@ -554,7 +542,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
def _get_rooms_for_users_with_stream_ordering_txn(
self, txn: LoggingTransaction, user_ids: Collection[str]
self, txn, user_ids: Collection[str]
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
clause, args = make_in_list_sql_clause(
@@ -587,9 +575,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, [Membership.JOIN] + args)
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
user_id: set() for user_id in user_ids
}
result = {user_id: set() for user_id in user_ids}
for user_id, room_id, instance, stream_id in txn:
result[user_id].add(
GetRoomsForUserWithStreamOrdering(
@@ -609,9 +595,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if not user_ids:
return set()
def _get_users_server_still_shares_room_with_txn(
txn: LoggingTransaction,
) -> Set[str]:
def _get_users_server_still_shares_room_with_txn(txn):
sql = """
SELECT state_key FROM current_state_events
WHERE
@@ -635,7 +619,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
async def get_rooms_for_user(
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
self, user_id: str, on_invalidate=None
) -> FrozenSet[str]:
"""Returns a set of room_ids the user is currently joined to.
@@ -673,7 +657,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = context.state_group
state_group = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -682,16 +666,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
current_state_ids = await context.get_current_state_ids()
assert current_state_ids is not None
assert state_group is not None
return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
async def get_joined_users_from_state(
self, room_id: str, state_entry: "_StateCacheEntry"
self, room_id, state_entry
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = state_entry.state_group
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -699,7 +681,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry
@@ -708,12 +689,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
async def _get_joined_users_from_context(
self,
room_id: str,
state_group: Union[object, int],
current_state_ids: StateMap[str],
cache_context: _CacheContext,
event: Optional[EventBase] = None,
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
room_id,
state_group,
current_state_ids,
cache_context,
event=None,
context=None,
) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
@@ -784,18 +765,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room
@cached(max_entries=10000)
def _get_joined_profile_from_event_id(
self, event_id: str
) -> Optional[Tuple[str, ProfileInfo]]:
def _get_joined_profile_from_event_id(self, event_id):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
)
async def _get_joined_profiles_from_event_ids(
self, event_ids: Iterable[str]
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -803,7 +780,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
Map from event ID to `user_id` and ProfileInfo (or None if not join event).
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = await self.db_pool.simple_select_many_batch(
@@ -869,10 +847,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
async def get_joined_hosts(self, room_id: str, state_entry):
state_group = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
@@ -880,7 +856,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
@@ -888,10 +863,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self,
room_id: str,
state_group: Union[object, int],
state_entry: "_StateCacheEntry",
self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
@@ -909,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
cache = await self._get_joined_hosts_cache(room_id)
# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
@@ -925,7 +897,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue
@@ -971,7 +942,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns False if they have since re-joined."""
def f(txn: LoggingTransaction) -> int:
def f(txn):
sql = (
"SELECT"
" COUNT(*)"
@@ -1002,7 +973,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The forgotten rooms.
"""
def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
def _get_forgotten_rooms_for_user_txn(txn):
# This is a slightly convoluted query that first looks up all rooms
# that the user has forgotten in the past, then rechecks that list
# to see if any have subsequently been updated. This is done so that
@@ -1105,9 +1076,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
clause,
)
def _is_local_host_in_room_ignoring_users_txn(
txn: LoggingTransaction,
) -> bool:
def _is_local_host_in_room_ignoring_users_txn(txn):
txn.execute(sql, (room_id, Membership.JOIN, *args))
return bool(txn.fetchone())
@@ -1141,17 +1110,15 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
where_clause="forgotten = 1",
)
async def _background_add_membership_profile(
self, progress: JsonDict, batch_size: int
) -> int:
async def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
"target_min_stream_id_inclusive", self._min_stream_order_on_start
)
max_stream_id = progress.get(
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
def add_membership_profile_txn(txn: LoggingTransaction) -> int:
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
@@ -1215,17 +1182,13 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return result
async def _background_current_state_membership(
self, progress: JsonDict, batch_size: int
) -> int:
async def _background_current_state_membership(self, progress, batch_size):
"""Update the new membership column on current_state_events.
This works by iterating over all rooms in alphebetical order.
"""
def _background_current_state_membership_txn(
txn: LoggingTransaction, last_processed_room: str
) -> Tuple[int, bool]:
def _background_current_state_membership_txn(txn, last_processed_room):
processed = 0
while processed < batch_size:
txn.execute(
@@ -1279,11 +1242,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
return row_count
class RoomMemberStore(
RoomMemberWorkerStore,
RoomMemberBackgroundUpdateStore,
CacheInvalidationWorkerStore,
):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(
self,
database: DatabasePool,
@@ -1295,7 +1254,7 @@ class RoomMemberStore(
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn: LoggingTransaction) -> None:
def f(txn):
sql = (
"UPDATE"
" room_memberships"
@@ -1329,5 +1288,5 @@ class _JoinedHostsCache:
# equal to anything else).
state_group: Union[object, int] = attr.Factory(object)
def __len__(self) -> int:
def __len__(self):
return sum(len(v) for v in self.hosts_to_joined_users.values())

View File

@@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
group: int,
state_filter: StateFilter,
) -> Tuple[MutableStateMap[str], bool]:
"""Checks if group is in cache. See `get_state_for_groups`
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
cache: the state group cache to use

View File

@@ -61,9 +61,7 @@ Changes in SCHEMA_VERSION = 68:
Changes in SCHEMA_VERSION = 69:
- We now write to `device_lists_changes_in_room` table.
- We now use a PostgreSQL sequence to generate future txn_ids for
`application_services_txns`. `application_services_state.last_txn` is no longer
updated.
- Use sequence to generate future `application_services_txns.txn_id`s
Changes in SCHEMA_VERSION = 70:
- event_reference_hashes is no longer written to.
@@ -73,7 +71,6 @@ Changes in SCHEMA_VERSION = 70:
SCHEMA_COMPAT_VERSION = (
# We now assume that `device_lists_changes_in_room` has been filled out for
# recent device_list_updates.
# ... and that `application_services_state.last_txn` is not used.
69
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat

View File

@@ -1,5 +1,4 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +15,6 @@ import logging
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
@@ -534,44 +532,6 @@ class StateFilter:
new_all, new_excludes, new_wildcards, new_concrete_keys
)
def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
"""Check if we need to wait for full state to complete to calculate this state
If we have a state filter which is completely satisfied even with partial
state, then we don't need to await_full_state before we can return it.
Args:
is_mine_id: a callable which confirms if a given state_key matches a mxid
of a local user
"""
# TODO(faster_joins): it's not entirely clear that this is safe. In particular,
# there may be circumstances in which we return a piece of state that, once we
# resync the state, we discover is invalid. For example: if it turns out that
# the sender of a piece of state wasn't actually in the room, then clearly that
# state shouldn't have been returned.
# We should at least add some tests around this to see what happens.
# if we haven't requested membership events, then it depends on the value of
# 'include_others'
if EventTypes.Member not in self.types:
return self.include_others
# if we're looking for *all* membership events, then we have to wait
member_state_keys = self.types[EventTypes.Member]
if member_state_keys is None:
return True
# otherwise, consider whose membership we are looking for. If it's entirely
# local users, then we don't need to wait.
for state_key in member_state_keys:
if not is_mine_id(state_key):
# remote user
return True
# local users only
return False
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
@@ -584,7 +544,6 @@ class StateGroupStorage:
"""High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
@@ -627,7 +586,7 @@ class StateGroupStorage:
if not event_ids:
return {}
event_to_groups = await self.get_state_group_for_events(event_ids)
event_to_groups = await self._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -643,7 +602,7 @@ class StateGroupStorage:
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self.get_state_for_groups((state_group,))
group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group]
@@ -716,13 +675,7 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
event_to_groups = await self._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
@@ -746,9 +699,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -765,13 +716,7 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
event_to_groups = await self._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(
@@ -829,7 +774,7 @@ class StateGroupStorage:
)
return state_map[event_id]
def get_state_for_groups(
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
@@ -847,7 +792,7 @@ class StateGroupStorage:
groups, state_filter or StateFilter.all()
)
async def get_state_group_for_events(
async def _get_state_group_for_events(
self,
event_ids: Collection[str],
await_full_state: bool = True,
@@ -857,7 +802,7 @@ class StateGroupStorage:
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
state at this event.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)

View File

@@ -1,415 +0,0 @@
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Dict
from unittest import TestCase
import yaml
from parameterized import parameterized
from pydantic import ValidationError
from synapse.config.oidc2 import (
ClientAuthMethods,
LegacyOIDCProviderModel,
OIDCProviderModel,
)
SAMPLE_CONFIG = yaml.safe_load(
"""
idp_id: my_idp
idp_name: My OpenID provider
idp_icon: "mxc://example.com/blahblahblah"
idp_brand: "brandy"
issuer: "https://accountns.exeample.com"
client_id: "provided-by-your-issuer"
client_secret_jwt_key:
key: DUMMY_PRIVATE_KEY
jwt_header:
alg: ES256
kid: potato123
jwt_payload:
iss: issuer456
client_auth_method: "client_secret_post"
scopes: ["name", "email", "openid"]
authorization_endpoint: https://example.com/auth/authorize?response_mode=form_post
token_endpoint: https://id.example.com/dummy_url_here
jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
user_mapping_provider:
config:
email_template: "{{ user.email }}"
localpart_template: "{{ user.email|localpart_from_email }}"
confirm_localpart: true
attribute_requirements:
- attribute: userGroup
value: "synapseUsers"
"""
)
class PydanticOIDCTestCase(TestCase):
"""Examples to build confidence that pydantic is doing the validation we think
it's doing"""
# Each test gets a dummy config it can change as it sees fit
config: Dict[str, Any]
def setUp(self) -> None:
self.config = deepcopy(SAMPLE_CONFIG)
@contextmanager
def assertRaises(self, *args, **kwargs):
"""To demonstrate the example error messages generated by Pydantic, uncomment
this method."""
with super().assertRaises(*args, **kwargs) as result:
yield result
print()
print(result.exception)
def test_example_config(self):
# Check that parsing the sample config doesn't raise an error.
OIDCProviderModel.parse_obj(self.config)
def test_idp_id(self) -> None:
"""Example of using a Pydantic constr() field without a default."""
# Enforce that idp_id is required.
with self.assertRaises(ValidationError):
del self.config["idp_id"]
OIDCProviderModel.parse_obj(self.config)
# Enforce that idp_id is a string.
for bad_value in 123, None, ["a"], {"a": "b"}:
with self.assertRaises(ValidationError):
self.config["idp_id"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Enforce a length between 1 and 250.
with self.assertRaises(ValidationError):
self.config["idp_id"] = ""
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["idp_id"] = "a" * 251
OIDCProviderModel.parse_obj(self.config)
# Enforce the regex
with self.assertRaises(ValidationError):
self.config["idp_id"] = "$"
OIDCProviderModel.parse_obj(self.config)
# What happens with a really long string of prohibited characters?
with self.assertRaises(ValidationError):
self.config["idp_id"] = "$" * 500
OIDCProviderModel.parse_obj(self.config)
def test_legacy_model(self) -> None:
"""Example of widening a field's type in a subclass."""
# Check that parsing the sample config doesn't raise an error.
LegacyOIDCProviderModel.parse_obj(self.config)
# Check we have default values for the attributes which have a legacy fallback
del self.config["idp_id"]
del self.config["idp_name"]
model = LegacyOIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.idp_id, "oidc")
self.assertEqual(model.idp_name, "OIDC")
# Check we still reject bad types
for bad_value in 123, [], {}, None:
with self.assertRaises(ValidationError) as e:
self.config["idp_id"] = bad_value
self.config["idp_name"] = bad_value
LegacyOIDCProviderModel.parse_obj(self.config)
# And while we're at it, check that we spot errors in both fields
reported_bad_fields = {item["loc"] for item in e.exception.errors()}
expected_bad_fields = {("idp_id",), ("idp_name",)}
self.assertEqual(
reported_bad_fields, expected_bad_fields, e.exception.errors()
)
def test_issuer(self) -> None:
"""Example of a StrictStr field without a default."""
# Empty and nonempty strings should be accepted.
for good_value in "", "hello", "hello" * 1000, "":
self.config["issuer"] = good_value
OIDCProviderModel.parse_obj(self.config)
# Invalid types should be rejected.
for bad_value in 123, None, ["h", "e", "l", "l", "o"], {"hello": "there"}:
with self.assertRaises(ValidationError):
self.config["issuer"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# A missing issuer should be rejected.
with self.assertRaises(ValidationError):
del self.config["issuer"]
OIDCProviderModel.parse_obj(self.config)
def test_idp_brand(self) -> None:
"""Example of an Optional[StrictStr] field."""
# Empty and nonempty strings should be accepted.
for good_value in "", "hello", "hello" * 1000, "":
self.config["idp_brand"] = good_value
OIDCProviderModel.parse_obj(self.config)
# Invalid types should be rejected.
for bad_value in 123, ["h", "e", "l", "l", "o"], {"hello": "there"}:
with self.assertRaises(ValidationError):
self.config["idp_brand"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# A lack of an idp_brand is fine...
del self.config["idp_brand"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertIsNone(model.idp_brand)
# ... and interpreted the same as an explicit `None`.
self.config["idp_brand"] = None
model = OIDCProviderModel.parse_obj(self.config)
self.assertIsNone(model.idp_brand)
def test_idp_icon(self) -> None:
"""Example of a field with a custom validator."""
# Test that bad types are rejected, even with our validator in place
bad_value: object
for bad_value in None, {}, [], 123, 45.6:
with self.assertRaises(ValidationError):
self.config["idp_icon"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Test that bad strings are rejected by our validator
for bad_value in "", "notaurl", "https://example.com", "mxc://mxc://mxc://":
with self.assertRaises(ValidationError):
self.config["idp_icon"] = bad_value
OIDCProviderModel.parse_obj(self.config)
def test_discover(self) -> None:
"""Example of a StrictBool field with a default."""
# Booleans are permitted.
for value in True, False:
self.config["discover"] = value
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.discover, value)
# Invalid types should be rejected.
for bad_value in (
-1.0,
0,
1,
float("nan"),
"yes",
"NO",
"True",
"true",
None,
"None",
"null",
["a"],
{"a": "b"},
):
self.config["discover"] = bad_value
with self.assertRaises(ValidationError):
OIDCProviderModel.parse_obj(self.config)
# A missing value is okay, because this field has a default.
del self.config["discover"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertIs(model.discover, True)
def test_client_auth_method(self) -> None:
"""This is an example of using a Pydantic string enum field."""
# check the allowed values are permitted and deserialise to an enum member
for method in "client_secret_basic", "client_secret_post", "none":
self.config["client_auth_method"] = method
model = OIDCProviderModel.parse_obj(self.config)
self.assertIs(model.client_auth_method, ClientAuthMethods[method])
# check the default applies if no auth method is provided.
del self.config["client_auth_method"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertIs(model.client_auth_method, ClientAuthMethods.client_secret_basic)
# Check invalid types are rejected
for bad_value in 123, ["client_secret_basic"], {"a": 1}, None:
with self.assertRaises(ValidationError):
self.config["client_auth_method"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Check that disallowed strings are rejected
with self.assertRaises(ValidationError):
self.config["client_auth_method"] = "No, Luke, _I_ am your father!"
OIDCProviderModel.parse_obj(self.config)
def test_scopes(self) -> None:
"""Example of a Tuple[StrictStr] with a default."""
# Check that the parsed object holds a tuple
self.config["scopes"] = []
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.scopes, ())
# Check a variety of list lengths are accepted.
for good_value in ["aa"], ["hello", "world"], ["a"] * 4, [""] * 20:
self.config["scopes"] = good_value
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.scopes, tuple(good_value))
# Check invalid types are rejected.
for bad_value in (
"",
"abc",
123,
{},
{"a": 1},
None,
[None],
[["a"]],
[{}],
[456],
):
with self.assertRaises(ValidationError):
self.config["scopes"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# Check that "scopes" may be omitted.
del self.config["scopes"]
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(model.scopes, ("openid",))
@parameterized.expand(["authorization_endpoint", "token_endpoint"])
def test_endpoints_required_when_discovery_disabled(self, key: str) -> None:
"""Example of a validator that applies to multiple fields."""
# Test that this field is required if discovery is disabled
self.config["discover"] = False
with self.assertRaises(ValidationError):
self.config[key] = None
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
del self.config[key]
OIDCProviderModel.parse_obj(self.config)
# We don't validate that the endpoint is a sensible URL; anything str will do
self.config[key] = "blahblah"
OIDCProviderModel.parse_obj(self.config)
def check_all_cases_pass():
self.config[key] = None
OIDCProviderModel.parse_obj(self.config)
del self.config[key]
OIDCProviderModel.parse_obj(self.config)
self.config[key] = "blahblah"
OIDCProviderModel.parse_obj(self.config)
# With discovery enabled, all three cases are accepted.
self.config["discover"] = True
check_all_cases_pass()
# If not specified, discovery is also on by default.
del self.config["discover"]
check_all_cases_pass()
def test_userinfo_endpoint(self) -> None:
"""Example of a more fiddly validator"""
# This field is required if discovery is disabled and the openid scope
# not requested.
self.assertNotIn("userinfo_endpoint", self.config)
with self.assertRaises(ValidationError):
self.config["discover"] = False
self.config["scopes"] = ()
OIDCProviderModel.parse_obj(self.config)
# Still an error even if other scopes are provided
with self.assertRaises(ValidationError):
self.config["discover"] = False
self.config["scopes"] = ("potato", "tomato")
OIDCProviderModel.parse_obj(self.config)
# Passing an explicit None for userinfo_endpoint should also be an error.
with self.assertRaises(ValidationError):
self.config["discover"] = False
self.config["scopes"] = ()
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
# No error if we enable discovery.
self.config["discover"] = True
self.config["scopes"] = ()
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
# No error if we enable the openid scope.
self.config["discover"] = False
self.config["scopes"] = ("openid",)
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
# No error if we don't specify scopes. (They default to `("openid", )`)
self.config["discover"] = False
del self.config["scopes"]
self.config["userinfo_endpoint"] = None
OIDCProviderModel.parse_obj(self.config)
def test_attribute_requirements(self):
# Example of a field involving a nested model
model = OIDCProviderModel.parse_obj(self.config)
self.assertIsInstance(model.attribute_requirements, tuple)
self.assertEqual(
len(model.attribute_requirements), 1, model.attribute_requirements
)
# Bad types should be rejected
bad_value: object
for bad_value in 123, 456.0, False, None, {}, ["hello"]:
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = bad_value
OIDCProviderModel.parse_obj(self.config)
# An empty list of requirements is okay, ...
self.config["attribute_requirements"] = []
OIDCProviderModel.parse_obj(self.config)
# ...as is an omitted list of requirements...
del self.config["attribute_requirements"]
OIDCProviderModel.parse_obj(self.config)
# ...but not an explicit None.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = None
OIDCProviderModel.parse_obj(self.config)
# Multiple requirements are fine.
self.config["attribute_requirements"] = [{"attribute": "k", "value": "v"}] * 3
model = OIDCProviderModel.parse_obj(self.config)
self.assertEqual(
len(model.attribute_requirements), 3, model.attribute_requirements
)
# The submodel's field types should be enforced too.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "key", "value": 123}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": 123, "value": "val"}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "a", "value": ["b"]}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "a", "value": None}]
OIDCProviderModel.parse_obj(self.config)
# Missing fields in the submodel are an error.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"attribute": "a"}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{"value": "v"}]
OIDCProviderModel.parse_obj(self.config)
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [{}]
OIDCProviderModel.parse_obj(self.config)
# Extra fields in the submodel are an error.
with self.assertRaises(ValidationError):
self.config["attribute_requirements"] = [
{"attribute": "a", "value": "v", "answer": "forty-two"}
]
OIDCProviderModel.parse_obj(self.config)

View File

@@ -17,7 +17,7 @@ from typing import Dict, List, Tuple
from synapse.api.errors import Codes
from synapse.federation.transport.server import BaseFederationServlet
from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
from synapse.federation.transport.server._base import Authenticator
from synapse.http.server import JsonResource, cancellable
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -112,30 +112,3 @@ class BaseFederationServletCancellationTests(
expect_cancellation=False,
expected_body={"result": True},
)
class BaseFederationAuthorizationTests(unittest.TestCase):
def test_authorization_header(self) -> None:
"""Tests that the Authorization header is parsed correctly."""
# test a "normal" Authorization header
self.assertEqual(
_parse_auth_header(
b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar"'
),
("foo", "ed25519:1", "sig", "bar"),
)
# test an Authorization with extra spaces, upper-case names, and escaped
# characters
self.assertEqual(
_parse_auth_header(
b'X-Matrix ORIGIN=foo,KEY="ed25\\519:1",SIG="sig",destination="bar"'
),
("foo", "ed25519:1", "sig", "bar"),
)
self.assertEqual(
_parse_auth_header(
b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar",extra_field=ignored'
),
("foo", "ed25519:1", "sig", "bar"),
)

View File

@@ -434,6 +434,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
},
)
# "Complete" a transaction.
# All this really does for us is make an entry in the application_services_state
# database table, which tracks the current stream_token per stream ID per AS.
self.get_success(
self.hs.get_datastores().main.complete_appservice_txn(
0,
interested_appservice,
)
)
# Now, pretend that we receive a large burst of read receipts (300 total) that
# all come in at once.
for i in range(300):

View File

@@ -332,7 +332,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
most_recent_prev_event_depth,
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
# mapping from (type, state_key) -> state_event_id
assert most_recent_prev_event_id is not None
prev_state_map = self.get_success(
self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
)

View File

@@ -2489,5 +2489,4 @@ PURGE_TABLES = [
"room_tags",
# "state_groups", # Current impl leaves orphaned state groups around.
"state_groups_state",
"federation_inbound_events_staging",
]

View File

@@ -14,7 +14,7 @@
import json
import os
import tempfile
from typing import List, cast
from typing import List, Optional, cast
from unittest.mock import Mock
import yaml
@@ -149,12 +149,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
outfile.write(yaml.dump(as_yaml))
self.as_yaml_files.append(as_token)
def _set_state(self, id: str, state: ApplicationServiceState):
def _set_state(
self, id: str, state: ApplicationServiceState, txn: Optional[int] = None
):
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
"INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)"
),
(id, state.value),
(id, state.value, txn),
)
def _insert_txn(self, as_id, txn_id, events):
@@ -277,6 +280,17 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
)
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT last_txn FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
)
self.assertEqual(1, len(res))
self.assertEqual(txn_id, res[0][0])
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
@@ -302,13 +316,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
res = self.get_success(
self.db_pool.runQuery(
self.engine.convert_param_style(
"SELECT state FROM application_services_state WHERE as_id=?"
"SELECT last_txn, state FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
)
self.assertEqual(1, len(res))
self.assertEqual(ApplicationServiceState.UP.value, res[0][0])
self.assertEqual(txn_id, res[0][0])
self.assertEqual(ApplicationServiceState.UP.value, res[0][1])
res = self.get_success(
self.db_pool.runQuery(

View File

@@ -129,19 +129,6 @@ class _DummyStore:
async def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
async def get_state_group_for_events(self, event_ids):
res = {}
for event in event_ids:
res[event] = self._event_to_state_group[event]
return res
async def get_state_for_groups(self, groups):
res = {}
for group in groups:
state = self._group_to_state[group]
res[group] = state
return res
class DictObj(dict):
def __init__(self, **kwargs):