mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-13 01:50:46 +00:00
Compare commits
7 Commits
dmr/oidc-c
...
dmr/reject
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0953cad3e4 | ||
|
|
ad6a6675bf | ||
|
|
21d1347f2c | ||
|
|
a9fe3350f8 | ||
|
|
a1adede444 | ||
|
|
79f1cef5e4 | ||
|
|
8c977edec8 |
11
CHANGES.md
11
CHANGES.md
@@ -1,14 +1,3 @@
|
||||
Synapse 1.59.1 (2022-05-18)
|
||||
===========================
|
||||
|
||||
This release fixes a long-standing issue which could prevent Synapse's user directory for updating properly.
|
||||
|
||||
Bugfixes
|
||||
----------------
|
||||
|
||||
- Fix a long-standing bug where the user directory background process would fail to make forward progress if a user included a null codepoint in their display name or avatar. Contributed by Nick @ Beeper. ([\#12762](https://github.com/matrix-org/synapse/issues/12762))
|
||||
|
||||
|
||||
Synapse 1.59.0 (2022-05-17)
|
||||
===========================
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Preparation for faster-room-join work: return subsets of room state which we already have, immediately.
|
||||
@@ -1 +0,0 @@
|
||||
Add support for [MSC3787: Allowing knocks to restricted rooms](https://github.com/matrix-org/matrix-spec-proposals/pull/3787).
|
||||
@@ -1 +0,0 @@
|
||||
Remove code which updates unused database column `application_services_state.last_txn`.
|
||||
@@ -1 +0,0 @@
|
||||
Add some type hints to datastore.
|
||||
@@ -1 +0,0 @@
|
||||
Fix push to dismiss notifications when read on another client. Contributed by @SpiritCroc @ Beeper.
|
||||
@@ -1 +0,0 @@
|
||||
Downgrade some OIDC errors to warnings in the logs, to reduce the noise of Sentry reports.
|
||||
@@ -1 +0,0 @@
|
||||
Link to the configuration manual from the welcome page of the documentation.
|
||||
@@ -1 +0,0 @@
|
||||
Add some type hints to datastore.
|
||||
@@ -1 +0,0 @@
|
||||
Add information regarding the `rc_invites` ratelimiting option to the configuration docs.
|
||||
@@ -1 +0,0 @@
|
||||
Add documentation for cancellation of request processing.
|
||||
@@ -1 +0,0 @@
|
||||
Recommend using docker to run tests against postgres.
|
||||
@@ -1 +0,0 @@
|
||||
Tweak the mypy plugin so that `@cached` can accept `on_invalidate=None`.
|
||||
@@ -1 +0,0 @@
|
||||
Delete events from the `federation_inbound_events_staging` table when a room is purged through the admin API.
|
||||
@@ -1 +0,0 @@
|
||||
Move methods that call `add_push_rule` to the `PushRuleStore` class.
|
||||
@@ -1 +0,0 @@
|
||||
Make handling of federation Authorization header (more) compliant with RFC7230.
|
||||
@@ -1 +0,0 @@
|
||||
Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.
|
||||
@@ -1 +0,0 @@
|
||||
Give a meaningful error message when a client tries to create a room with an invalid alias localpart.
|
||||
@@ -1 +0,0 @@
|
||||
Do not keep going if there are 5 back-to-back background update failures.
|
||||
@@ -1 +0,0 @@
|
||||
Fix federation when using the demo scripts.
|
||||
@@ -1 +0,0 @@
|
||||
Fix invalid YAML syntax in the example documentation for the `url_preview_accept_language` config option.
|
||||
6
debian/changelog
vendored
6
debian/changelog
vendored
@@ -1,9 +1,3 @@
|
||||
matrix-synapse-py3 (1.59.1) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.59.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Wed, 18 May 2022 11:41:46 +0100
|
||||
|
||||
matrix-synapse-py3 (1.59.0) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.59.0.
|
||||
|
||||
@@ -12,7 +12,6 @@ export PYTHONPATH
|
||||
|
||||
echo "$PYTHONPATH"
|
||||
|
||||
# Create servers which listen on HTTP at 808x and HTTPS at 848x.
|
||||
for port in 8080 8081 8082; do
|
||||
echo "Starting server on port $port... "
|
||||
|
||||
@@ -20,12 +19,10 @@ for port in 8080 8081 8082; do
|
||||
mkdir -p demo/$port
|
||||
pushd demo/$port || exit
|
||||
|
||||
# Generate the configuration for the homeserver at localhost:848x, note that
|
||||
# the homeserver name needs to match the HTTPS listening port for federation
|
||||
# to properly work..
|
||||
# Generate the configuration for the homeserver at localhost:848x.
|
||||
python3 -m synapse.app.homeserver \
|
||||
--generate-config \
|
||||
--server-name "localhost:$https_port" \
|
||||
--server-name "localhost:$port" \
|
||||
--config-path "$port.config" \
|
||||
--report-stats no
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@
|
||||
- [Database Schemas](development/database_schema.md)
|
||||
- [Experimental features](development/experimental_features.md)
|
||||
- [Synapse Architecture]()
|
||||
- [Cancellation](development/synapse_architecture/cancellation.md)
|
||||
- [Log Contexts](log_contexts.md)
|
||||
- [Replication](replication.md)
|
||||
- [TCP Replication](tcp_replication.md)
|
||||
|
||||
@@ -206,32 +206,7 @@ This means that we need to run our unit tests against PostgreSQL too. Our CI doe
|
||||
this automatically for pull requests and release candidates, but it's sometimes
|
||||
useful to reproduce this locally.
|
||||
|
||||
#### Using Docker
|
||||
|
||||
The easiest way to do so is to run Postgres via a docker container. In one
|
||||
terminal:
|
||||
|
||||
```shell
|
||||
docker run --rm -e POSTGRES_PASSWORD=mysecretpassword -e POSTGRES_USER=postgres -e POSTGRES_DB=postgress -p 5432:5432 postgres:14
|
||||
```
|
||||
|
||||
If you see an error like
|
||||
|
||||
```
|
||||
docker: Error response from daemon: driver failed programming external connectivity on endpoint nice_ride (b57bbe2e251b70015518d00c9981e8cb8346b5c785250341a6c53e3c899875f1): Error starting userland proxy: listen tcp4 0.0.0.0:5432: bind: address already in use.
|
||||
```
|
||||
|
||||
then something is already bound to port 5432. You're probably already running postgres locally.
|
||||
|
||||
Once you have a postgres server running, invoke `trial` in a second terminal:
|
||||
|
||||
```shell
|
||||
SYNAPSE_POSTGRES=1 SYNAPSE_POSTGRES_HOST=127.0.0.1 SYNAPSE_POSTGRES_USER=postgres SYNAPSE_POSTGRES_PASSWORD=mysecretpassword poetry run trial tests
|
||||
````
|
||||
|
||||
#### Using an existing Postgres installation
|
||||
|
||||
If you have postgres already installed on your system, you can run `trial` with the
|
||||
To do so, [configure Postgres](../postgres.md) and run `trial` with the
|
||||
following environment variables matching your configuration:
|
||||
|
||||
- `SYNAPSE_POSTGRES` to anything nonempty
|
||||
@@ -254,8 +229,8 @@ You don't need to specify the host, user, port or password if your Postgres
|
||||
server is set to authenticate you over the UNIX socket (i.e. if the `psql` command
|
||||
works without further arguments).
|
||||
|
||||
Your Postgres account needs to be able to create databases; see the postgres
|
||||
docs for [`ALTER ROLE`](https://www.postgresql.org/docs/current/sql-alterrole.html).
|
||||
Your Postgres account needs to be able to create databases.
|
||||
|
||||
|
||||
## Run the integration tests ([Sytest](https://github.com/matrix-org/sytest)).
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
Requires you to have a [Synapse development environment setup](https://matrix-org.github.io/synapse/develop/development/contributing_guide.html#4-install-the-dependencies).
|
||||
|
||||
The demo setup allows running three federation Synapse servers, with server
|
||||
names `localhost:8480`, `localhost:8481`, and `localhost:8482`.
|
||||
names `localhost:8080`, `localhost:8081`, and `localhost:8082`.
|
||||
|
||||
You can access them via any Matrix client over HTTP at `localhost:8080`,
|
||||
`localhost:8081`, and `localhost:8082` or over HTTPS at `localhost:8480`,
|
||||
@@ -20,10 +20,9 @@ and the servers are configured in a highly insecure way, including:
|
||||
The servers are configured to store their data under `demo/8080`, `demo/8081`, and
|
||||
`demo/8082`. This includes configuration, logs, SQLite databases, and media.
|
||||
|
||||
Note that when joining a public room on a different homeserver via "#foo:bar.net",
|
||||
then you are (in the current implementation) joining a room with room_id "foo".
|
||||
This means that it won't work if your homeserver already has a room with that
|
||||
name.
|
||||
Note that when joining a public room on a different HS via "#foo:bar.net", then
|
||||
you are (in the current impl) joining a room with room_id "foo". This means that
|
||||
it won't work if your HS already has a room with that name.
|
||||
|
||||
## Using the demo scripts
|
||||
|
||||
|
||||
@@ -1,392 +0,0 @@
|
||||
# Cancellation
|
||||
Sometimes, requests take a long time to service and clients disconnect
|
||||
before Synapse produces a response. To avoid wasting resources, Synapse
|
||||
can cancel request processing for select endpoints marked with the
|
||||
`@cancellable` decorator.
|
||||
|
||||
Synapse makes use of Twisted's `Deferred.cancel()` feature to make
|
||||
cancellation work. The `@cancellable` decorator does nothing by itself
|
||||
and merely acts as a flag, signalling to developers and other code alike
|
||||
that a method can be cancelled.
|
||||
|
||||
## Enabling cancellation for an endpoint
|
||||
1. Check that the endpoint method, and any `async` functions in its call
|
||||
tree handle cancellation correctly. See
|
||||
[Handling cancellation correctly](#handling-cancellation-correctly)
|
||||
for a list of things to look out for.
|
||||
2. Add the `@cancellable` decorator to the `on_GET/POST/PUT/DELETE`
|
||||
method. It's not recommended to make non-`GET` methods cancellable,
|
||||
since cancellation midway through some database updates is less
|
||||
likely to be handled correctly.
|
||||
|
||||
## Mechanics
|
||||
There are two stages to cancellation: downward propagation of a
|
||||
`cancel()` call, followed by upwards propagation of a `CancelledError`
|
||||
out of a blocked `await`.
|
||||
Both Twisted and asyncio have a cancellation mechanism.
|
||||
|
||||
| | Method | Exception | Exception inherits from |
|
||||
|---------------|---------------------|-----------------------------------------|-------------------------|
|
||||
| Twisted | `Deferred.cancel()` | `twisted.internet.defer.CancelledError` | `Exception` (!) |
|
||||
| asyncio | `Task.cancel()` | `asyncio.CancelledError` | `BaseException` |
|
||||
|
||||
### Deferred.cancel()
|
||||
When Synapse starts handling a request, it runs the async method
|
||||
responsible for handling it using `defer.ensureDeferred`, which returns
|
||||
a `Deferred`. For example:
|
||||
|
||||
```python
|
||||
def do_something() -> Deferred[None]:
|
||||
...
|
||||
|
||||
@cancellable
|
||||
async def on_GET() -> Tuple[int, JsonDict]:
|
||||
d = make_deferred_yieldable(do_something())
|
||||
await d
|
||||
return 200, {}
|
||||
|
||||
request = defer.ensureDeferred(on_GET())
|
||||
```
|
||||
|
||||
When a client disconnects early, Synapse checks for the presence of the
|
||||
`@cancellable` decorator on `on_GET`. Since `on_GET` is cancellable,
|
||||
`Deferred.cancel()` is called on the `Deferred` from
|
||||
`defer.ensureDeferred`, ie. `request`. Twisted knows which `Deferred`
|
||||
`request` is waiting on and passes the `cancel()` call on to `d`.
|
||||
|
||||
The `Deferred` being waited on, `d`, may have its own handling for
|
||||
`cancel()` and pass the call on to other `Deferred`s.
|
||||
|
||||
Eventually, a `Deferred` handles the `cancel()` call by resolving itself
|
||||
with a `CancelledError`.
|
||||
|
||||
### CancelledError
|
||||
The `CancelledError` gets raised out of the `await` and bubbles up, as
|
||||
per normal Python exception handling.
|
||||
|
||||
## Handling cancellation correctly
|
||||
In general, when writing code that might be subject to cancellation, two
|
||||
things must be considered:
|
||||
* The effect of `CancelledError`s raised out of `await`s.
|
||||
* The effect of `Deferred`s being `cancel()`ed.
|
||||
|
||||
Examples of code that handles cancellation incorrectly include:
|
||||
* `try-except` blocks which swallow `CancelledError`s.
|
||||
* Code that shares the same `Deferred`, which may be cancelled, between
|
||||
multiple requests.
|
||||
* Code that starts some processing that's exempt from cancellation, but
|
||||
uses a logging context from cancellable code. The logging context
|
||||
will be finished upon cancellation, while the uncancelled processing
|
||||
is still using it.
|
||||
|
||||
Some common patterns are listed below in more detail.
|
||||
|
||||
### `async` function calls
|
||||
Most functions in Synapse are relatively straightforward from a
|
||||
cancellation standpoint: they don't do anything with `Deferred`s and
|
||||
purely call and `await` other `async` functions.
|
||||
|
||||
An `async` function handles cancellation correctly if its own code
|
||||
handles cancellation correctly and all the async function it calls
|
||||
handle cancellation correctly. For example:
|
||||
```python
|
||||
async def do_two_things() -> None:
|
||||
check_something()
|
||||
await do_something()
|
||||
await do_something_else()
|
||||
```
|
||||
`do_two_things` handles cancellation correctly if `do_something` and
|
||||
`do_something_else` handle cancellation correctly.
|
||||
|
||||
That is, when checking whether a function handles cancellation
|
||||
correctly, its implementation and all its `async` function calls need to
|
||||
be checked, recursively.
|
||||
|
||||
As `check_something` is not `async`, it does not need to be checked.
|
||||
|
||||
### CancelledErrors
|
||||
Because Twisted's `CancelledError`s are `Exception`s, it's easy to
|
||||
accidentally catch and suppress them. Care must be taken to ensure that
|
||||
`CancelledError`s are allowed to propagate upwards.
|
||||
|
||||
<table width="100%">
|
||||
<tr>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Bad**:
|
||||
```python
|
||||
try:
|
||||
await do_something()
|
||||
except Exception:
|
||||
# `CancelledError` gets swallowed here.
|
||||
logger.info(...)
|
||||
```
|
||||
</td>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Good**:
|
||||
```python
|
||||
try:
|
||||
await do_something()
|
||||
except CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.info(...)
|
||||
```
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**OK**:
|
||||
```python
|
||||
try:
|
||||
check_something()
|
||||
# A `CancelledError` won't ever be raised here.
|
||||
except Exception:
|
||||
logger.info(...)
|
||||
```
|
||||
</td>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Good**:
|
||||
```python
|
||||
try:
|
||||
await do_something()
|
||||
except ValueError:
|
||||
logger.info(...)
|
||||
```
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
#### defer.gatherResults
|
||||
`defer.gatherResults` produces a `Deferred` which:
|
||||
* broadcasts `cancel()` calls to every `Deferred` being waited on.
|
||||
* wraps the first exception it sees in a `FirstError`.
|
||||
|
||||
Together, this means that `CancelledError`s will be wrapped in
|
||||
a `FirstError` unless unwrapped. Such `FirstError`s are liable to be
|
||||
swallowed, so they must be unwrapped.
|
||||
|
||||
<table width="100%">
|
||||
<tr>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Bad**:
|
||||
```python
|
||||
async def do_something() -> None:
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults([...], consumeErrors=True)
|
||||
)
|
||||
|
||||
try:
|
||||
await do_something()
|
||||
except CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
# `FirstError(CancelledError)` gets swallowed here.
|
||||
logger.info(...)
|
||||
```
|
||||
|
||||
</td>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Good**:
|
||||
```python
|
||||
async def do_something() -> None:
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults([...], consumeErrors=True)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
try:
|
||||
await do_something()
|
||||
except CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.info(...)
|
||||
```
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Creation of `Deferred`s
|
||||
If a function creates a `Deferred`, the effect of cancelling it must be considered. `Deferred`s that get shared are likely to have unintended behaviour when cancelled.
|
||||
|
||||
<table width="100%">
|
||||
<tr>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Bad**:
|
||||
```python
|
||||
cache: Dict[str, Deferred[None]] = {}
|
||||
|
||||
def wait_for_room(room_id: str) -> Deferred[None]:
|
||||
deferred = cache.get(room_id)
|
||||
if deferred is None:
|
||||
deferred = Deferred()
|
||||
cache[room_id] = deferred
|
||||
# `deferred` can have multiple waiters.
|
||||
# All of them will observe a `CancelledError`
|
||||
# if any one of them is cancelled.
|
||||
return make_deferred_yieldable(deferred)
|
||||
|
||||
# Request 1
|
||||
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
|
||||
# Request 2
|
||||
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
|
||||
```
|
||||
</td>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Good**:
|
||||
```python
|
||||
cache: Dict[str, Deferred[None]] = {}
|
||||
|
||||
def wait_for_room(room_id: str) -> Deferred[None]:
|
||||
deferred = cache.get(room_id)
|
||||
if deferred is None:
|
||||
deferred = Deferred()
|
||||
cache[room_id] = deferred
|
||||
# `deferred` will never be cancelled now.
|
||||
# A `CancelledError` will still come out of
|
||||
# the `await`.
|
||||
# `delay_cancellation` may also be used.
|
||||
return make_deferred_yieldable(stop_cancellation(deferred))
|
||||
|
||||
# Request 1
|
||||
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
|
||||
# Request 2
|
||||
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
|
||||
```
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="50%" valign="top">
|
||||
</td>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Good**:
|
||||
```python
|
||||
cache: Dict[str, List[Deferred[None]]] = {}
|
||||
|
||||
def wait_for_room(room_id: str) -> Deferred[None]:
|
||||
if room_id not in cache:
|
||||
cache[room_id] = []
|
||||
# Each request gets its own `Deferred` to wait on.
|
||||
deferred = Deferred()
|
||||
cache[room_id]].append(deferred)
|
||||
return make_deferred_yieldable(deferred)
|
||||
|
||||
# Request 1
|
||||
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
|
||||
# Request 2
|
||||
await wait_for_room("!aAAaaAaaaAAAaAaAA:matrix.org")
|
||||
```
|
||||
</td>
|
||||
</table>
|
||||
|
||||
### Uncancelled processing
|
||||
Some `async` functions may kick off some `async` processing which is
|
||||
intentionally protected from cancellation, by `stop_cancellation` or
|
||||
other means. If the `async` processing inherits the logcontext of the
|
||||
request which initiated it, care must be taken to ensure that the
|
||||
logcontext is not finished before the `async` processing completes.
|
||||
|
||||
<table width="100%">
|
||||
<tr>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Bad**:
|
||||
```python
|
||||
cache: Optional[ObservableDeferred[None]] = None
|
||||
|
||||
async def do_something_else(
|
||||
to_resolve: Deferred[None]
|
||||
) -> None:
|
||||
await ...
|
||||
logger.info("done!")
|
||||
to_resolve.callback(None)
|
||||
|
||||
async def do_something() -> None:
|
||||
if not cache:
|
||||
to_resolve = Deferred()
|
||||
cache = ObservableDeferred(to_resolve)
|
||||
# `do_something_else` will never be cancelled and
|
||||
# can outlive the `request-1` logging context.
|
||||
run_in_background(do_something_else, to_resolve)
|
||||
|
||||
await make_deferred_yieldable(cache.observe())
|
||||
|
||||
with LoggingContext("request-1"):
|
||||
await do_something()
|
||||
```
|
||||
</td>
|
||||
<td width="50%" valign="top">
|
||||
|
||||
**Good**:
|
||||
```python
|
||||
cache: Optional[ObservableDeferred[None]] = None
|
||||
|
||||
async def do_something_else(
|
||||
to_resolve: Deferred[None]
|
||||
) -> None:
|
||||
await ...
|
||||
logger.info("done!")
|
||||
to_resolve.callback(None)
|
||||
|
||||
async def do_something() -> None:
|
||||
if not cache:
|
||||
to_resolve = Deferred()
|
||||
cache = ObservableDeferred(to_resolve)
|
||||
run_in_background(do_something_else, to_resolve)
|
||||
# We'll wait until `do_something_else` is
|
||||
# done before raising a `CancelledError`.
|
||||
await make_deferred_yieldable(
|
||||
delay_cancellation(cache.observe())
|
||||
)
|
||||
else:
|
||||
await make_deferred_yieldable(cache.observe())
|
||||
|
||||
with LoggingContext("request-1"):
|
||||
await do_something()
|
||||
```
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="50%">
|
||||
|
||||
**OK**:
|
||||
```python
|
||||
cache: Optional[ObservableDeferred[None]] = None
|
||||
|
||||
async def do_something_else(
|
||||
to_resolve: Deferred[None]
|
||||
) -> None:
|
||||
await ...
|
||||
logger.info("done!")
|
||||
to_resolve.callback(None)
|
||||
|
||||
async def do_something() -> None:
|
||||
if not cache:
|
||||
to_resolve = Deferred()
|
||||
cache = ObservableDeferred(to_resolve)
|
||||
# `do_something_else` will get its own independent
|
||||
# logging context. `request-1` will not count any
|
||||
# metrics from `do_something_else`.
|
||||
run_as_background_process(
|
||||
"do_something_else",
|
||||
do_something_else,
|
||||
to_resolve,
|
||||
)
|
||||
|
||||
await make_deferred_yieldable(cache.observe())
|
||||
|
||||
with LoggingContext("request-1"):
|
||||
await do_something()
|
||||
```
|
||||
</td>
|
||||
<td width="50%">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
@@ -1194,7 +1194,7 @@ For more information on using Synapse with Postgres,
|
||||
see [here](../../postgres.md).
|
||||
|
||||
Example SQLite configuration:
|
||||
```yaml
|
||||
```
|
||||
database:
|
||||
name: sqlite3
|
||||
args:
|
||||
@@ -1202,7 +1202,7 @@ database:
|
||||
```
|
||||
|
||||
Example Postgres configuration:
|
||||
```yaml
|
||||
```
|
||||
database:
|
||||
name: psycopg2
|
||||
txn_limit: 10000
|
||||
@@ -1357,20 +1357,6 @@ This option sets ratelimiting how often invites can be sent in a room or to a
|
||||
specific user. `per_room` defaults to `per_second: 0.3`, `burst_count: 10` and
|
||||
`per_user` defaults to `per_second: 0.003`, `burst_count: 5`.
|
||||
|
||||
Client requests that invite user(s) when [creating a
|
||||
room](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom)
|
||||
will count against the `rc_invites.per_room` limit, whereas
|
||||
client requests to [invite a single user to a
|
||||
room](https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite)
|
||||
will count against both the `rc_invites.per_user` and `rc_invites.per_room` limits.
|
||||
|
||||
Federation requests to invite a user will count against the `rc_invites.per_user`
|
||||
limit only, as Synapse presumes ratelimiting by room will be done by the sending server.
|
||||
|
||||
The `rc_invites.per_user` limit applies to the *receiver* of the invite, rather than the
|
||||
sender, meaning that a `rc_invite.per_user.burst_count` of 5 mandates that a single user
|
||||
cannot *receive* more than a burst of 5 invites at a time.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
rc_invites:
|
||||
@@ -1679,10 +1665,10 @@ Defaults to "en".
|
||||
Example configuration:
|
||||
```yaml
|
||||
url_preview_accept_language:
|
||||
- 'en-UK'
|
||||
- 'en-US;q=0.9'
|
||||
- 'fr;q=0.8'
|
||||
- '*;q=0.7'
|
||||
- en-UK
|
||||
- en-US;q=0.9
|
||||
- fr;q=0.8
|
||||
- *;q=0.7
|
||||
```
|
||||
----
|
||||
Config option: `oembed`
|
||||
|
||||
@@ -7,10 +7,10 @@ team.
|
||||
## Installing and using Synapse
|
||||
|
||||
This documentation covers topics for **installation**, **configuration** and
|
||||
**maintenance** of your Synapse process:
|
||||
**maintainence** of your Synapse process:
|
||||
|
||||
* Learn how to [install](setup/installation.md) and
|
||||
[configure](usage/configuration/config_documentation.md) your own instance, perhaps with [Single
|
||||
[configure](usage/configuration/index.html) your own instance, perhaps with [Single
|
||||
Sign-On](usage/configuration/user_authentication/index.html).
|
||||
|
||||
* See how to [upgrade](upgrade.md) between Synapse versions.
|
||||
@@ -65,7 +65,7 @@ following documentation:
|
||||
|
||||
Want to help keep Synapse going but don't know how to code? Synapse is a
|
||||
[Matrix.org Foundation](https://matrix.org) project. Consider becoming a
|
||||
supporter on [Liberapay](https://liberapay.com/matrixdotorg),
|
||||
supportor on [Liberapay](https://liberapay.com/matrixdotorg),
|
||||
[Patreon](https://patreon.com/matrixdotorg) or through
|
||||
[PayPal](https://paypal.me/matrixdotorg) via a one-time donation.
|
||||
|
||||
|
||||
11
mypy.ini
11
mypy.ini
@@ -1,6 +1,6 @@
|
||||
[mypy]
|
||||
namespace_packages = True
|
||||
plugins = pydantic.mypy, mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
|
||||
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
|
||||
follow_imports = normal
|
||||
check_untyped_defs = True
|
||||
show_error_codes = True
|
||||
@@ -27,6 +27,9 @@ exclude = (?x)
|
||||
|synapse/storage/databases/__init__.py
|
||||
|synapse/storage/databases/main/cache.py
|
||||
|synapse/storage/databases/main/devices.py
|
||||
|synapse/storage/databases/main/event_federation.py
|
||||
|synapse/storage/databases/main/push_rule.py
|
||||
|synapse/storage/databases/main/roommember.py
|
||||
|synapse/storage/schema/
|
||||
|
||||
|tests/api/test_auth.py
|
||||
@@ -86,12 +89,6 @@ exclude = (?x)
|
||||
|tests/utils.py
|
||||
)$
|
||||
|
||||
[pydantic-mypy]
|
||||
init_forbid_extra = True
|
||||
init_typed = True
|
||||
warn_required_dynamic_aliases = True
|
||||
warn_untyped_fields = True
|
||||
|
||||
[mypy-synapse._scripts.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
||||
54
poetry.lock
generated
54
poetry.lock
generated
@@ -778,21 +778,6 @@ category = "main"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "1.9.0"
|
||||
description = "Data validation and settings management using python 3.6 type hinting"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.6.1"
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = ">=3.7.4.3"
|
||||
|
||||
[package.extras]
|
||||
dotenv = ["python-dotenv (>=0.10.4)"]
|
||||
email = ["email-validator (>=1.0.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyflakes"
|
||||
version = "2.4.0"
|
||||
@@ -1578,7 +1563,7 @@ url_preview = ["lxml"]
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.7.1"
|
||||
content-hash = "54ec27d5187386653b8d0d13ed843f86ae68b3ebbee633c82dfffc7605b99f74"
|
||||
content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1"
|
||||
|
||||
[metadata.files]
|
||||
attrs = [
|
||||
@@ -2266,43 +2251,6 @@ pycparser = [
|
||||
{file = "pycparser-2.21-py2.py3-none-any.whl", hash = "sha256:8ee45429555515e1f6b185e78100aea234072576aa43ab53aefcae078162fca9"},
|
||||
{file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"},
|
||||
]
|
||||
pydantic = [
|
||||
{file = "pydantic-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cb23bcc093697cdea2708baae4f9ba0e972960a835af22560f6ae4e7e47d33f5"},
|
||||
{file = "pydantic-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1d5278bd9f0eee04a44c712982343103bba63507480bfd2fc2790fa70cd64cf4"},
|
||||
{file = "pydantic-1.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab624700dc145aa809e6f3ec93fb8e7d0f99d9023b713f6a953637429b437d37"},
|
||||
{file = "pydantic-1.9.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c8d7da6f1c1049eefb718d43d99ad73100c958a5367d30b9321b092771e96c25"},
|
||||
{file = "pydantic-1.9.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:3c3b035103bd4e2e4a28da9da7ef2fa47b00ee4a9cf4f1a735214c1bcd05e0f6"},
|
||||
{file = "pydantic-1.9.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3011b975c973819883842c5ab925a4e4298dffccf7782c55ec3580ed17dc464c"},
|
||||
{file = "pydantic-1.9.0-cp310-cp310-win_amd64.whl", hash = "sha256:086254884d10d3ba16da0588604ffdc5aab3f7f09557b998373e885c690dd398"},
|
||||
{file = "pydantic-1.9.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:0fe476769acaa7fcddd17cadd172b156b53546ec3614a4d880e5d29ea5fbce65"},
|
||||
{file = "pydantic-1.9.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8e9dcf1ac499679aceedac7e7ca6d8641f0193c591a2d090282aaf8e9445a46"},
|
||||
{file = "pydantic-1.9.0-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1e4c28f30e767fd07f2ddc6f74f41f034d1dd6bc526cd59e63a82fe8bb9ef4c"},
|
||||
{file = "pydantic-1.9.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:c86229333cabaaa8c51cf971496f10318c4734cf7b641f08af0a6fbf17ca3054"},
|
||||
{file = "pydantic-1.9.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:c0727bda6e38144d464daec31dff936a82917f431d9c39c39c60a26567eae3ed"},
|
||||
{file = "pydantic-1.9.0-cp36-cp36m-win_amd64.whl", hash = "sha256:dee5ef83a76ac31ab0c78c10bd7d5437bfdb6358c95b91f1ba7ff7b76f9996a1"},
|
||||
{file = "pydantic-1.9.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:d9c9bdb3af48e242838f9f6e6127de9be7063aad17b32215ccc36a09c5cf1070"},
|
||||
{file = "pydantic-1.9.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ee7e3209db1e468341ef41fe263eb655f67f5c5a76c924044314e139a1103a2"},
|
||||
{file = "pydantic-1.9.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b6037175234850ffd094ca77bf60fb54b08b5b22bc85865331dd3bda7a02fa1"},
|
||||
{file = "pydantic-1.9.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b2571db88c636d862b35090ccf92bf24004393f85c8870a37f42d9f23d13e032"},
|
||||
{file = "pydantic-1.9.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8b5ac0f1c83d31b324e57a273da59197c83d1bb18171e512908fe5dc7278a1d6"},
|
||||
{file = "pydantic-1.9.0-cp37-cp37m-win_amd64.whl", hash = "sha256:bbbc94d0c94dd80b3340fc4f04fd4d701f4b038ebad72c39693c794fd3bc2d9d"},
|
||||
{file = "pydantic-1.9.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e0896200b6a40197405af18828da49f067c2fa1f821491bc8f5bde241ef3f7d7"},
|
||||
{file = "pydantic-1.9.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bdfdadb5994b44bd5579cfa7c9b0e1b0e540c952d56f627eb227851cda9db77"},
|
||||
{file = "pydantic-1.9.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:574936363cd4b9eed8acdd6b80d0143162f2eb654d96cb3a8ee91d3e64bf4cf9"},
|
||||
{file = "pydantic-1.9.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c556695b699f648c58373b542534308922c46a1cda06ea47bc9ca45ef5b39ae6"},
|
||||
{file = "pydantic-1.9.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f947352c3434e8b937e3aa8f96f47bdfe6d92779e44bb3f41e4c213ba6a32145"},
|
||||
{file = "pydantic-1.9.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5e48ef4a8b8c066c4a31409d91d7ca372a774d0212da2787c0d32f8045b1e034"},
|
||||
{file = "pydantic-1.9.0-cp38-cp38-win_amd64.whl", hash = "sha256:96f240bce182ca7fe045c76bcebfa0b0534a1bf402ed05914a6f1dadff91877f"},
|
||||
{file = "pydantic-1.9.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:815ddebb2792efd4bba5488bc8fde09c29e8ca3227d27cf1c6990fc830fd292b"},
|
||||
{file = "pydantic-1.9.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c5b77947b9e85a54848343928b597b4f74fc364b70926b3c4441ff52620640c"},
|
||||
{file = "pydantic-1.9.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c68c3bc88dbda2a6805e9a142ce84782d3930f8fdd9655430d8576315ad97ce"},
|
||||
{file = "pydantic-1.9.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a79330f8571faf71bf93667d3ee054609816f10a259a109a0738dac983b23c3"},
|
||||
{file = "pydantic-1.9.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f5a64b64ddf4c99fe201ac2724daada8595ada0d102ab96d019c1555c2d6441d"},
|
||||
{file = "pydantic-1.9.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a733965f1a2b4090a5238d40d983dcd78f3ecea221c7af1497b845a9709c1721"},
|
||||
{file = "pydantic-1.9.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cc6a4cb8a118ffec2ca5fcb47afbacb4f16d0ab8b7350ddea5e8ef7bcc53a16"},
|
||||
{file = "pydantic-1.9.0-py3-none-any.whl", hash = "sha256:085ca1de245782e9b46cefcf99deecc67d418737a1fd3f6a4f511344b613a5b3"},
|
||||
{file = "pydantic-1.9.0.tar.gz", hash = "sha256:742645059757a56ecd886faf4ed2441b9c0cd406079c2b4bee51bcc3fbcd510a"},
|
||||
]
|
||||
pyflakes = [
|
||||
{file = "pyflakes-2.4.0-py2.py3-none-any.whl", hash = "sha256:3bb3a3f256f4b7968c9c788781e4ff07dce46bdf12339dcda61053375426ee2e"},
|
||||
{file = "pyflakes-2.4.0.tar.gz", hash = "sha256:05a85c2872edf37a4ed30b0cce2f6093e1d0581f8c19d7393122da7e25b2b24c"},
|
||||
|
||||
@@ -54,7 +54,7 @@ skip_gitignore = true
|
||||
|
||||
[tool.poetry]
|
||||
name = "matrix-synapse"
|
||||
version = "1.59.1"
|
||||
version = "1.59.0"
|
||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
||||
license = "Apache-2.0"
|
||||
@@ -182,7 +182,6 @@ hiredis = { version = "*", optional = true }
|
||||
Pympler = { version = "*", optional = true }
|
||||
parameterized = { version = ">=0.7.4", optional = true }
|
||||
idna = { version = ">=2.5", optional = true }
|
||||
pydantic = ">=1.9.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
# NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Callable, Optional, Type
|
||||
from mypy.nodes import ARG_NAMED_OPT
|
||||
from mypy.plugin import MethodSigContext, Plugin
|
||||
from mypy.typeops import bind_self
|
||||
from mypy.types import CallableType, NoneType, UnionType
|
||||
from mypy.types import CallableType, NoneType
|
||||
|
||||
|
||||
class SynapsePlugin(Plugin):
|
||||
@@ -72,20 +72,13 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||
|
||||
# Third, we add an optional "on_invalidate" argument.
|
||||
#
|
||||
# This is a either
|
||||
# - a callable which accepts no input and returns nothing, or
|
||||
# - None.
|
||||
calltyp = UnionType(
|
||||
[
|
||||
NoneType(),
|
||||
CallableType(
|
||||
arg_types=[],
|
||||
arg_kinds=[],
|
||||
arg_names=[],
|
||||
ret_type=NoneType(),
|
||||
fallback=ctx.api.named_generic_type("builtins.function", []),
|
||||
),
|
||||
]
|
||||
# This is a callable which accepts no input and returns nothing.
|
||||
calltyp = CallableType(
|
||||
arg_types=[],
|
||||
arg_kinds=[],
|
||||
arg_names=[],
|
||||
ret_type=NoneType(),
|
||||
fallback=ctx.api.named_generic_type("builtins.function", []),
|
||||
)
|
||||
|
||||
arg_types.append(calltyp)
|
||||
@@ -102,7 +95,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||
|
||||
|
||||
def plugin(version: str) -> Type[SynapsePlugin]:
|
||||
# This is the entry point of the plugin, and lets us deal with the fact
|
||||
# This is the entry point of the plugin, and let's us deal with the fact
|
||||
# that the mypy plugin interface is *not* stable by looking at the version
|
||||
# string.
|
||||
#
|
||||
|
||||
@@ -65,8 +65,6 @@ class JoinRules:
|
||||
PRIVATE: Final = "private"
|
||||
# As defined for MSC3083.
|
||||
RESTRICTED: Final = "restricted"
|
||||
# As defined for MSC3787.
|
||||
KNOCK_RESTRICTED: Final = "knock_restricted"
|
||||
|
||||
|
||||
class RestrictedJoinRuleTypes:
|
||||
|
||||
@@ -81,9 +81,6 @@ class RoomVersion:
|
||||
msc2716_historical: bool
|
||||
# MSC2716: Adds support for redacting "insertion", "chunk", and "marker" events
|
||||
msc2716_redactions: bool
|
||||
# MSC3787: Adds support for a `knock_restricted` join rule, mixing concepts of
|
||||
# knocks and restricted join rules into the same join condition.
|
||||
msc3787_knock_restricted_join_rule: bool
|
||||
|
||||
|
||||
class RoomVersions:
|
||||
@@ -102,7 +99,6 @@ class RoomVersions:
|
||||
msc2403_knocking=False,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V2 = RoomVersion(
|
||||
"2",
|
||||
@@ -119,7 +115,6 @@ class RoomVersions:
|
||||
msc2403_knocking=False,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V3 = RoomVersion(
|
||||
"3",
|
||||
@@ -136,7 +131,6 @@ class RoomVersions:
|
||||
msc2403_knocking=False,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V4 = RoomVersion(
|
||||
"4",
|
||||
@@ -153,7 +147,6 @@ class RoomVersions:
|
||||
msc2403_knocking=False,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V5 = RoomVersion(
|
||||
"5",
|
||||
@@ -170,7 +163,6 @@ class RoomVersions:
|
||||
msc2403_knocking=False,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V6 = RoomVersion(
|
||||
"6",
|
||||
@@ -187,7 +179,6 @@ class RoomVersions:
|
||||
msc2403_knocking=False,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
MSC2176 = RoomVersion(
|
||||
"org.matrix.msc2176",
|
||||
@@ -204,7 +195,6 @@ class RoomVersions:
|
||||
msc2403_knocking=False,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V7 = RoomVersion(
|
||||
"7",
|
||||
@@ -221,7 +211,6 @@ class RoomVersions:
|
||||
msc2403_knocking=True,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V8 = RoomVersion(
|
||||
"8",
|
||||
@@ -238,7 +227,6 @@ class RoomVersions:
|
||||
msc2403_knocking=True,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
V9 = RoomVersion(
|
||||
"9",
|
||||
@@ -255,7 +243,6 @@ class RoomVersions:
|
||||
msc2403_knocking=True,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
MSC2716v3 = RoomVersion(
|
||||
"org.matrix.msc2716v3",
|
||||
@@ -272,24 +259,6 @@ class RoomVersions:
|
||||
msc2403_knocking=True,
|
||||
msc2716_historical=True,
|
||||
msc2716_redactions=True,
|
||||
msc3787_knock_restricted_join_rule=False,
|
||||
)
|
||||
MSC3787 = RoomVersion(
|
||||
"org.matrix.msc3787",
|
||||
RoomDisposition.UNSTABLE,
|
||||
EventFormatVersions.V3,
|
||||
StateResolutionVersions.V2,
|
||||
enforce_key_validity=True,
|
||||
special_case_aliases_auth=False,
|
||||
strict_canonicaljson=True,
|
||||
limit_notifications_power_levels=True,
|
||||
msc2176_redaction_rules=False,
|
||||
msc3083_join_rules=True,
|
||||
msc3375_redaction_rules=True,
|
||||
msc2403_knocking=True,
|
||||
msc2716_historical=False,
|
||||
msc2716_redactions=False,
|
||||
msc3787_knock_restricted_join_rule=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -307,7 +276,6 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
|
||||
RoomVersions.V8,
|
||||
RoomVersions.V9,
|
||||
RoomVersions.MSC2716v3,
|
||||
RoomVersions.MSC3787,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,240 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Mapping, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, StrictBool, StrictStr, constr, validator
|
||||
from pydantic.fields import ModelField
|
||||
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
||||
# Ugly workaround for https://github.com/samuelcolvin/pydantic/issues/156. Mypy doesn't
|
||||
# consider expressions like `constr(...)` to be valid types.
|
||||
if TYPE_CHECKING:
|
||||
IDP_ID_TYPE = str
|
||||
IDP_BRAND_TYPE = str
|
||||
else:
|
||||
IDP_ID_TYPE = constr(
|
||||
strict=True,
|
||||
min_length=1,
|
||||
max_length=250,
|
||||
regex="^[A-Za-z0-9._~-]+$", # noqa: F722
|
||||
)
|
||||
IDP_BRAND_TYPE = constr(
|
||||
strict=True,
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
regex="^[a-z][a-z0-9_.-]*$", # noqa: F722
|
||||
)
|
||||
|
||||
|
||||
# the following list of enum members is the same as the keys of
|
||||
# authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it
|
||||
# to avoid importing authlib here.
|
||||
class ClientAuthMethods(str, Enum):
|
||||
# The duplication is unfortunate. 3.11 should have StrEnum though,
|
||||
# and there is a backport available for 3.8.6.
|
||||
client_secret_basic = "client_secret_basic"
|
||||
client_secret_post = "client_secret_post"
|
||||
none = "none"
|
||||
|
||||
|
||||
class UserProfileMethod(str, Enum):
|
||||
# The duplication is unfortunate. 3.11 should have StrEnum though,
|
||||
# and there is a backport available for 3.8.6.
|
||||
auto = "auto"
|
||||
userinfo_endpoint = "userinfo_endpoint"
|
||||
|
||||
|
||||
class SSOAttributeRequirement(BaseModel):
|
||||
class Config:
|
||||
# Complain if someone provides a field that's not one of those listed here.
|
||||
# Pydantic suggests making your own BaseModel subclass if you want to do this,
|
||||
# see https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
|
||||
extra = "forbid"
|
||||
|
||||
attribute: StrictStr
|
||||
# Note: a comment in config/oidc.py suggests that `value` may be optional. But
|
||||
# The JSON schema seems to forbid this.
|
||||
value: StrictStr
|
||||
|
||||
|
||||
class ClientSecretJWTKey(BaseModel):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
# a pem-encoded signing key
|
||||
# TODO: how should we handle key_file?
|
||||
key: StrictStr
|
||||
|
||||
# properties to include in the JWT header
|
||||
# TODO: validator should enforce that jwt_header contains an 'alg'.
|
||||
jwt_header: Mapping[str, str]
|
||||
|
||||
# properties to include in the JWT payload.
|
||||
jwt_payload: Mapping[str, str] = {}
|
||||
|
||||
|
||||
class OIDCProviderModel(BaseModel):
|
||||
"""
|
||||
Notes on Pydantic:
|
||||
- I've used StrictStr because a plain `str` e.g. accepts integers and calls str()
|
||||
on them
|
||||
- pulling out constr() into IDP_ID_TYPE is a little awkward, but necessary to keep
|
||||
mypy happy
|
||||
-
|
||||
"""
|
||||
|
||||
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
||||
# table, as well as the query/path parameter used in the login protocol.
|
||||
idp_id: IDP_ID_TYPE
|
||||
|
||||
@validator("idp_id")
|
||||
def ensure_idp_id_prefix(cls, idp_id: str) -> str:
|
||||
"""Prefix the given IDP with a prefix specific to the SSO mechanism, to avoid
|
||||
clashes with other mechs (such as SAML, CAS).
|
||||
|
||||
We allow "oidc" as an exception so that people migrating from old-style
|
||||
"oidc_config" format (which has long used "oidc" as its idp_id) can migrate to
|
||||
a new-style "oidc_providers" entry without changing the idp_id for their provider
|
||||
(and thereby invalidating their user_external_ids data).
|
||||
"""
|
||||
if idp_id != "oidc":
|
||||
return "oidc-" + idp_id
|
||||
return idp_id
|
||||
|
||||
# user-facing name for this identity provider.
|
||||
idp_name: StrictStr
|
||||
|
||||
# Optional MXC URI for icon for this IdP.
|
||||
idp_icon: Optional[StrictStr]
|
||||
|
||||
@validator("idp_icon")
|
||||
def idp_icon_is_an_mxc_url(cls, idp_icon: str) -> str:
|
||||
parse_and_validate_mxc_uri(idp_icon)
|
||||
return idp_icon
|
||||
|
||||
# Optional brand identifier for this IdP.
|
||||
idp_brand: Optional[StrictStr]
|
||||
|
||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||
discover: StrictBool = True
|
||||
|
||||
# the OIDC issuer. Used to validate tokens and (if discovery is enabled) to
|
||||
# discover the provider's endpoints.
|
||||
issuer: StrictStr
|
||||
|
||||
# oauth2 client id to use
|
||||
client_id: StrictStr
|
||||
|
||||
# oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
|
||||
# a secret.
|
||||
client_secret: Optional[StrictStr]
|
||||
|
||||
# key to use to construct a JWT to use as a client secret. May be `None` if
|
||||
# `client_secret` is set.
|
||||
# TODO: test that ClientSecretJWTKey is being parsed correctly
|
||||
client_secret_jwt_key: Optional[ClientSecretJWTKey]
|
||||
|
||||
# TODO: what is the precise relationship between client_auth_method, client_secret
|
||||
# and client_secret_jwt_key? Is there anything we should enforce with a validator?
|
||||
# auth method to use when exchanging the token.
|
||||
# Valid values are 'client_secret_basic', 'client_secret_post' and
|
||||
# 'none'.
|
||||
client_auth_method: ClientAuthMethods = ClientAuthMethods.client_secret_basic
|
||||
|
||||
# list of scopes to request
|
||||
scopes: Tuple[StrictStr, ...] = ("openid",)
|
||||
|
||||
# the oauth2 authorization endpoint. Required if discovery is disabled.
|
||||
authorization_endpoint: Optional[StrictStr]
|
||||
|
||||
# the oauth2 token endpoint. Required if discovery is disabled.
|
||||
token_endpoint: Optional[StrictStr]
|
||||
|
||||
# Normally, validators aren't run when fields don't have a value provided.
|
||||
# Using validate=True ensures we run the validator even in that situation.
|
||||
@validator("authorization_endpoint", "token_endpoint", always=True)
|
||||
def endpoints_required_if_discovery_disabled(
|
||||
cls,
|
||||
endpoint_url: Optional[str],
|
||||
values: Mapping[str, Any],
|
||||
field: ModelField,
|
||||
) -> Optional[str]:
|
||||
# `if "discover" in values means: don't run our checks if "discover" didn't
|
||||
# pass validation. (NB: validation order is the field definition order)
|
||||
if "discover" in values and not values["discover"] and endpoint_url is None:
|
||||
raise ValueError(f"{field.name} is required if discovery is disabled")
|
||||
return endpoint_url
|
||||
|
||||
# the OIDC userinfo endpoint. Required if discovery is disabled and the
|
||||
# "openid" scope is not requested.
|
||||
userinfo_endpoint: Optional[StrictStr]
|
||||
|
||||
@validator("userinfo_endpoint", always=True)
|
||||
def userinfo_endpoint_required_without_discovery_and_without_openid_scope(
|
||||
cls, userinfo_endpoint: Optional[str], values: Mapping[str, Any]
|
||||
) -> Optional[str]:
|
||||
discovery_disabled = "discover" in values and not values["discover"]
|
||||
openid_scope_not_requested = (
|
||||
"scopes" in values and "openid" not in values["scopes"]
|
||||
)
|
||||
if (
|
||||
discovery_disabled
|
||||
and openid_scope_not_requested
|
||||
and userinfo_endpoint is None
|
||||
):
|
||||
raise ValueError(
|
||||
"userinfo_requirement is required if discovery is disabled and"
|
||||
"the 'openid' scope is not requested"
|
||||
)
|
||||
return userinfo_endpoint
|
||||
|
||||
# URI where to fetch the JWKS. Required if discovery is disabled and the
|
||||
# "openid" scope is used.
|
||||
jwks_uri: Optional[StrictStr]
|
||||
|
||||
@validator("jwks_uri", always=True)
|
||||
def jwks_uri_required_without_discovery_but_with_openid_scope(
|
||||
cls, jwks_uri: Optional[str], values: Mapping[str, Any]
|
||||
) -> Optional[str]:
|
||||
discovery_disabled = "discover" in values and not values["discover"]
|
||||
openid_scope_requested = "scopes" in values and "openid" in values["scopes"]
|
||||
if discovery_disabled and openid_scope_requested and jwks_uri is None:
|
||||
raise ValueError(
|
||||
"jwks_uri is required if discovery is disabled and"
|
||||
"the 'openid' scope is not requested"
|
||||
)
|
||||
return jwks_uri
|
||||
|
||||
# Whether to skip metadata verification
|
||||
skip_verification: StrictBool = False
|
||||
|
||||
# Whether to fetch the user profile from the userinfo endpoint. Valid
|
||||
# values are: "auto" or "userinfo_endpoint".
|
||||
user_profile_method: UserProfileMethod = UserProfileMethod.auto
|
||||
|
||||
# whether to allow a user logging in via OIDC to match a pre-existing account
|
||||
# instead of failing
|
||||
allow_existing_users: StrictBool = False
|
||||
|
||||
# the class of the user mapping provider
|
||||
# TODO there was logic for this
|
||||
user_mapping_provider_class: Any # TODO: Type
|
||||
|
||||
# the config of the user mapping provider
|
||||
# TODO
|
||||
user_mapping_provider_config: Any
|
||||
|
||||
# required attributes to require in userinfo to allow login/registration
|
||||
# TODO: wouldn't this be better expressed as a Mapping[str, str]?
|
||||
attribute_requirements: Tuple[SSOAttributeRequirement, ...] = ()
|
||||
|
||||
|
||||
class LegacyOIDCProviderModel(OIDCProviderModel):
|
||||
# These fields could be omitted in the old scheme.
|
||||
idp_id: IDP_ID_TYPE = "oidc"
|
||||
idp_name: StrictStr = "OIDC"
|
||||
|
||||
|
||||
# TODO
|
||||
# top-level config: check we don't have any duplicate idp_ids now
|
||||
# compute callback url
|
||||
@@ -414,12 +414,7 @@ def _is_membership_change_allowed(
|
||||
raise AuthError(403, "You are banned from this room")
|
||||
elif join_rule == JoinRules.PUBLIC:
|
||||
pass
|
||||
elif (
|
||||
room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED
|
||||
) or (
|
||||
room_version.msc3787_knock_restricted_join_rule
|
||||
and join_rule == JoinRules.KNOCK_RESTRICTED
|
||||
):
|
||||
elif room_version.msc3083_join_rules and join_rule == JoinRules.RESTRICTED:
|
||||
# This is the same as public, but the event must contain a reference
|
||||
# to the server who authorised the join. If the event does not contain
|
||||
# the proper content it is rejected.
|
||||
@@ -445,13 +440,8 @@ def _is_membership_change_allowed(
|
||||
if authorising_user_level < invite_level:
|
||||
raise AuthError(403, "Join event authorised by invalid server.")
|
||||
|
||||
elif (
|
||||
join_rule == JoinRules.INVITE
|
||||
or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
|
||||
or (
|
||||
room_version.msc3787_knock_restricted_join_rule
|
||||
and join_rule == JoinRules.KNOCK_RESTRICTED
|
||||
)
|
||||
elif join_rule == JoinRules.INVITE or (
|
||||
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
|
||||
):
|
||||
if not caller_in_room and not caller_invited:
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
@@ -472,10 +462,7 @@ def _is_membership_change_allowed(
|
||||
if user_level < ban_level or user_level <= target_level:
|
||||
raise AuthError(403, "You don't have permission to ban")
|
||||
elif room_version.msc2403_knocking and Membership.KNOCK == membership:
|
||||
if join_rule != JoinRules.KNOCK and (
|
||||
not room_version.msc3787_knock_restricted_join_rule
|
||||
or join_rule != JoinRules.KNOCK_RESTRICTED
|
||||
):
|
||||
if join_rule != JoinRules.KNOCK:
|
||||
raise AuthError(403, "You don't have permission to knock")
|
||||
elif target_user_id != event.user_id:
|
||||
raise AuthError(403, "You cannot knock for other users")
|
||||
|
||||
@@ -15,17 +15,7 @@
|
||||
import abc
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
Hashable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
@@ -419,7 +409,7 @@ class FederationSender(AbstractFederationSender):
|
||||
)
|
||||
return
|
||||
|
||||
destinations: Optional[Collection[str]] = None
|
||||
destinations: Optional[Set[str]] = None
|
||||
if not event.prev_event_ids():
|
||||
# If there are no prev event IDs then the state is empty
|
||||
# and so no remote servers in the room
|
||||
@@ -454,7 +444,7 @@ class FederationSender(AbstractFederationSender):
|
||||
)
|
||||
return
|
||||
|
||||
sharded_destinations = {
|
||||
destinations = {
|
||||
d
|
||||
for d in destinations
|
||||
if self._federation_shard_config.should_handle(
|
||||
@@ -466,12 +456,12 @@ class FederationSender(AbstractFederationSender):
|
||||
# If we are sending the event on behalf of another server
|
||||
# then it already has the event and there is no reason to
|
||||
# send the event to it.
|
||||
sharded_destinations.discard(send_on_behalf_of)
|
||||
destinations.discard(send_on_behalf_of)
|
||||
|
||||
logger.debug("Sending %s to %r", event, sharded_destinations)
|
||||
logger.debug("Sending %s to %r", event, destinations)
|
||||
|
||||
if sharded_destinations:
|
||||
await self._send_pdu(event, sharded_destinations)
|
||||
if destinations:
|
||||
await self._send_pdu(event, destinations)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
ts = await self.store.get_received_ts(event.event_id)
|
||||
|
||||
@@ -169,16 +169,14 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str
|
||||
"""
|
||||
try:
|
||||
header_str = header_bytes.decode("utf-8")
|
||||
params = re.split(" +", header_str)[1].split(",")
|
||||
params = header_str.split(" ")[1].split(",")
|
||||
param_dict: Dict[str, str] = {
|
||||
k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
|
||||
k: v for k, v in [param.split("=", maxsplit=1) for param in params]
|
||||
}
|
||||
|
||||
def strip_quotes(value: str) -> str:
|
||||
if value.startswith('"'):
|
||||
return re.sub(
|
||||
"\\\\(.)", lambda matchobj: matchobj.group(1), value[1:-1]
|
||||
)
|
||||
return value[1:-1]
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
@@ -71,9 +71,6 @@ class DirectoryHandler:
|
||||
if wchar in room_alias.localpart:
|
||||
raise SynapseError(400, "Invalid characters in room alias")
|
||||
|
||||
if ":" in room_alias.localpart:
|
||||
raise SynapseError(400, "Invalid character in room alias localpart: ':'.")
|
||||
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
# TODO(erikj): Change this.
|
||||
|
||||
@@ -241,15 +241,7 @@ class EventAuthHandler:
|
||||
|
||||
# If the join rule is not restricted, this doesn't apply.
|
||||
join_rules_event = await self._store.get_event(join_rules_event_id)
|
||||
content_join_rule = join_rules_event.content.get("join_rule")
|
||||
if content_join_rule == JoinRules.RESTRICTED:
|
||||
return True
|
||||
|
||||
# also check for MSC3787 behaviour
|
||||
if room_version.msc3787_knock_restricted_join_rule:
|
||||
return content_join_rule == JoinRules.KNOCK_RESTRICTED
|
||||
|
||||
return False
|
||||
return join_rules_event.content.get("join_rule") == JoinRules.RESTRICTED
|
||||
|
||||
async def get_rooms_that_allow_join(
|
||||
self, state_ids: StateMap[str]
|
||||
|
||||
@@ -224,7 +224,7 @@ class OidcHandler:
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
except MacaroonInvalidSignatureException as e:
|
||||
logger.warning("Could not verify session for OIDC callback: %s", e)
|
||||
logger.exception("Could not verify session for OIDC callback")
|
||||
self._sso_handler.render_error(request, "mismatching_session", str(e))
|
||||
return
|
||||
|
||||
@@ -827,7 +827,7 @@ class OidcProvider:
|
||||
logger.debug("Exchanging OAuth2 code for a token")
|
||||
token = await self._exchange_code(code)
|
||||
except OidcError as e:
|
||||
logger.warning("Could not exchange OAuth2 code: %s", e)
|
||||
logger.exception("Could not exchange OAuth2 code")
|
||||
self._sso_handler.render_error(request, e.error, e.error_description)
|
||||
return
|
||||
|
||||
|
||||
@@ -751,21 +751,6 @@ class RoomCreationHandler:
|
||||
if wchar in config["room_alias_name"]:
|
||||
raise SynapseError(400, "Invalid characters in room alias")
|
||||
|
||||
if ":" in config["room_alias_name"]:
|
||||
# Prevent someone from trying to pass in a full alias here.
|
||||
# Note that it's permissible for a room alias to have multiple
|
||||
# hash symbols at the start (notably bridged over from IRC, too),
|
||||
# but the first colon in the alias is defined to separate the local
|
||||
# part from the server name.
|
||||
# (remember server names can contain port numbers, also separated
|
||||
# by a colon. But under no circumstances should the local part be
|
||||
# allowed to contain a colon!)
|
||||
raise SynapseError(
|
||||
400,
|
||||
"':' is not permitted in the room alias name. "
|
||||
"Please note this expects a local part — 'wombat', not '#wombat:example.com'.",
|
||||
)
|
||||
|
||||
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
|
||||
mapping = await self.store.get_association_from_room_alias(room_alias)
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ class RoomBatchHandler:
|
||||
# We want to use the successor event depth so they appear after `prev_event` because
|
||||
# it has a larger `depth` but before the successor event because the `stream_ordering`
|
||||
# is negative before the successor event.
|
||||
assert most_recent_prev_event_id is not None
|
||||
successor_event_ids = await self.store.get_successor_events(
|
||||
most_recent_prev_event_id
|
||||
)
|
||||
@@ -140,7 +139,6 @@ class RoomBatchHandler:
|
||||
_,
|
||||
) = await self.store.get_max_depth_of(event_ids)
|
||||
# mapping from (type, state_key) -> state_event_id
|
||||
assert most_recent_event_id is not None
|
||||
prev_state_map = await self.state_store.get_state_ids_for_event(
|
||||
most_recent_event_id
|
||||
)
|
||||
|
||||
@@ -562,13 +562,8 @@ class RoomSummaryHandler:
|
||||
if join_rules_event_id:
|
||||
join_rules_event = await self._store.get_event(join_rules_event_id)
|
||||
join_rule = join_rules_event.content.get("join_rule")
|
||||
if (
|
||||
join_rule == JoinRules.PUBLIC
|
||||
or (room_version.msc2403_knocking and join_rule == JoinRules.KNOCK)
|
||||
or (
|
||||
room_version.msc3787_knock_restricted_join_rule
|
||||
and join_rule == JoinRules.KNOCK_RESTRICTED
|
||||
)
|
||||
if join_rule == JoinRules.PUBLIC or (
|
||||
room_version.msc2403_knocking and join_rule == JoinRules.KNOCK
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
@@ -411,10 +411,10 @@ class SyncHandler:
|
||||
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
|
||||
return sync_result
|
||||
|
||||
async def push_rules_for_user(self, user: UserID) -> Dict[str, Dict[str, list]]:
|
||||
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
||||
user_id = user.to_string()
|
||||
rules_raw = await self.store.get_push_rules_for_user(user_id)
|
||||
rules = format_push_rules_for_user(user, rules_raw)
|
||||
rules = await self.store.get_push_rules_for_user(user_id)
|
||||
rules = format_push_rules_for_user(user, rules)
|
||||
return rules
|
||||
|
||||
async def ephemeral_by_room(
|
||||
|
||||
@@ -747,7 +747,7 @@ class MatrixFederationHttpClient:
|
||||
for key, sig in request["signatures"][self.server_name].items():
|
||||
auth_headers.append(
|
||||
(
|
||||
'X-Matrix origin="%s",key="%s",sig="%s",destination="%s"'
|
||||
'X-Matrix origin=%s,key="%s",sig="%s",destination="%s"'
|
||||
% (
|
||||
self.server_name,
|
||||
key,
|
||||
|
||||
@@ -405,7 +405,7 @@ class HttpPusher(Pusher):
|
||||
rejected = []
|
||||
if "rejected" in resp:
|
||||
rejected = resp["rejected"]
|
||||
if not rejected:
|
||||
else:
|
||||
self.badge_count_last_call = badge
|
||||
return rejected
|
||||
|
||||
|
||||
@@ -148,9 +148,9 @@ class PushRuleRestServlet(RestServlet):
|
||||
# we build up the full structure and then decide which bits of it
|
||||
# to send which means doing unnecessary work sometimes but is
|
||||
# is probably not going to make a whole lot of difference
|
||||
rules_raw = await self.store.get_push_rules_for_user(user_id)
|
||||
rules = await self.store.get_push_rules_for_user(user_id)
|
||||
|
||||
rules = format_push_rules_for_user(requester.user, rules_raw)
|
||||
rules = format_push_rules_for_user(requester.user, rules)
|
||||
|
||||
path_parts = path.split("/")[1:]
|
||||
|
||||
|
||||
@@ -239,13 +239,13 @@ class StateHandler:
|
||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> FrozenSet[str]:
|
||||
) -> Set[str]:
|
||||
"""Get the hosts that were in a room at the given event ids
|
||||
|
||||
Args:
|
||||
@@ -288,6 +288,7 @@ class StateHandler:
|
||||
#
|
||||
# first of all, figure out the state before the event
|
||||
#
|
||||
|
||||
if old_state:
|
||||
# if we're given the state before the event, then we use that
|
||||
state_ids_before_event: StateMap[str] = {
|
||||
@@ -418,37 +419,33 @@ class StateHandler:
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
state_groups = await self.state_store.get_state_group_for_events(event_ids)
|
||||
# map from state group id to the state in that state group (where
|
||||
# 'state' is a map from state key to event id)
|
||||
# dict[int, dict[(str, str), str]]
|
||||
state_groups_ids = await self.state_store.get_state_groups_ids(
|
||||
room_id, event_ids
|
||||
)
|
||||
|
||||
state_group_ids = state_groups.values()
|
||||
if len(state_groups_ids) == 0:
|
||||
return _StateCacheEntry(state={}, state_group=None)
|
||||
elif len(state_groups_ids) == 1:
|
||||
name, state_list = list(state_groups_ids.items()).pop()
|
||||
|
||||
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
|
||||
|
||||
# check if each event has same state group id, if so there's no state to resolve
|
||||
state_group_ids_set = set(state_group_ids)
|
||||
if len(state_group_ids_set) == 1:
|
||||
(state_group_id,) = state_group_ids_set
|
||||
state = await self.state_store.get_state_for_groups(state_group_ids_set)
|
||||
prev_group, delta_ids = await self.state_store.get_state_group_delta(
|
||||
state_group_id
|
||||
)
|
||||
return _StateCacheEntry(
|
||||
state=state[state_group_id],
|
||||
state_group=state_group_id,
|
||||
state=state_list,
|
||||
state_group=name,
|
||||
prev_group=prev_group,
|
||||
delta_ids=delta_ids,
|
||||
)
|
||||
elif len(state_group_ids_set) == 0:
|
||||
return _StateCacheEntry(state={}, state_group=None)
|
||||
|
||||
room_version = await self.store.get_room_version_id(room_id)
|
||||
|
||||
state_to_resolve = await self.state_store.get_state_for_groups(
|
||||
state_group_ids_set
|
||||
)
|
||||
|
||||
result = await self._state_resolution_handler.resolve_state_groups(
|
||||
room_id,
|
||||
room_version,
|
||||
state_to_resolve,
|
||||
state_groups_ids,
|
||||
None,
|
||||
state_res_store=StateResolutionStore(self.store),
|
||||
)
|
||||
|
||||
@@ -282,20 +282,12 @@ class BackgroundUpdater:
|
||||
|
||||
self._running = True
|
||||
|
||||
back_to_back_failures = 0
|
||||
|
||||
try:
|
||||
logger.info("Starting background schema updates")
|
||||
while self.enabled:
|
||||
try:
|
||||
result = await self.do_next_background_update(sleep)
|
||||
back_to_back_failures = 0
|
||||
except Exception:
|
||||
back_to_back_failures += 1
|
||||
if back_to_back_failures >= 5:
|
||||
raise RuntimeError(
|
||||
"5 back-to-back background update failures; aborting."
|
||||
)
|
||||
logger.exception("Error doing update")
|
||||
else:
|
||||
if result:
|
||||
|
||||
@@ -26,7 +26,11 @@ from synapse.storage.database import (
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.storage.util.id_generators import (
|
||||
IdGenerator,
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
@@ -151,6 +155,8 @@ class DataStore(
|
||||
],
|
||||
)
|
||||
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||
self._group_updates_id_gen = StreamIdGenerator(
|
||||
db_conn, "local_group_updates", "stream_id"
|
||||
)
|
||||
|
||||
@@ -203,29 +203,19 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
"""Get the application service state.
|
||||
|
||||
Args:
|
||||
service: The service whose state to get.
|
||||
service: The service whose state to set.
|
||||
Returns:
|
||||
An ApplicationServiceState, or None if we have yet to attempt any
|
||||
transactions to the AS.
|
||||
An ApplicationServiceState or none.
|
||||
"""
|
||||
# if we have created transactions for this AS but not yet attempted to send
|
||||
# them, we will have a row in the table with state=NULL (recording the stream
|
||||
# positions we have processed up to).
|
||||
#
|
||||
# On the other hand, if we have yet to create any transactions for this AS at
|
||||
# all, then there will be no row for the AS.
|
||||
#
|
||||
# In either case, we return None to indicate "we don't yet know the state of
|
||||
# this AS".
|
||||
result = await self.db_pool.simple_select_one_onecol(
|
||||
result = await self.db_pool.simple_select_one(
|
||||
"application_services_state",
|
||||
{"as_id": service.id},
|
||||
retcol="state",
|
||||
["state"],
|
||||
allow_none=True,
|
||||
desc="get_appservice_state",
|
||||
)
|
||||
if result:
|
||||
return ApplicationServiceState(result)
|
||||
return ApplicationServiceState(result.get("state"))
|
||||
return None
|
||||
|
||||
async def set_appservice_state(
|
||||
@@ -306,6 +296,14 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
"""
|
||||
|
||||
def _complete_appservice_txn(txn: LoggingTransaction) -> None:
|
||||
# Set current txn_id for AS to 'txn_id'
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
"application_services_state",
|
||||
{"as_id": service.id},
|
||||
{"last_txn": txn_id},
|
||||
)
|
||||
|
||||
# Delete txn
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
@@ -454,15 +452,16 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
% (stream_type,)
|
||||
)
|
||||
|
||||
# this may be the first time that we're recording any state for this AS, so
|
||||
# we don't yet know if a row for it exists; hence we have to upsert here.
|
||||
await self.db_pool.simple_upsert(
|
||||
table="application_services_state",
|
||||
keyvalues={"as_id": service.id},
|
||||
values={f"{stream_type}_stream_id": pos},
|
||||
# no need to lock when emulating upsert: as_id is a unique key
|
||||
lock=False,
|
||||
desc="set_appservice_stream_type_pos",
|
||||
def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
|
||||
stream_id_type = "%s_stream_id" % stream_type
|
||||
txn.execute(
|
||||
"UPDATE application_services_state SET %s = ? WHERE as_id=?"
|
||||
% stream_id_type,
|
||||
(pos, service.id),
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_appservice_stream_type_pos", set_appservice_stream_type_pos_txn
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,17 +14,7 @@
|
||||
import itertools
|
||||
import logging
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter, Gauge
|
||||
@@ -43,7 +33,7 @@ from synapse.storage.database import (
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
@@ -145,7 +135,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
room = await self.get_room(room_id)
|
||||
if room["has_auth_chain_index"]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -168,11 +158,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_using_cover_index_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
include_given: bool,
|
||||
self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain IDs using the chain index."""
|
||||
|
||||
@@ -229,9 +215,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
chains: Dict[int, int] = {}
|
||||
|
||||
# Add all linked chains reachable from initial set of chains.
|
||||
for batch2 in batch_iter(event_chains, 1000):
|
||||
for batch in batch_iter(event_chains, 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "origin_chain_id", batch2
|
||||
txn.database_engine, "origin_chain_id", batch
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
@@ -311,7 +297,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
new_front: Set[str] = set()
|
||||
new_front = set()
|
||||
for chunk in batch_iter(front, 100):
|
||||
# Pull the auth events either from the cache or DB.
|
||||
to_fetch: List[str] = [] # Event IDs to fetch from DB
|
||||
@@ -330,7 +316,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
# Note we need to batch up the results by event ID before
|
||||
# adding to the cache.
|
||||
to_cache: Dict[str, List[Tuple[str, int]]] = {}
|
||||
to_cache = {}
|
||||
for event_id, auth_event_id, auth_event_depth in txn:
|
||||
to_cache.setdefault(event_id, []).append(
|
||||
(auth_event_id, auth_event_depth)
|
||||
@@ -363,7 +349,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
# Check if we have indexed the room so we can use the chain cover
|
||||
# algorithm.
|
||||
room = await self.get_room(room_id) # type: ignore[attr-defined]
|
||||
room = await self.get_room(room_id)
|
||||
if room["has_auth_chain_index"]:
|
||||
try:
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -384,7 +370,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
def _get_auth_chain_difference_using_cover_index_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
|
||||
self, txn: Cursor, room_id: str, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain difference using the chain index.
|
||||
|
||||
@@ -458,9 +444,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
# (We need to take a copy of `seen_chains` as we want to mutate it in
|
||||
# the loop)
|
||||
for batch2 in batch_iter(set(seen_chains), 1000):
|
||||
for batch in batch_iter(set(seen_chains), 1000):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "origin_chain_id", batch2
|
||||
txn.database_engine, "origin_chain_id", batch
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
@@ -543,7 +529,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
return result
|
||||
|
||||
def _get_auth_chain_difference_txn(
|
||||
self, txn: LoggingTransaction, state_sets: List[Set[str]]
|
||||
self, txn, state_sets: List[Set[str]]
|
||||
) -> Set[str]:
|
||||
"""Calculates the auth chain difference using a breadth first search.
|
||||
|
||||
@@ -616,7 +602,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
# I think building a temporary list with fetchall is more efficient than
|
||||
# just `search.extend(txn)`, but this is unconfirmed
|
||||
search.extend(cast(List[Tuple[int, str]], txn.fetchall()))
|
||||
search.extend(txn.fetchall())
|
||||
|
||||
# sort by depth
|
||||
search.sort()
|
||||
@@ -659,7 +645,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
# We parse the results and add the to the `found` set and the
|
||||
# cache (note we need to batch up the results by event ID before
|
||||
# adding to the cache).
|
||||
to_cache: Dict[str, List[Tuple[str, int]]] = {}
|
||||
to_cache = {}
|
||||
for event_id, auth_event_id, auth_event_depth in txn:
|
||||
to_cache.setdefault(event_id, []).append(
|
||||
(auth_event_id, auth_event_depth)
|
||||
@@ -710,7 +696,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
return {eid for eid, n in event_to_missing_sets.items() if n}
|
||||
|
||||
async def get_oldest_event_ids_with_depth_in_room(
|
||||
self, room_id: str
|
||||
self, room_id
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""Gets the oldest events(backwards extremities) in the room along with the
|
||||
aproximate depth.
|
||||
@@ -727,9 +713,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
List of (event_id, depth) tuples
|
||||
"""
|
||||
|
||||
def get_oldest_event_ids_with_depth_in_room_txn(
|
||||
txn: LoggingTransaction, room_id: str
|
||||
) -> List[Tuple[str, int]]:
|
||||
def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
|
||||
# Assemble a dictionary with event_id -> depth for the oldest events
|
||||
# we know of in the room. Backwards extremeties are the oldest
|
||||
# events we know of in the room but we only know of them because
|
||||
@@ -759,7 +743,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
txn.execute(sql, (room_id, False))
|
||||
|
||||
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_oldest_event_ids_with_depth_in_room",
|
||||
@@ -768,7 +752,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
async def get_insertion_event_backward_extremities_in_room(
|
||||
self, room_id: str
|
||||
self, room_id
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""Get the insertion events we know about that we haven't backfilled yet.
|
||||
|
||||
@@ -784,9 +768,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
List of (event_id, depth) tuples
|
||||
"""
|
||||
|
||||
def get_insertion_event_backward_extremities_in_room_txn(
|
||||
txn: LoggingTransaction, room_id: str
|
||||
) -> List[Tuple[str, int]]:
|
||||
def get_insertion_event_backward_extremities_in_room_txn(txn, room_id):
|
||||
sql = """
|
||||
SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
|
||||
/* We only want insertion events that are also marked as backwards extremities */
|
||||
@@ -798,7 +780,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
"""
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
return cast(List[Tuple[str, int]], txn.fetchall())
|
||||
return txn.fetchall()
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_insertion_event_backward_extremities_in_room",
|
||||
@@ -806,7 +788,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
room_id,
|
||||
)
|
||||
|
||||
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
|
||||
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
|
||||
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
|
||||
|
||||
Args:
|
||||
@@ -835,7 +817,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
return max_depth_event_id, current_max_depth
|
||||
|
||||
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
|
||||
async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
|
||||
"""Returns the event ID and depth for the event that has the min depth from a set of event IDs
|
||||
|
||||
Args:
|
||||
@@ -883,9 +865,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
|
||||
)
|
||||
|
||||
def _get_prev_events_for_room_txn(
|
||||
self, txn: LoggingTransaction, room_id: str
|
||||
) -> List[str]:
|
||||
def _get_prev_events_for_room_txn(self, txn, room_id: str):
|
||||
# we just use the 10 newest events. Older events will become
|
||||
# prev_events of future events.
|
||||
|
||||
@@ -916,7 +896,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
sorted by extremity count.
|
||||
"""
|
||||
|
||||
def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
|
||||
def _get_rooms_with_many_extremities_txn(txn):
|
||||
where_clause = "1=1"
|
||||
if room_id_filter:
|
||||
where_clause = "room_id NOT IN (%s)" % (
|
||||
@@ -957,9 +937,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
"get_min_depth", self._get_min_depth_interaction, room_id
|
||||
)
|
||||
|
||||
def _get_min_depth_interaction(
|
||||
self, txn: LoggingTransaction, room_id: str
|
||||
) -> Optional[int]:
|
||||
def _get_min_depth_interaction(self, txn, room_id):
|
||||
min_depth = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
@@ -988,24 +966,22 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
"""
|
||||
# We want to make the cache more effective, so we clamp to the last
|
||||
# change before the given ordering.
|
||||
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
|
||||
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
|
||||
|
||||
# We don't always have a full stream_to_exterm_id table, e.g. after
|
||||
# the upgrade that introduced it, so we make sure we never ask for a
|
||||
# stream_ordering from before a restart
|
||||
last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined]
|
||||
last_change = max(self._stream_order_on_start, last_change)
|
||||
|
||||
# provided the last_change is recent enough, we now clamp the requested
|
||||
# stream_ordering to it.
|
||||
if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined]
|
||||
if last_change > self.stream_ordering_month_ago:
|
||||
stream_ordering = min(last_change, stream_ordering)
|
||||
|
||||
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
|
||||
|
||||
@cached(max_entries=5000, num_args=2)
|
||||
async def _get_forward_extremeties_for_room(
|
||||
self, room_id: str, stream_ordering: int
|
||||
) -> List[str]:
|
||||
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
||||
"""For a given room_id and stream_ordering, return the forward
|
||||
extremeties of the room at that point in "time".
|
||||
|
||||
@@ -1013,7 +989,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
stream_orderings from that point.
|
||||
"""
|
||||
|
||||
if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined]
|
||||
if stream_ordering <= self.stream_ordering_month_ago:
|
||||
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
|
||||
|
||||
sql = """
|
||||
@@ -1026,7 +1002,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
WHERE room_id = ?
|
||||
"""
|
||||
|
||||
def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]:
|
||||
def get_forward_extremeties_for_room_txn(txn):
|
||||
txn.execute(sql, (stream_ordering, room_id))
|
||||
return [event_id for event_id, in txn]
|
||||
|
||||
@@ -1128,8 +1104,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
]
|
||||
|
||||
async def get_backfill_events(
|
||||
self, room_id: str, seed_event_id_list: List[str], limit: int
|
||||
) -> List[EventBase]:
|
||||
self, room_id: str, seed_event_id_list: list, limit: int
|
||||
):
|
||||
"""Get a list of Events for a given topic that occurred before (and
|
||||
including) the events in seed_event_id_list. Return a list of max size `limit`
|
||||
|
||||
@@ -1147,19 +1123,10 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
return sorted(
|
||||
# type-ignore: mypy doesn't like negating the Optional[int] stream_ordering.
|
||||
# But it's never None, because these events were previously persisted to the DB.
|
||||
events,
|
||||
key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator]
|
||||
events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering)
|
||||
)
|
||||
|
||||
def _get_backfill_events(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
seed_event_id_list: List[str],
|
||||
limit: int,
|
||||
) -> Set[str]:
|
||||
def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit):
|
||||
"""
|
||||
We want to make sure that we do a breadth-first, "depth" ordered search.
|
||||
We also handle navigating historical branches of history connected by
|
||||
@@ -1172,7 +1139,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
limit,
|
||||
)
|
||||
|
||||
event_id_results: Set[str] = set()
|
||||
event_id_results = set()
|
||||
|
||||
# In a PriorityQueue, the lowest valued entries are retrieved first.
|
||||
# We're using depth as the priority in the queue and tie-break based on
|
||||
@@ -1180,7 +1147,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
# highest and newest-in-time message. We add events to the queue with a
|
||||
# negative depth so that we process the newest-in-time messages first
|
||||
# going backwards in time. stream_ordering follows the same pattern.
|
||||
queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue()
|
||||
queue = PriorityQueue()
|
||||
|
||||
for seed_event_id in seed_event_id_list:
|
||||
event_lookup_result = self.db_pool.simple_select_one_txn(
|
||||
@@ -1286,13 +1253,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
return event_id_results
|
||||
|
||||
async def get_missing_events(
|
||||
self,
|
||||
room_id: str,
|
||||
earliest_events: List[str],
|
||||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> List[EventBase]:
|
||||
async def get_missing_events(self, room_id, earliest_events, latest_events, limit):
|
||||
ids = await self.db_pool.runInteraction(
|
||||
"get_missing_events",
|
||||
self._get_missing_events,
|
||||
@@ -1303,18 +1264,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
return await self.get_events_as_list(ids)
|
||||
|
||||
def _get_missing_events(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
earliest_events: List[str],
|
||||
latest_events: List[str],
|
||||
limit: int,
|
||||
) -> List[str]:
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
||||
|
||||
seen_events = set(earliest_events)
|
||||
front = set(latest_events) - seen_events
|
||||
event_results: List[str] = []
|
||||
event_results = []
|
||||
|
||||
query = (
|
||||
"SELECT prev_event_id FROM event_edges "
|
||||
@@ -1357,7 +1311,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
@wrap_as_background_process("delete_old_forward_extrem_cache")
|
||||
async def _delete_old_forward_extrem_cache(self) -> None:
|
||||
def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None:
|
||||
def _delete_old_forward_extrem_cache_txn(txn):
|
||||
# Delete entries older than a month, while making sure we don't delete
|
||||
# the only entries for a room.
|
||||
sql = """
|
||||
@@ -1370,7 +1324,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
) AND stream_ordering < ?
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined]
|
||||
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
@@ -1428,9 +1382,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
"""
|
||||
if self.db_pool.engine.supports_returning:
|
||||
|
||||
def _remove_received_event_from_staging_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[int]:
|
||||
def _remove_received_event_from_staging_txn(txn):
|
||||
sql = """
|
||||
DELETE FROM federation_inbound_events_staging
|
||||
WHERE origin = ? AND event_id = ?
|
||||
@@ -1438,24 +1390,21 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
"""
|
||||
|
||||
txn.execute(sql, (origin, event_id))
|
||||
row = cast(Optional[Tuple[int]], txn.fetchone())
|
||||
return txn.fetchone()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return row[0]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
row = await self.db_pool.runInteraction(
|
||||
"remove_received_event_from_staging",
|
||||
_remove_received_event_from_staging_txn,
|
||||
db_autocommit=True,
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return row[0]
|
||||
|
||||
else:
|
||||
|
||||
def _remove_received_event_from_staging_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[int]:
|
||||
def _remove_received_event_from_staging_txn(txn):
|
||||
received_ts = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="federation_inbound_events_staging",
|
||||
@@ -1488,9 +1437,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
"""Get the next event ID in the staging area for the given room."""
|
||||
|
||||
def _get_next_staged_event_id_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Tuple[str, str]]:
|
||||
def _get_next_staged_event_id_for_room_txn(txn):
|
||||
sql = """
|
||||
SELECT origin, event_id
|
||||
FROM federation_inbound_events_staging
|
||||
@@ -1501,7 +1448,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
|
||||
return cast(Optional[Tuple[str, str]], txn.fetchone())
|
||||
return txn.fetchone()
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
|
||||
@@ -1514,9 +1461,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
) -> Optional[Tuple[str, EventBase]]:
|
||||
"""Get the next event in the staging area for the given room."""
|
||||
|
||||
def _get_next_staged_event_for_room_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Tuple[str, str, str]]:
|
||||
def _get_next_staged_event_for_room_txn(txn):
|
||||
sql = """
|
||||
SELECT event_json, internal_metadata, origin
|
||||
FROM federation_inbound_events_staging
|
||||
@@ -1526,7 +1471,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
|
||||
return cast(Optional[Tuple[str, str, str]], txn.fetchone())
|
||||
return txn.fetchone()
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
"get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
|
||||
@@ -1654,20 +1599,18 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
@wrap_as_background_process("_get_stats_for_federation_staging")
|
||||
async def _get_stats_for_federation_staging(self) -> None:
|
||||
async def _get_stats_for_federation_staging(self):
|
||||
"""Update the prometheus metrics for the inbound federation staging area."""
|
||||
|
||||
def _get_stats_for_federation_staging_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[int, int]:
|
||||
def _get_stats_for_federation_staging_txn(txn):
|
||||
txn.execute("SELECT count(*) FROM federation_inbound_events_staging")
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
|
||||
txn.execute(
|
||||
"SELECT min(received_ts) FROM federation_inbound_events_staging"
|
||||
)
|
||||
|
||||
(received_ts,) = cast(Tuple[Optional[int]], txn.fetchone())
|
||||
(received_ts,) = txn.fetchone()
|
||||
|
||||
# If there is nothing in the staging area default it to 0.
|
||||
age = 0
|
||||
@@ -1708,21 +1651,19 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
|
||||
)
|
||||
|
||||
async def clean_room_for_join(self, room_id: str) -> None:
|
||||
await self.db_pool.runInteraction(
|
||||
async def clean_room_for_join(self, room_id):
|
||||
return await self.db_pool.runInteraction(
|
||||
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
||||
)
|
||||
|
||||
def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
|
||||
def _clean_room_for_join_txn(self, txn, room_id):
|
||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||
|
||||
async def _background_delete_non_state_event_auth(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
def delete_event_auth(txn: LoggingTransaction) -> bool:
|
||||
async def _background_delete_non_state_event_auth(self, progress, batch_size):
|
||||
def delete_event_auth(txn):
|
||||
target_min_stream_id = progress.get("target_min_stream_id_inclusive")
|
||||
max_stream_id = progress.get("max_stream_id_exclusive")
|
||||
|
||||
|
||||
@@ -14,19 +14,16 @@
|
||||
import calendar
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from synapse.metrics import GaugeBucketCollector
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.databases.main.event_push_actions import (
|
||||
EventPushActionsWorkerStore,
|
||||
)
|
||||
from synapse.storage.types import Cursor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -76,7 +73,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
|
||||
@wrap_as_background_process("read_forward_extremities")
|
||||
async def _read_forward_extremities(self) -> None:
|
||||
def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
|
||||
def fetch(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT t1.c, t2.c
|
||||
@@ -89,7 +86,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
) t2 ON t1.room_id = t2.room_id
|
||||
"""
|
||||
)
|
||||
return cast(List[Tuple[int, int]], txn.fetchall())
|
||||
return txn.fetchall()
|
||||
|
||||
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
|
||||
|
||||
@@ -107,20 +104,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
call to this function, it will return None.
|
||||
"""
|
||||
|
||||
def _count_messages(txn: LoggingTransaction) -> int:
|
||||
def _count_messages(txn):
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM events
|
||||
WHERE type = 'm.room.encrypted'
|
||||
AND stream_ordering > ?
|
||||
"""
|
||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
|
||||
|
||||
async def count_daily_sent_e2ee_messages(self) -> int:
|
||||
def _count_messages(txn: LoggingTransaction) -> int:
|
||||
def _count_messages(txn):
|
||||
# This is good enough as if you have silly characters in your own
|
||||
# hostname then that's your own fault.
|
||||
like_clause = "%:" + self.hs.hostname
|
||||
@@ -133,7 +130,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
"""
|
||||
|
||||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -141,14 +138,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
async def count_daily_active_e2ee_rooms(self) -> int:
|
||||
def _count(txn: LoggingTransaction) -> int:
|
||||
def _count(txn):
|
||||
sql = """
|
||||
SELECT COUNT(DISTINCT room_id) FROM events
|
||||
WHERE type = 'm.room.encrypted'
|
||||
AND stream_ordering > ?
|
||||
"""
|
||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -163,20 +160,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
call to this function, it will return None.
|
||||
"""
|
||||
|
||||
def _count_messages(txn: LoggingTransaction) -> int:
|
||||
def _count_messages(txn):
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM events
|
||||
WHERE type = 'm.room.message'
|
||||
AND stream_ordering > ?
|
||||
"""
|
||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return await self.db_pool.runInteraction("count_messages", _count_messages)
|
||||
|
||||
async def count_daily_sent_messages(self) -> int:
|
||||
def _count_messages(txn: LoggingTransaction) -> int:
|
||||
def _count_messages(txn):
|
||||
# This is good enough as if you have silly characters in your own
|
||||
# hostname then that's your own fault.
|
||||
like_clause = "%:" + self.hs.hostname
|
||||
@@ -189,7 +186,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
"""
|
||||
|
||||
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -197,14 +194,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
async def count_daily_active_rooms(self) -> int:
|
||||
def _count(txn: LoggingTransaction) -> int:
|
||||
def _count(txn):
|
||||
sql = """
|
||||
SELECT COUNT(DISTINCT room_id) FROM events
|
||||
WHERE type = 'm.room.message'
|
||||
AND stream_ordering > ?
|
||||
"""
|
||||
txn.execute(sql, (self.stream_ordering_day_ago,))
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
|
||||
@@ -230,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
"count_monthly_users", self._count_users, thirty_days_ago
|
||||
)
|
||||
|
||||
def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
|
||||
def _count_users(self, txn: Cursor, time_from: int) -> int:
|
||||
"""
|
||||
Returns number of users seen in the past time_from period
|
||||
"""
|
||||
@@ -245,7 +242,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
# Mypy knows that fetchone() might return None if there are no rows.
|
||||
# We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
|
||||
# returns exactly one row.
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone() # type: ignore[misc]
|
||||
return count
|
||||
|
||||
async def count_r30_users(self) -> Dict[str, int]:
|
||||
@@ -259,7 +256,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
A mapping of counts globally as well as broken out by platform.
|
||||
"""
|
||||
|
||||
def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
|
||||
def _count_r30_users(txn):
|
||||
thirty_days_in_secs = 86400 * 30
|
||||
now = int(self._clock.time())
|
||||
thirty_days_ago_in_secs = now - thirty_days_in_secs
|
||||
@@ -324,7 +321,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
|
||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
results["all"] = count
|
||||
|
||||
return results
|
||||
@@ -351,7 +348,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
- "web" (any web application -- it's not possible to distinguish Element Web here)
|
||||
"""
|
||||
|
||||
def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
|
||||
def _count_r30v2_users(txn):
|
||||
thirty_days_in_secs = 86400 * 30
|
||||
now = int(self._clock.time())
|
||||
sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
|
||||
@@ -448,8 +445,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
thirty_days_in_secs * 1000,
|
||||
),
|
||||
)
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
results["all"] = count
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
results["all"] = 0
|
||||
else:
|
||||
results["all"] = row[0]
|
||||
|
||||
return results
|
||||
|
||||
@@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
|
||||
Generates daily visit data for use in cohort/ retention analysis
|
||||
"""
|
||||
|
||||
def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
|
||||
def _generate_user_daily_visits(txn):
|
||||
logger.info("Calling _generate_user_daily_visits")
|
||||
today_start = self._get_start_of_day()
|
||||
a_day_in_milliseconds = 24 * 60 * 60 * 1000
|
||||
|
||||
@@ -417,7 +417,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
"room_account_data",
|
||||
"room_tags",
|
||||
"local_current_membership",
|
||||
"federation_inbound_events_staging",
|
||||
):
|
||||
logger.info("[purge] removing %s from %s", room_id, table)
|
||||
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
|
||||
|
||||
@@ -14,18 +14,14 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.homeserver import ExperimentalConfig
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||
@@ -34,12 +30,9 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import (
|
||||
AbstractStreamIdGenerator,
|
||||
AbstractStreamIdTracker,
|
||||
IdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
@@ -64,11 +57,7 @@ def _is_experimental_rule_enabled(
|
||||
return True
|
||||
|
||||
|
||||
def _load_rules(
|
||||
rawrules: List[JsonDict],
|
||||
enabled_map: Dict[str, bool],
|
||||
experimental_config: ExperimentalConfig,
|
||||
) -> List[JsonDict]:
|
||||
def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
|
||||
ruleslist = []
|
||||
for rawrule in rawrules:
|
||||
rule = dict(rawrule)
|
||||
@@ -148,7 +137,7 @@ class PushRulesWorkerStore(
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_max_push_rules_stream_id(self) -> int:
|
||||
def get_max_push_rules_stream_id(self):
|
||||
"""Get the position of the push rules stream.
|
||||
|
||||
Returns:
|
||||
@@ -157,7 +146,7 @@ class PushRulesWorkerStore(
|
||||
raise NotImplementedError()
|
||||
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
|
||||
async def get_push_rules_for_user(self, user_id):
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="push_rules",
|
||||
keyvalues={"user_name": user_id},
|
||||
@@ -179,7 +168,7 @@ class PushRulesWorkerStore(
|
||||
return _load_rules(rows, enabled_map, self.hs.config.experimental)
|
||||
|
||||
@cached(max_entries=5000)
|
||||
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
|
||||
async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
|
||||
results = await self.db_pool.simple_select_list(
|
||||
table="push_rules_enable",
|
||||
keyvalues={"user_name": user_id},
|
||||
@@ -195,13 +184,13 @@ class PushRulesWorkerStore(
|
||||
return False
|
||||
else:
|
||||
|
||||
def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
|
||||
def have_push_rules_changed_txn(txn):
|
||||
sql = (
|
||||
"SELECT COUNT(stream_id) FROM push_rules_stream"
|
||||
" WHERE user_id = ? AND ? < stream_id"
|
||||
)
|
||||
txn.execute(sql, (user_id, last_id))
|
||||
(count,) = cast(Tuple[int], txn.fetchone())
|
||||
(count,) = txn.fetchone()
|
||||
return bool(count)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -213,13 +202,11 @@ class PushRulesWorkerStore(
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def bulk_get_push_rules(
|
||||
self, user_ids: Collection[str]
|
||||
) -> Dict[str, List[JsonDict]]:
|
||||
async def bulk_get_push_rules(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
|
||||
results = {user_id: [] for user_id in user_ids}
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules",
|
||||
@@ -243,18 +230,67 @@ class PushRulesWorkerStore(
|
||||
|
||||
return results
|
||||
|
||||
async def copy_push_rule_from_room_to_room(
|
||||
self, new_room_id: str, user_id: str, rule: dict
|
||||
) -> None:
|
||||
"""Copy a single push rule from one room to another for a specific user.
|
||||
|
||||
Args:
|
||||
new_room_id: ID of the new room.
|
||||
user_id : ID of user the push rule belongs to.
|
||||
rule: A push rule.
|
||||
"""
|
||||
# Create new rule id
|
||||
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
|
||||
new_rule_id = rule_id_scope + "/" + new_room_id
|
||||
|
||||
# Change room id in each condition
|
||||
for condition in rule.get("conditions", []):
|
||||
if condition.get("key") == "room_id":
|
||||
condition["pattern"] = new_room_id
|
||||
|
||||
# Add the rule for the new room
|
||||
await self.add_push_rule(
|
||||
user_id=user_id,
|
||||
rule_id=new_rule_id,
|
||||
priority_class=rule["priority_class"],
|
||||
conditions=rule["conditions"],
|
||||
actions=rule["actions"],
|
||||
)
|
||||
|
||||
async def copy_push_rules_from_room_to_room_for_user(
|
||||
self, old_room_id: str, new_room_id: str, user_id: str
|
||||
) -> None:
|
||||
"""Copy all of the push rules from one room to another for a specific
|
||||
user.
|
||||
|
||||
Args:
|
||||
old_room_id: ID of the old room.
|
||||
new_room_id: ID of the new room.
|
||||
user_id: ID of user to copy push rules for.
|
||||
"""
|
||||
# Retrieve push rules for this user
|
||||
user_push_rules = await self.get_push_rules_for_user(user_id)
|
||||
|
||||
# Get rules relating to the old room and copy them to the new room
|
||||
for rule in user_push_rules:
|
||||
conditions = rule.get("conditions", [])
|
||||
if any(
|
||||
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
|
||||
for c in conditions
|
||||
):
|
||||
await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def bulk_get_push_rules_enabled(
|
||||
self, user_ids: Collection[str]
|
||||
) -> Dict[str, Dict[str, bool]]:
|
||||
async def bulk_get_push_rules_enabled(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
|
||||
results = {user_id: {} for user_id in user_ids}
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules_enable",
|
||||
@@ -270,7 +306,7 @@ class PushRulesWorkerStore(
|
||||
|
||||
async def get_all_push_rule_updates(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for push_rules replication stream.
|
||||
|
||||
Args:
|
||||
@@ -295,9 +331,7 @@ class PushRulesWorkerStore(
|
||||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_push_rule_updates_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
|
||||
def get_all_push_rule_updates_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_id, user_id
|
||||
FROM push_rules_stream
|
||||
@@ -306,10 +340,7 @@ class PushRulesWorkerStore(
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updates = cast(
|
||||
List[Tuple[int, Tuple[str]]],
|
||||
[(stream_id, (user_id,)) for stream_id, user_id in txn],
|
||||
)
|
||||
updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
|
||||
|
||||
limited = False
|
||||
upper_bound = current_id
|
||||
@@ -325,30 +356,15 @@ class PushRulesWorkerStore(
|
||||
|
||||
|
||||
class PushRuleStore(PushRulesWorkerStore):
|
||||
# Because we have write access, this will be a StreamIdGenerator
|
||||
# (see PushRulesWorkerStore.__init__)
|
||||
_push_rules_stream_id_gen: AbstractStreamIdGenerator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||
|
||||
async def add_push_rule(
|
||||
self,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
conditions: List[Dict[str, str]],
|
||||
actions: List[Union[JsonDict, str]],
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions,
|
||||
actions,
|
||||
before=None,
|
||||
after=None,
|
||||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
@@ -384,17 +400,17 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
def _add_push_rule_relative_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
conditions_json: str,
|
||||
actions_json: str,
|
||||
before: str,
|
||||
after: str,
|
||||
) -> None:
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
before,
|
||||
after,
|
||||
):
|
||||
# Lock the table since otherwise we'll have annoying races between the
|
||||
# SELECT here and the UPSERT below.
|
||||
self.database_engine.lock_table(txn, "push_rules")
|
||||
@@ -454,15 +470,15 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
def _add_push_rule_highest_priority_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
conditions_json: str,
|
||||
actions_json: str,
|
||||
) -> None:
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
):
|
||||
# Lock the table since otherwise we'll have annoying races between the
|
||||
# SELECT here and the UPSERT below.
|
||||
self.database_engine.lock_table(txn, "push_rules")
|
||||
@@ -494,17 +510,17 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
def _upsert_push_rule_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
priority_class: int,
|
||||
priority: int,
|
||||
conditions_json: str,
|
||||
actions_json: str,
|
||||
update_stream: bool = True,
|
||||
) -> None:
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
priority_class,
|
||||
priority,
|
||||
conditions_json,
|
||||
actions_json,
|
||||
update_stream=True,
|
||||
):
|
||||
"""Specialised version of simple_upsert_txn that picks a push_rule_id
|
||||
using the _push_rule_id_gen if it needs to insert the rule. It assumes
|
||||
that the "push_rules" table is locked"""
|
||||
@@ -584,11 +600,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
rule_id: The rule_id of the rule to be deleted
|
||||
"""
|
||||
|
||||
def delete_push_rule_txn(
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
) -> None:
|
||||
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
|
||||
# we don't use simple_delete_one_txn because that would fail if the
|
||||
# user did not have a push_rule_enable row.
|
||||
self.db_pool.simple_delete_txn(
|
||||
@@ -649,14 +661,14 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
|
||||
def _set_push_rule_enabled_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
enabled: bool,
|
||||
is_default_rule: bool,
|
||||
) -> None:
|
||||
txn,
|
||||
stream_id,
|
||||
event_stream_ordering,
|
||||
user_id,
|
||||
rule_id,
|
||||
enabled,
|
||||
is_default_rule,
|
||||
):
|
||||
new_id = self._push_rules_enable_id_gen.get_next()
|
||||
|
||||
if not is_default_rule:
|
||||
@@ -728,11 +740,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
"""
|
||||
actions_json = json_encoder.encode(actions)
|
||||
|
||||
def set_push_rule_actions_txn(
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
) -> None:
|
||||
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
|
||||
if is_default_rule:
|
||||
# Add a dummy rule to the rules table with the user specified
|
||||
# actions.
|
||||
@@ -786,15 +794,8 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
|
||||
def _insert_push_rules_update_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
stream_id: int,
|
||||
event_stream_ordering: int,
|
||||
user_id: str,
|
||||
rule_id: str,
|
||||
op: str,
|
||||
data: Optional[JsonDict] = None,
|
||||
) -> None:
|
||||
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
|
||||
):
|
||||
values = {
|
||||
"stream_id": stream_id,
|
||||
"event_stream_ordering": event_stream_ordering,
|
||||
@@ -813,56 +814,5 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
|
||||
)
|
||||
|
||||
def get_max_push_rules_stream_id(self) -> int:
|
||||
def get_max_push_rules_stream_id(self):
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
async def copy_push_rule_from_room_to_room(
|
||||
self, new_room_id: str, user_id: str, rule: dict
|
||||
) -> None:
|
||||
"""Copy a single push rule from one room to another for a specific user.
|
||||
|
||||
Args:
|
||||
new_room_id: ID of the new room.
|
||||
user_id : ID of user the push rule belongs to.
|
||||
rule: A push rule.
|
||||
"""
|
||||
# Create new rule id
|
||||
rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
|
||||
new_rule_id = rule_id_scope + "/" + new_room_id
|
||||
|
||||
# Change room id in each condition
|
||||
for condition in rule.get("conditions", []):
|
||||
if condition.get("key") == "room_id":
|
||||
condition["pattern"] = new_room_id
|
||||
|
||||
# Add the rule for the new room
|
||||
await self.add_push_rule(
|
||||
user_id=user_id,
|
||||
rule_id=new_rule_id,
|
||||
priority_class=rule["priority_class"],
|
||||
conditions=rule["conditions"],
|
||||
actions=rule["actions"],
|
||||
)
|
||||
|
||||
async def copy_push_rules_from_room_to_room_for_user(
|
||||
self, old_room_id: str, new_room_id: str, user_id: str
|
||||
) -> None:
|
||||
"""Copy all of the push rules from one room to another for a specific
|
||||
user.
|
||||
|
||||
Args:
|
||||
old_room_id: ID of the old room.
|
||||
new_room_id: ID of the new room.
|
||||
user_id: ID of user to copy push rules for.
|
||||
"""
|
||||
# Retrieve push rules for this user
|
||||
user_push_rules = await self.get_push_rules_for_user(user_id)
|
||||
|
||||
# Get rules relating to the old room and copy them to the new room
|
||||
for rule in user_push_rules:
|
||||
conditions = rule.get("conditions", [])
|
||||
if any(
|
||||
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
|
||||
for c in conditions
|
||||
):
|
||||
await self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
@@ -38,12 +37,7 @@ from synapse.metrics.background_process_metrics import (
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.engines import Sqlite3Engine
|
||||
from synapse.storage.roommember import (
|
||||
@@ -52,7 +46,7 @@ from synapse.storage.roommember import (
|
||||
ProfileInfo,
|
||||
RoomsForUser,
|
||||
)
|
||||
from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
|
||||
from synapse.types import PersistedEventPosition, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import intern_string
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
|
||||
@@ -121,7 +115,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
)
|
||||
|
||||
@wrap_as_background_process("_count_known_servers")
|
||||
async def _count_known_servers(self) -> int:
|
||||
async def _count_known_servers(self):
|
||||
"""
|
||||
Count the servers that this server knows about.
|
||||
|
||||
@@ -129,7 +123,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
`synapse_federation_known_servers` LaterGauge to collect.
|
||||
"""
|
||||
|
||||
def _transact(txn: LoggingTransaction) -> int:
|
||||
def _transact(txn):
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
query = """
|
||||
SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
|
||||
@@ -156,9 +150,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
self._known_servers_count = max([count, 1])
|
||||
return self._known_servers_count
|
||||
|
||||
def _check_safe_current_state_events_membership_updated_txn(
|
||||
self, txn: LoggingTransaction
|
||||
) -> None:
|
||||
def _check_safe_current_state_events_membership_updated_txn(self, txn):
|
||||
"""Checks if it is safe to assume the new current_state_events
|
||||
membership column is up to date
|
||||
"""
|
||||
@@ -190,7 +182,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||
)
|
||||
|
||||
def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
|
||||
def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
|
||||
# If we can assume current_state_events.membership is up to date
|
||||
# then we can avoid a join, which is a Very Good Thing given how
|
||||
# frequently this function gets called.
|
||||
@@ -230,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
A mapping from user ID to ProfileInfo.
|
||||
"""
|
||||
|
||||
def _get_users_in_room_with_profiles(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
|
||||
sql = """
|
||||
SELECT state_key, display_name, avatar_url FROM room_memberships as m
|
||||
INNER JOIN current_state_events as c
|
||||
@@ -260,9 +250,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
dict of membership states, pointing to a MemberSummary named tuple.
|
||||
"""
|
||||
|
||||
def _get_room_summary_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, MemberSummary]:
|
||||
def _get_room_summary_txn(txn):
|
||||
# first get counts.
|
||||
# We do this all in one transaction to keep the cache small.
|
||||
# FIXME: get rid of this when we have room_stats
|
||||
@@ -291,7 +279,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
"""
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
res: Dict[str, MemberSummary] = {}
|
||||
res = {}
|
||||
for count, membership in txn:
|
||||
res.setdefault(membership, MemberSummary([], count))
|
||||
|
||||
@@ -412,7 +400,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
def _get_rooms_for_local_user_where_membership_is_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
txn,
|
||||
user_id: str,
|
||||
membership_list: List[str],
|
||||
) -> List[RoomsForUser]:
|
||||
@@ -500,7 +488,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
)
|
||||
|
||||
def _get_rooms_for_user_with_stream_ordering_txn(
|
||||
self, txn: LoggingTransaction, user_id: str
|
||||
self, txn, user_id: str
|
||||
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
|
||||
# We use `current_state_events` here and not `local_current_membership`
|
||||
# as a) this gets called with remote users and b) this only gets called
|
||||
@@ -554,7 +542,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
)
|
||||
|
||||
def _get_rooms_for_users_with_stream_ordering_txn(
|
||||
self, txn: LoggingTransaction, user_ids: Collection[str]
|
||||
self, txn, user_ids: Collection[str]
|
||||
) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
@@ -587,9 +575,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
txn.execute(sql, [Membership.JOIN] + args)
|
||||
|
||||
result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
|
||||
user_id: set() for user_id in user_ids
|
||||
}
|
||||
result = {user_id: set() for user_id in user_ids}
|
||||
for user_id, room_id, instance, stream_id in txn:
|
||||
result[user_id].add(
|
||||
GetRoomsForUserWithStreamOrdering(
|
||||
@@ -609,9 +595,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
if not user_ids:
|
||||
return set()
|
||||
|
||||
def _get_users_server_still_shares_room_with_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Set[str]:
|
||||
def _get_users_server_still_shares_room_with_txn(txn):
|
||||
sql = """
|
||||
SELECT state_key FROM current_state_events
|
||||
WHERE
|
||||
@@ -635,7 +619,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
)
|
||||
|
||||
async def get_rooms_for_user(
|
||||
self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
|
||||
self, user_id: str, on_invalidate=None
|
||||
) -> FrozenSet[str]:
|
||||
"""Returns a set of room_ids the user is currently joined to.
|
||||
|
||||
@@ -673,7 +657,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
async def get_joined_users_from_context(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
state_group: Union[object, int] = context.state_group
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
@@ -682,16 +666,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
state_group = object()
|
||||
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
assert current_state_ids is not None
|
||||
assert state_group is not None
|
||||
return await self._get_joined_users_from_context(
|
||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
||||
)
|
||||
|
||||
async def get_joined_users_from_state(
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
self, room_id, state_entry
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
state_group: Union[object, int] = state_entry.state_group
|
||||
state_group = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
@@ -699,7 +681,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
assert state_group is not None
|
||||
with Measure(self._clock, "get_joined_users_from_state"):
|
||||
return await self._get_joined_users_from_context(
|
||||
room_id, state_group, state_entry.state, context=state_entry
|
||||
@@ -708,12 +689,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
||||
async def _get_joined_users_from_context(
|
||||
self,
|
||||
room_id: str,
|
||||
state_group: Union[object, int],
|
||||
current_state_ids: StateMap[str],
|
||||
cache_context: _CacheContext,
|
||||
event: Optional[EventBase] = None,
|
||||
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
|
||||
room_id,
|
||||
state_group,
|
||||
current_state_ids,
|
||||
cache_context,
|
||||
event=None,
|
||||
context=None,
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
# We don't use `state_group`, it's there so that we can cache based
|
||||
# on it. However, it's important that it's never None, since two current_states
|
||||
@@ -784,18 +765,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
return users_in_room
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def _get_joined_profile_from_event_id(
|
||||
self, event_id: str
|
||||
) -> Optional[Tuple[str, ProfileInfo]]:
|
||||
def _get_joined_profile_from_event_id(self, event_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_joined_profile_from_event_id",
|
||||
list_name="event_ids",
|
||||
)
|
||||
async def _get_joined_profiles_from_event_ids(
|
||||
self, event_ids: Iterable[str]
|
||||
) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
|
||||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
"""For given set of member event_ids check if they point to a join
|
||||
event and if so return the associated user and profile info.
|
||||
|
||||
@@ -803,7 +780,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
event_ids: The member event IDs to lookup
|
||||
|
||||
Returns:
|
||||
Map from event ID to `user_id` and ProfileInfo (or None if not join event).
|
||||
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
|
||||
to `user_id` and ProfileInfo (or None if not join event).
|
||||
"""
|
||||
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
@@ -869,10 +847,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
return True
|
||||
|
||||
async def get_joined_hosts(
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
state_group: Union[object, int] = state_entry.state_group
|
||||
async def get_joined_hosts(self, room_id: str, state_entry):
|
||||
state_group = state_entry.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
@@ -880,7 +856,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
assert state_group is not None
|
||||
with Measure(self._clock, "get_joined_hosts"):
|
||||
return await self._get_joined_hosts(
|
||||
room_id, state_group, state_entry=state_entry
|
||||
@@ -888,10 +863,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||
async def _get_joined_hosts(
|
||||
self,
|
||||
room_id: str,
|
||||
state_group: Union[object, int],
|
||||
state_entry: "_StateCacheEntry",
|
||||
self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
# We don't use `state_group`, it's there so that we can cache based on
|
||||
# it. However, its important that its never None, since two
|
||||
@@ -909,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
# `get_joined_hosts` is called with the "current" state group for the
|
||||
# room, and so consecutive calls will be for consecutive state groups
|
||||
# which point to the previous state group.
|
||||
cache = await self._get_joined_hosts_cache(room_id) # type: ignore[misc]
|
||||
cache = await self._get_joined_hosts_cache(room_id)
|
||||
|
||||
# If the state group in the cache matches, we already have the data we need.
|
||||
if state_entry.state_group == cache.state_group:
|
||||
@@ -925,7 +897,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
elif state_entry.prev_group == cache.state_group:
|
||||
# The cached work is for the previous state group, so we work out
|
||||
# the delta.
|
||||
assert state_entry.delta_ids is not None
|
||||
for (typ, state_key), event_id in state_entry.delta_ids.items():
|
||||
if typ != EventTypes.Member:
|
||||
continue
|
||||
@@ -971,7 +942,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
Returns False if they have since re-joined."""
|
||||
|
||||
def f(txn: LoggingTransaction) -> int:
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT"
|
||||
" COUNT(*)"
|
||||
@@ -1002,7 +973,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
The forgotten rooms.
|
||||
"""
|
||||
|
||||
def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
def _get_forgotten_rooms_for_user_txn(txn):
|
||||
# This is a slightly convoluted query that first looks up all rooms
|
||||
# that the user has forgotten in the past, then rechecks that list
|
||||
# to see if any have subsequently been updated. This is done so that
|
||||
@@ -1105,9 +1076,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
clause,
|
||||
)
|
||||
|
||||
def _is_local_host_in_room_ignoring_users_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> bool:
|
||||
def _is_local_host_in_room_ignoring_users_txn(txn):
|
||||
txn.execute(sql, (room_id, Membership.JOIN, *args))
|
||||
|
||||
return bool(txn.fetchone())
|
||||
@@ -1141,17 +1110,15 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
where_clause="forgotten = 1",
|
||||
)
|
||||
|
||||
async def _background_add_membership_profile(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
async def _background_add_membership_profile(self, progress, batch_size):
|
||||
target_min_stream_id = progress.get(
|
||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start # type: ignore[attr-defined]
|
||||
"target_min_stream_id_inclusive", self._min_stream_order_on_start
|
||||
)
|
||||
max_stream_id = progress.get(
|
||||
"max_stream_id_exclusive", self._stream_order_on_start + 1 # type: ignore[attr-defined]
|
||||
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
||||
)
|
||||
|
||||
def add_membership_profile_txn(txn: LoggingTransaction) -> int:
|
||||
def add_membership_profile_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_ordering, event_id, events.room_id, event_json.json
|
||||
FROM events
|
||||
@@ -1215,17 +1182,13 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
return result
|
||||
|
||||
async def _background_current_state_membership(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
async def _background_current_state_membership(self, progress, batch_size):
|
||||
"""Update the new membership column on current_state_events.
|
||||
|
||||
This works by iterating over all rooms in alphebetical order.
|
||||
"""
|
||||
|
||||
def _background_current_state_membership_txn(
|
||||
txn: LoggingTransaction, last_processed_room: str
|
||||
) -> Tuple[int, bool]:
|
||||
def _background_current_state_membership_txn(txn, last_processed_room):
|
||||
processed = 0
|
||||
while processed < batch_size:
|
||||
txn.execute(
|
||||
@@ -1279,11 +1242,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
return row_count
|
||||
|
||||
|
||||
class RoomMemberStore(
|
||||
RoomMemberWorkerStore,
|
||||
RoomMemberBackgroundUpdateStore,
|
||||
CacheInvalidationWorkerStore,
|
||||
):
|
||||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
@@ -1295,7 +1254,7 @@ class RoomMemberStore(
|
||||
async def forget(self, user_id: str, room_id: str) -> None:
|
||||
"""Indicate that user_id wishes to discard history for room_id."""
|
||||
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
def f(txn):
|
||||
sql = (
|
||||
"UPDATE"
|
||||
" room_memberships"
|
||||
@@ -1329,5 +1288,5 @@ class _JoinedHostsCache:
|
||||
# equal to anything else).
|
||||
state_group: Union[object, int] = attr.Factory(object)
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return sum(len(v) for v in self.hosts_to_joined_users.values())
|
||||
|
||||
@@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
group: int,
|
||||
state_filter: StateFilter,
|
||||
) -> Tuple[MutableStateMap[str], bool]:
|
||||
"""Checks if group is in cache. See `get_state_for_groups`
|
||||
"""Checks if group is in cache. See `_get_state_for_groups`
|
||||
|
||||
Args:
|
||||
cache: the state group cache to use
|
||||
|
||||
@@ -61,9 +61,7 @@ Changes in SCHEMA_VERSION = 68:
|
||||
|
||||
Changes in SCHEMA_VERSION = 69:
|
||||
- We now write to `device_lists_changes_in_room` table.
|
||||
- We now use a PostgreSQL sequence to generate future txn_ids for
|
||||
`application_services_txns`. `application_services_state.last_txn` is no longer
|
||||
updated.
|
||||
- Use sequence to generate future `application_services_txns.txn_id`s
|
||||
|
||||
Changes in SCHEMA_VERSION = 70:
|
||||
- event_reference_hashes is no longer written to.
|
||||
@@ -73,7 +71,6 @@ Changes in SCHEMA_VERSION = 70:
|
||||
SCHEMA_COMPAT_VERSION = (
|
||||
# We now assume that `device_lists_changes_in_room` has been filled out for
|
||||
# recent device_list_updates.
|
||||
# ... and that `application_services_state.last_txn` is not used.
|
||||
69
|
||||
)
|
||||
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,7 +15,6 @@ import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
@@ -534,44 +532,6 @@ class StateFilter:
|
||||
new_all, new_excludes, new_wildcards, new_concrete_keys
|
||||
)
|
||||
|
||||
def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
|
||||
"""Check if we need to wait for full state to complete to calculate this state
|
||||
|
||||
If we have a state filter which is completely satisfied even with partial
|
||||
state, then we don't need to await_full_state before we can return it.
|
||||
|
||||
Args:
|
||||
is_mine_id: a callable which confirms if a given state_key matches a mxid
|
||||
of a local user
|
||||
"""
|
||||
|
||||
# TODO(faster_joins): it's not entirely clear that this is safe. In particular,
|
||||
# there may be circumstances in which we return a piece of state that, once we
|
||||
# resync the state, we discover is invalid. For example: if it turns out that
|
||||
# the sender of a piece of state wasn't actually in the room, then clearly that
|
||||
# state shouldn't have been returned.
|
||||
# We should at least add some tests around this to see what happens.
|
||||
|
||||
# if we haven't requested membership events, then it depends on the value of
|
||||
# 'include_others'
|
||||
if EventTypes.Member not in self.types:
|
||||
return self.include_others
|
||||
|
||||
# if we're looking for *all* membership events, then we have to wait
|
||||
member_state_keys = self.types[EventTypes.Member]
|
||||
if member_state_keys is None:
|
||||
return True
|
||||
|
||||
# otherwise, consider whose membership we are looking for. If it's entirely
|
||||
# local users, then we don't need to wait.
|
||||
for state_key in member_state_keys:
|
||||
if not is_mine_id(state_key):
|
||||
# remote user
|
||||
return True
|
||||
|
||||
# local users only
|
||||
return False
|
||||
|
||||
|
||||
_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
|
||||
_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
|
||||
@@ -584,7 +544,6 @@ class StateGroupStorage:
|
||||
"""High level interface to fetching state for event."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
self.stores = stores
|
||||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
||||
|
||||
@@ -627,7 +586,7 @@ class StateGroupStorage:
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(event_ids)
|
||||
event_to_groups = await self._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||
@@ -643,7 +602,7 @@ class StateGroupStorage:
|
||||
Returns:
|
||||
Resolves to a map of (type, state_key) -> event_id
|
||||
"""
|
||||
group_to_state = await self.get_state_for_groups((state_group,))
|
||||
group_to_state = await self._get_state_for_groups((state_group,))
|
||||
|
||||
return group_to_state[state_group]
|
||||
|
||||
@@ -716,13 +675,7 @@ class StateGroupStorage:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
await_full_state = True
|
||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||
await_full_state = False
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(
|
||||
event_ids, await_full_state=await_full_state
|
||||
)
|
||||
event_to_groups = await self._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
@@ -746,9 +699,7 @@ class StateGroupStorage:
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
@@ -765,13 +716,7 @@ class StateGroupStorage:
|
||||
RuntimeError if we don't have a state group for one or more of the events
|
||||
(ie they are outliers or unknown)
|
||||
"""
|
||||
await_full_state = True
|
||||
if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
|
||||
await_full_state = False
|
||||
|
||||
event_to_groups = await self.get_state_group_for_events(
|
||||
event_ids, await_full_state=await_full_state
|
||||
)
|
||||
event_to_groups = await self._get_state_group_for_events(event_ids)
|
||||
|
||||
groups = set(event_to_groups.values())
|
||||
group_to_state = await self.stores.state._get_state_for_groups(
|
||||
@@ -829,7 +774,7 @@ class StateGroupStorage:
|
||||
)
|
||||
return state_map[event_id]
|
||||
|
||||
def get_state_for_groups(
|
||||
def _get_state_for_groups(
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Awaitable[Dict[int, MutableStateMap[str]]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
@@ -847,7 +792,7 @@ class StateGroupStorage:
|
||||
groups, state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
async def get_state_group_for_events(
|
||||
async def _get_state_group_for_events(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
await_full_state: bool = True,
|
||||
@@ -857,7 +802,7 @@ class StateGroupStorage:
|
||||
Args:
|
||||
event_ids: events to get state groups for
|
||||
await_full_state: if true, will block if we do not yet have complete
|
||||
state at these events.
|
||||
state at this event.
|
||||
"""
|
||||
if await_full_state:
|
||||
await self._partial_state_events_tracker.await_full_state(event_ids)
|
||||
|
||||
@@ -1,415 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict
|
||||
from unittest import TestCase
|
||||
|
||||
import yaml
|
||||
from parameterized import parameterized
|
||||
from pydantic import ValidationError
|
||||
|
||||
from synapse.config.oidc2 import (
|
||||
ClientAuthMethods,
|
||||
LegacyOIDCProviderModel,
|
||||
OIDCProviderModel,
|
||||
)
|
||||
|
||||
SAMPLE_CONFIG = yaml.safe_load(
|
||||
"""
|
||||
idp_id: my_idp
|
||||
idp_name: My OpenID provider
|
||||
idp_icon: "mxc://example.com/blahblahblah"
|
||||
idp_brand: "brandy"
|
||||
issuer: "https://accountns.exeample.com"
|
||||
client_id: "provided-by-your-issuer"
|
||||
client_secret_jwt_key:
|
||||
key: DUMMY_PRIVATE_KEY
|
||||
jwt_header:
|
||||
alg: ES256
|
||||
kid: potato123
|
||||
jwt_payload:
|
||||
iss: issuer456
|
||||
client_auth_method: "client_secret_post"
|
||||
scopes: ["name", "email", "openid"]
|
||||
authorization_endpoint: https://example.com/auth/authorize?response_mode=form_post
|
||||
token_endpoint: https://id.example.com/dummy_url_here
|
||||
jwks_uri: "https://accounts.example.com/.well-known/jwks.json"
|
||||
user_mapping_provider:
|
||||
config:
|
||||
email_template: "{{ user.email }}"
|
||||
localpart_template: "{{ user.email|localpart_from_email }}"
|
||||
confirm_localpart: true
|
||||
attribute_requirements:
|
||||
- attribute: userGroup
|
||||
value: "synapseUsers"
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
class PydanticOIDCTestCase(TestCase):
|
||||
"""Examples to build confidence that pydantic is doing the validation we think
|
||||
it's doing"""
|
||||
|
||||
# Each test gets a dummy config it can change as it sees fit
|
||||
config: Dict[str, Any]
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.config = deepcopy(SAMPLE_CONFIG)
|
||||
|
||||
@contextmanager
|
||||
def assertRaises(self, *args, **kwargs):
|
||||
"""To demonstrate the example error messages generated by Pydantic, uncomment
|
||||
this method."""
|
||||
with super().assertRaises(*args, **kwargs) as result:
|
||||
yield result
|
||||
print()
|
||||
print(result.exception)
|
||||
|
||||
def test_example_config(self):
|
||||
# Check that parsing the sample config doesn't raise an error.
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
def test_idp_id(self) -> None:
|
||||
"""Example of using a Pydantic constr() field without a default."""
|
||||
# Enforce that idp_id is required.
|
||||
with self.assertRaises(ValidationError):
|
||||
del self.config["idp_id"]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Enforce that idp_id is a string.
|
||||
for bad_value in 123, None, ["a"], {"a": "b"}:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_id"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Enforce a length between 1 and 250.
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_id"] = ""
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_id"] = "a" * 251
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Enforce the regex
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_id"] = "$"
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# What happens with a really long string of prohibited characters?
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_id"] = "$" * 500
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
def test_legacy_model(self) -> None:
|
||||
"""Example of widening a field's type in a subclass."""
|
||||
# Check that parsing the sample config doesn't raise an error.
|
||||
LegacyOIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Check we have default values for the attributes which have a legacy fallback
|
||||
del self.config["idp_id"]
|
||||
del self.config["idp_name"]
|
||||
model = LegacyOIDCProviderModel.parse_obj(self.config)
|
||||
self.assertEqual(model.idp_id, "oidc")
|
||||
self.assertEqual(model.idp_name, "OIDC")
|
||||
|
||||
# Check we still reject bad types
|
||||
for bad_value in 123, [], {}, None:
|
||||
with self.assertRaises(ValidationError) as e:
|
||||
self.config["idp_id"] = bad_value
|
||||
self.config["idp_name"] = bad_value
|
||||
LegacyOIDCProviderModel.parse_obj(self.config)
|
||||
# And while we're at it, check that we spot errors in both fields
|
||||
reported_bad_fields = {item["loc"] for item in e.exception.errors()}
|
||||
expected_bad_fields = {("idp_id",), ("idp_name",)}
|
||||
self.assertEqual(
|
||||
reported_bad_fields, expected_bad_fields, e.exception.errors()
|
||||
)
|
||||
|
||||
def test_issuer(self) -> None:
|
||||
"""Example of a StrictStr field without a default."""
|
||||
|
||||
# Empty and nonempty strings should be accepted.
|
||||
for good_value in "", "hello", "hello" * 1000, "☃":
|
||||
self.config["issuer"] = good_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Invalid types should be rejected.
|
||||
for bad_value in 123, None, ["h", "e", "l", "l", "o"], {"hello": "there"}:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["issuer"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# A missing issuer should be rejected.
|
||||
with self.assertRaises(ValidationError):
|
||||
del self.config["issuer"]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
def test_idp_brand(self) -> None:
|
||||
"""Example of an Optional[StrictStr] field."""
|
||||
# Empty and nonempty strings should be accepted.
|
||||
for good_value in "", "hello", "hello" * 1000, "☃":
|
||||
self.config["idp_brand"] = good_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Invalid types should be rejected.
|
||||
for bad_value in 123, ["h", "e", "l", "l", "o"], {"hello": "there"}:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_brand"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# A lack of an idp_brand is fine...
|
||||
del self.config["idp_brand"]
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertIsNone(model.idp_brand)
|
||||
|
||||
# ... and interpreted the same as an explicit `None`.
|
||||
self.config["idp_brand"] = None
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertIsNone(model.idp_brand)
|
||||
|
||||
def test_idp_icon(self) -> None:
|
||||
"""Example of a field with a custom validator."""
|
||||
# Test that bad types are rejected, even with our validator in place
|
||||
bad_value: object
|
||||
for bad_value in None, {}, [], 123, 45.6:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_icon"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Test that bad strings are rejected by our validator
|
||||
for bad_value in "", "notaurl", "https://example.com", "mxc://mxc://mxc://":
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["idp_icon"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
def test_discover(self) -> None:
|
||||
"""Example of a StrictBool field with a default."""
|
||||
# Booleans are permitted.
|
||||
for value in True, False:
|
||||
self.config["discover"] = value
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertEqual(model.discover, value)
|
||||
|
||||
# Invalid types should be rejected.
|
||||
for bad_value in (
|
||||
-1.0,
|
||||
0,
|
||||
1,
|
||||
float("nan"),
|
||||
"yes",
|
||||
"NO",
|
||||
"True",
|
||||
"true",
|
||||
None,
|
||||
"None",
|
||||
"null",
|
||||
["a"],
|
||||
{"a": "b"},
|
||||
):
|
||||
self.config["discover"] = bad_value
|
||||
with self.assertRaises(ValidationError):
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# A missing value is okay, because this field has a default.
|
||||
del self.config["discover"]
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertIs(model.discover, True)
|
||||
|
||||
def test_client_auth_method(self) -> None:
|
||||
"""This is an example of using a Pydantic string enum field."""
|
||||
# check the allowed values are permitted and deserialise to an enum member
|
||||
for method in "client_secret_basic", "client_secret_post", "none":
|
||||
self.config["client_auth_method"] = method
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertIs(model.client_auth_method, ClientAuthMethods[method])
|
||||
|
||||
# check the default applies if no auth method is provided.
|
||||
del self.config["client_auth_method"]
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertIs(model.client_auth_method, ClientAuthMethods.client_secret_basic)
|
||||
|
||||
# Check invalid types are rejected
|
||||
for bad_value in 123, ["client_secret_basic"], {"a": 1}, None:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["client_auth_method"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Check that disallowed strings are rejected
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["client_auth_method"] = "No, Luke, _I_ am your father!"
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
def test_scopes(self) -> None:
|
||||
"""Example of a Tuple[StrictStr] with a default."""
|
||||
# Check that the parsed object holds a tuple
|
||||
self.config["scopes"] = []
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertEqual(model.scopes, ())
|
||||
|
||||
# Check a variety of list lengths are accepted.
|
||||
for good_value in ["aa"], ["hello", "world"], ["a"] * 4, [""] * 20:
|
||||
self.config["scopes"] = good_value
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertEqual(model.scopes, tuple(good_value))
|
||||
|
||||
# Check invalid types are rejected.
|
||||
for bad_value in (
|
||||
"",
|
||||
"abc",
|
||||
123,
|
||||
{},
|
||||
{"a": 1},
|
||||
None,
|
||||
[None],
|
||||
[["a"]],
|
||||
[{}],
|
||||
[456],
|
||||
):
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["scopes"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Check that "scopes" may be omitted.
|
||||
del self.config["scopes"]
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertEqual(model.scopes, ("openid",))
|
||||
|
||||
@parameterized.expand(["authorization_endpoint", "token_endpoint"])
|
||||
def test_endpoints_required_when_discovery_disabled(self, key: str) -> None:
|
||||
"""Example of a validator that applies to multiple fields."""
|
||||
# Test that this field is required if discovery is disabled
|
||||
self.config["discover"] = False
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config[key] = None
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
del self.config[key]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
# We don't validate that the endpoint is a sensible URL; anything str will do
|
||||
self.config[key] = "blahblah"
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
def check_all_cases_pass():
|
||||
self.config[key] = None
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
del self.config[key]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
self.config[key] = "blahblah"
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# With discovery enabled, all three cases are accepted.
|
||||
self.config["discover"] = True
|
||||
check_all_cases_pass()
|
||||
|
||||
# If not specified, discovery is also on by default.
|
||||
del self.config["discover"]
|
||||
check_all_cases_pass()
|
||||
|
||||
def test_userinfo_endpoint(self) -> None:
|
||||
"""Example of a more fiddly validator"""
|
||||
# This field is required if discovery is disabled and the openid scope
|
||||
# not requested.
|
||||
self.assertNotIn("userinfo_endpoint", self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["discover"] = False
|
||||
self.config["scopes"] = ()
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Still an error even if other scopes are provided
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["discover"] = False
|
||||
self.config["scopes"] = ("potato", "tomato")
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Passing an explicit None for userinfo_endpoint should also be an error.
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["discover"] = False
|
||||
self.config["scopes"] = ()
|
||||
self.config["userinfo_endpoint"] = None
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# No error if we enable discovery.
|
||||
self.config["discover"] = True
|
||||
self.config["scopes"] = ()
|
||||
self.config["userinfo_endpoint"] = None
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# No error if we enable the openid scope.
|
||||
self.config["discover"] = False
|
||||
self.config["scopes"] = ("openid",)
|
||||
self.config["userinfo_endpoint"] = None
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# No error if we don't specify scopes. (They default to `("openid", )`)
|
||||
self.config["discover"] = False
|
||||
del self.config["scopes"]
|
||||
self.config["userinfo_endpoint"] = None
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
def test_attribute_requirements(self):
|
||||
# Example of a field involving a nested model
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertIsInstance(model.attribute_requirements, tuple)
|
||||
self.assertEqual(
|
||||
len(model.attribute_requirements), 1, model.attribute_requirements
|
||||
)
|
||||
|
||||
# Bad types should be rejected
|
||||
bad_value: object
|
||||
for bad_value in 123, 456.0, False, None, {}, ["hello"]:
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = bad_value
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# An empty list of requirements is okay, ...
|
||||
self.config["attribute_requirements"] = []
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# ...as is an omitted list of requirements...
|
||||
del self.config["attribute_requirements"]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# ...but not an explicit None.
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = None
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Multiple requirements are fine.
|
||||
self.config["attribute_requirements"] = [{"attribute": "k", "value": "v"}] * 3
|
||||
model = OIDCProviderModel.parse_obj(self.config)
|
||||
self.assertEqual(
|
||||
len(model.attribute_requirements), 3, model.attribute_requirements
|
||||
)
|
||||
|
||||
# The submodel's field types should be enforced too.
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [{"attribute": "key", "value": 123}]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [{"attribute": 123, "value": "val"}]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [{"attribute": "a", "value": ["b"]}]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [{"attribute": "a", "value": None}]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Missing fields in the submodel are an error.
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [{"attribute": "a"}]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [{"value": "v"}]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [{}]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
|
||||
# Extra fields in the submodel are an error.
|
||||
with self.assertRaises(ValidationError):
|
||||
self.config["attribute_requirements"] = [
|
||||
{"attribute": "a", "value": "v", "answer": "forty-two"}
|
||||
]
|
||||
OIDCProviderModel.parse_obj(self.config)
|
||||
@@ -17,7 +17,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.federation.transport.server import BaseFederationServlet
|
||||
from synapse.federation.transport.server._base import Authenticator, _parse_auth_header
|
||||
from synapse.federation.transport.server._base import Authenticator
|
||||
from synapse.http.server import JsonResource, cancellable
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
@@ -112,30 +112,3 @@ class BaseFederationServletCancellationTests(
|
||||
expect_cancellation=False,
|
||||
expected_body={"result": True},
|
||||
)
|
||||
|
||||
|
||||
class BaseFederationAuthorizationTests(unittest.TestCase):
|
||||
def test_authorization_header(self) -> None:
|
||||
"""Tests that the Authorization header is parsed correctly."""
|
||||
|
||||
# test a "normal" Authorization header
|
||||
self.assertEqual(
|
||||
_parse_auth_header(
|
||||
b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar"'
|
||||
),
|
||||
("foo", "ed25519:1", "sig", "bar"),
|
||||
)
|
||||
# test an Authorization with extra spaces, upper-case names, and escaped
|
||||
# characters
|
||||
self.assertEqual(
|
||||
_parse_auth_header(
|
||||
b'X-Matrix ORIGIN=foo,KEY="ed25\\519:1",SIG="sig",destination="bar"'
|
||||
),
|
||||
("foo", "ed25519:1", "sig", "bar"),
|
||||
)
|
||||
self.assertEqual(
|
||||
_parse_auth_header(
|
||||
b'X-Matrix origin=foo,key="ed25519:1",sig="sig",destination="bar",extra_field=ignored'
|
||||
),
|
||||
("foo", "ed25519:1", "sig", "bar"),
|
||||
)
|
||||
|
||||
@@ -434,6 +434,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
# "Complete" a transaction.
|
||||
# All this really does for us is make an entry in the application_services_state
|
||||
# database table, which tracks the current stream_token per stream ID per AS.
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.complete_appservice_txn(
|
||||
0,
|
||||
interested_appservice,
|
||||
)
|
||||
)
|
||||
|
||||
# Now, pretend that we receive a large burst of read receipts (300 total) that
|
||||
# all come in at once.
|
||||
for i in range(300):
|
||||
|
||||
@@ -332,7 +332,6 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
most_recent_prev_event_depth,
|
||||
) = self.get_success(self.store.get_max_depth_of(prev_event_ids))
|
||||
# mapping from (type, state_key) -> state_event_id
|
||||
assert most_recent_prev_event_id is not None
|
||||
prev_state_map = self.get_success(
|
||||
self.state_store.get_state_ids_for_event(most_recent_prev_event_id)
|
||||
)
|
||||
|
||||
@@ -2489,5 +2489,4 @@ PURGE_TABLES = [
|
||||
"room_tags",
|
||||
# "state_groups", # Current impl leaves orphaned state groups around.
|
||||
"state_groups_state",
|
||||
"federation_inbound_events_staging",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import List, cast
|
||||
from typing import List, Optional, cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import yaml
|
||||
@@ -149,12 +149,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||
outfile.write(yaml.dump(as_yaml))
|
||||
self.as_yaml_files.append(as_token)
|
||||
|
||||
def _set_state(self, id: str, state: ApplicationServiceState):
|
||||
def _set_state(
|
||||
self, id: str, state: ApplicationServiceState, txn: Optional[int] = None
|
||||
):
|
||||
return self.db_pool.runOperation(
|
||||
self.engine.convert_param_style(
|
||||
"INSERT INTO application_services_state(as_id, state) VALUES(?,?)"
|
||||
"INSERT INTO application_services_state(as_id, state, last_txn) "
|
||||
"VALUES(?,?,?)"
|
||||
),
|
||||
(id, state.value),
|
||||
(id, state.value, txn),
|
||||
)
|
||||
|
||||
def _insert_txn(self, as_id, txn_id, events):
|
||||
@@ -277,6 +280,17 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.store.complete_appservice_txn(txn_id=txn_id, service=service)
|
||||
)
|
||||
|
||||
res = self.get_success(
|
||||
self.db_pool.runQuery(
|
||||
self.engine.convert_param_style(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?"
|
||||
),
|
||||
(service.id,),
|
||||
)
|
||||
)
|
||||
self.assertEqual(1, len(res))
|
||||
self.assertEqual(txn_id, res[0][0])
|
||||
|
||||
res = self.get_success(
|
||||
self.db_pool.runQuery(
|
||||
self.engine.convert_param_style(
|
||||
@@ -302,13 +316,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||
res = self.get_success(
|
||||
self.db_pool.runQuery(
|
||||
self.engine.convert_param_style(
|
||||
"SELECT state FROM application_services_state WHERE as_id=?"
|
||||
"SELECT last_txn, state FROM application_services_state WHERE as_id=?"
|
||||
),
|
||||
(service.id,),
|
||||
)
|
||||
)
|
||||
self.assertEqual(1, len(res))
|
||||
self.assertEqual(ApplicationServiceState.UP.value, res[0][0])
|
||||
self.assertEqual(txn_id, res[0][0])
|
||||
self.assertEqual(ApplicationServiceState.UP.value, res[0][1])
|
||||
|
||||
res = self.get_success(
|
||||
self.db_pool.runQuery(
|
||||
|
||||
@@ -129,19 +129,6 @@ class _DummyStore:
|
||||
async def get_room_version_id(self, room_id):
|
||||
return RoomVersions.V1.identifier
|
||||
|
||||
async def get_state_group_for_events(self, event_ids):
|
||||
res = {}
|
||||
for event in event_ids:
|
||||
res[event] = self._event_to_state_group[event]
|
||||
return res
|
||||
|
||||
async def get_state_for_groups(self, groups):
|
||||
res = {}
|
||||
for group in groups:
|
||||
state = self._group_to_state[group]
|
||||
res[group] = state
|
||||
return res
|
||||
|
||||
|
||||
class DictObj(dict):
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user