mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-15 02:00:21 +00:00
Compare commits
43 Commits
v1.85.2
...
shay/fix_d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cd13d0f26 | ||
|
|
b4cc31d906 | ||
|
|
d69d109ad6 | ||
|
|
a317ccbc7a | ||
|
|
7bc2aefe92 | ||
|
|
67f152b476 | ||
|
|
d0c4257f14 | ||
|
|
e0f2429d13 | ||
|
|
30a5076da8 | ||
|
|
8af29155ec | ||
|
|
5ed0e8c61f | ||
|
|
d1693f0362 | ||
|
|
0b5f64ff09 | ||
|
|
6f18812bb0 | ||
|
|
874378c052 | ||
|
|
daf3a67908 | ||
|
|
c01343de43 | ||
|
|
6fc3deb029 | ||
|
|
ceb3dd77db | ||
|
|
32a2f05004 | ||
|
|
f739bde962 | ||
|
|
98afc57d59 | ||
|
|
14a5be9c4d | ||
|
|
ec9379d7e2 | ||
|
|
e343125b38 | ||
|
|
4d0231b364 | ||
|
|
c008b44b4f | ||
|
|
bad1f2cd35 | ||
|
|
249f4a338d | ||
|
|
03920bdd4e | ||
|
|
31691d6151 | ||
|
|
5fe96082d0 | ||
|
|
28a9663bdf | ||
|
|
a1374b5c70 | ||
|
|
d20669971a | ||
|
|
f9cd549f64 | ||
|
|
7628dbf4e9 | ||
|
|
c5cf1b421d | ||
|
|
e82ec6d008 | ||
|
|
8f576aa462 | ||
|
|
765244faee | ||
|
|
e2c8458bba | ||
|
|
5d8c659373 |
1
.github/workflows/release-artifacts.yml
vendored
1
.github/workflows/release-artifacts.yml
vendored
@@ -34,6 +34,7 @@ jobs:
|
||||
- id: set-distros
|
||||
run: |
|
||||
# if we're running from a tag, get the full list of distros; otherwise just use debian:sid
|
||||
# NOTE: inside the actual Dockerfile-dhvirtualenv, the image name is expanded into its full image path
|
||||
dists='["debian:sid"]'
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
dists=$(scripts-dev/build_debian_packages.py --show-dists-json)
|
||||
|
||||
21
.github/workflows/tests.yml
vendored
21
.github/workflows/tests.yml
vendored
@@ -45,16 +45,6 @@ jobs:
|
||||
- run: poetry run scripts-dev/generate_sample_config.sh --check
|
||||
- run: poetry run scripts-dev/config-lint.sh
|
||||
|
||||
check-schema-delta:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- run: "pip install 'click==8.1.1' 'GitPython>=3.1.20'"
|
||||
- run: scripts-dev/check_schema_delta.py --force-colors
|
||||
|
||||
check-lockfile:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
@@ -221,7 +211,6 @@ jobs:
|
||||
- lint-newsfile
|
||||
- lint-pydantic
|
||||
- check-sampleconfig
|
||||
- check-schema-delta
|
||||
- check-lockfile
|
||||
- lint-clippy
|
||||
- lint-rustfmt
|
||||
@@ -609,6 +598,16 @@ jobs:
|
||||
|
||||
- run: cargo bench --no-run
|
||||
|
||||
check-schema-delta:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- run: "pip install 'click==8.1.1' 'GitPython>=3.1.20'"
|
||||
- run: scripts-dev/check_schema_delta.py --force-colors
|
||||
|
||||
# a job which marks all the other jobs as complete, thus allowing PRs to be merged.
|
||||
tests-done:
|
||||
if: ${{ always() }}
|
||||
|
||||
1
changelog.d/14213.misc
Normal file
1
changelog.d/14213.misc
Normal file
@@ -0,0 +1 @@
|
||||
Log when events are (maybe unexpectedly) filtered out of responses in tests.
|
||||
1
changelog.d/15388.feature
Normal file
1
changelog.d/15388.feature
Normal file
@@ -0,0 +1 @@
|
||||
Stable support for [MSC3882](https://github.com/matrix-org/matrix-spec-proposals/pull/3882) to allow an existing device/session to generate a login token for use on a new device/session.
|
||||
1
changelog.d/15450.feature
Normal file
1
changelog.d/15450.feature
Normal file
@@ -0,0 +1 @@
|
||||
Support resolving a room's [canonical alias](https://spec.matrix.org/v1.7/client-server-api/#mroomcanonical_alias) via the module API.
|
||||
1
changelog.d/15582.feature
Normal file
1
changelog.d/15582.feature
Normal file
@@ -0,0 +1 @@
|
||||
Experimental [MSC3861](https://github.com/matrix-org/matrix-spec-proposals/pull/3861) support: delegate auth to an OIDC provider.
|
||||
1
changelog.d/15649.misc
Normal file
1
changelog.d/15649.misc
Normal file
@@ -0,0 +1 @@
|
||||
Read from column `full_user_id` rather than `user_id` of tables `profiles` and `user_filters`.
|
||||
1
changelog.d/15674.feature
Normal file
1
changelog.d/15674.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add Syanpse version deploy annotations to Grafana dashboard which enables easy correlation between behavior changes witnessed in a graph to a certain Synapse version and nail down regressions.
|
||||
1
changelog.d/15675.misc
Normal file
1
changelog.d/15675.misc
Normal file
@@ -0,0 +1 @@
|
||||
Cache requests for user's devices over federation.
|
||||
1
changelog.d/15689.misc
Normal file
1
changelog.d/15689.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add fully qualified docker image names to Dockerfiles.
|
||||
1
changelog.d/15690.misc
Normal file
1
changelog.d/15690.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove some unused code.
|
||||
1
changelog.d/15694.misc
Normal file
1
changelog.d/15694.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve type hints.
|
||||
1
changelog.d/15697.misc
Normal file
1
changelog.d/15697.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve type hints.
|
||||
1
changelog.d/15705.feature
Normal file
1
changelog.d/15705.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add a catch-all * to the supported relation types when redacting an event and its related events. This is an update to [MSC3912](https://github.com/matrix-org/matrix-spec-proposals/pull/3861) implementation.
|
||||
1
changelog.d/15724.bugfix
Normal file
1
changelog.d/15724.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix missing dependencies in background jobs.
|
||||
File diff suppressed because it is too large
Load Diff
@@ -27,7 +27,7 @@ ARG PYTHON_VERSION=3.11
|
||||
###
|
||||
# We hardcode the use of Debian bullseye here because this could change upstream
|
||||
# and other Dockerfiles used for testing are expecting bullseye.
|
||||
FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as requirements
|
||||
FROM docker.io/library/python:${PYTHON_VERSION}-slim-bullseye as requirements
|
||||
|
||||
# RUN --mount is specific to buildkit and is documented at
|
||||
# https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md#build-mounts-run---mount.
|
||||
@@ -87,7 +87,7 @@ RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
|
||||
###
|
||||
### Stage 1: builder
|
||||
###
|
||||
FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as builder
|
||||
FROM docker.io/library/python:${PYTHON_VERSION}-slim-bullseye as builder
|
||||
|
||||
# install the OS build deps
|
||||
RUN \
|
||||
@@ -158,7 +158,7 @@ RUN --mount=type=cache,target=/synapse/target,sharing=locked \
|
||||
### Stage 2: runtime
|
||||
###
|
||||
|
||||
FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye
|
||||
FROM docker.io/library/python:${PYTHON_VERSION}-slim-bullseye
|
||||
|
||||
LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
|
||||
LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'
|
||||
|
||||
@@ -24,7 +24,7 @@ ARG distro=""
|
||||
# https://launchpad.net/~jyrki-pulliainen/+archive/ubuntu/dh-virtualenv, but
|
||||
# it's not obviously easier to use that than to build our own.)
|
||||
|
||||
FROM ${distro} as builder
|
||||
FROM docker.io/library/${distro} as builder
|
||||
|
||||
RUN apt-get update -qq -o Acquire::Languages=none
|
||||
RUN env DEBIAN_FRONTEND=noninteractive apt-get install \
|
||||
@@ -55,7 +55,7 @@ RUN cd /dh-virtualenv && DEB_BUILD_OPTIONS=nodoc dpkg-buildpackage -us -uc -b
|
||||
###
|
||||
### Stage 1
|
||||
###
|
||||
FROM ${distro}
|
||||
FROM docker.io/library/${distro}
|
||||
|
||||
# Get the distro we want to pull from as a dynamic build variable
|
||||
# (We need to define it in each build stage)
|
||||
|
||||
@@ -7,7 +7,7 @@ ARG FROM=matrixdotorg/synapse:$SYNAPSE_VERSION
|
||||
# target image. For repeated rebuilds, this is much faster than apt installing
|
||||
# each time.
|
||||
|
||||
FROM debian:bullseye-slim AS deps_base
|
||||
FROM docker.io/library/debian:bullseye-slim AS deps_base
|
||||
RUN \
|
||||
--mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
@@ -21,7 +21,7 @@ FROM debian:bullseye-slim AS deps_base
|
||||
# which makes it much easier to copy (but we need to make sure we use an image
|
||||
# based on the same debian version as the synapse image, to make sure we get
|
||||
# the expected version of libc.
|
||||
FROM redis:6-bullseye AS redis_base
|
||||
FROM docker.io/library/redis:6-bullseye AS redis_base
|
||||
|
||||
# now build the final image, based on the the regular Synapse docker image
|
||||
FROM $FROM
|
||||
|
||||
@@ -73,7 +73,8 @@ The following environment variables are supported in `generate` mode:
|
||||
will log sensitive information such as access tokens.
|
||||
This should not be needed unless you are a developer attempting to debug something
|
||||
particularly tricky.
|
||||
|
||||
* `SYNAPSE_LOG_TESTING`: if set, Synapse will log additional information useful
|
||||
for testing.
|
||||
|
||||
## Postgres
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
# https://github.com/matrix-org/synapse/blob/develop/docker/README-testing.md#testing-with-postgresql-and-single-or-multi-process-synapse
|
||||
|
||||
ARG SYNAPSE_VERSION=latest
|
||||
# This is an intermediate image, to be built locally (not pulled from a registry).
|
||||
ARG FROM=matrixdotorg/synapse-workers:$SYNAPSE_VERSION
|
||||
|
||||
FROM $FROM
|
||||
@@ -19,8 +20,8 @@ FROM $FROM
|
||||
# the same debian version as Synapse's docker image (so the versions of the
|
||||
# shared libraries match).
|
||||
RUN adduser --system --uid 999 postgres --home /var/lib/postgresql
|
||||
COPY --from=postgres:13-bullseye /usr/lib/postgresql /usr/lib/postgresql
|
||||
COPY --from=postgres:13-bullseye /usr/share/postgresql /usr/share/postgresql
|
||||
COPY --from=docker.io/library/postgres:13-bullseye /usr/lib/postgresql /usr/lib/postgresql
|
||||
COPY --from=docker.io/library/postgres:13-bullseye /usr/share/postgresql /usr/share/postgresql
|
||||
RUN mkdir /var/run/postgresql && chown postgres /var/run/postgresql
|
||||
ENV PATH="${PATH}:/usr/lib/postgresql/13/bin"
|
||||
ENV PGDATA=/var/lib/postgresql/data
|
||||
|
||||
@@ -49,17 +49,35 @@ handlers:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
|
||||
{% if not SYNAPSE_LOG_SENSITIVE %}
|
||||
{#
|
||||
If SYNAPSE_LOG_SENSITIVE is unset, then override synapse.storage.SQL to INFO
|
||||
so that DEBUG entries (containing sensitive information) are not emitted.
|
||||
#}
|
||||
loggers:
|
||||
# This is just here so we can leave `loggers` in the config regardless of whether
|
||||
# we configure other loggers below (avoid empty yaml dict error).
|
||||
_placeholder:
|
||||
level: "INFO"
|
||||
|
||||
{% if not SYNAPSE_LOG_SENSITIVE %}
|
||||
{#
|
||||
If SYNAPSE_LOG_SENSITIVE is unset, then override synapse.storage.SQL to INFO
|
||||
so that DEBUG entries (containing sensitive information) are not emitted.
|
||||
#}
|
||||
synapse.storage.SQL:
|
||||
# beware: increasing this to DEBUG will make synapse log sensitive
|
||||
# information such as access tokens.
|
||||
level: INFO
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if SYNAPSE_LOG_TESTING %}
|
||||
{#
|
||||
If Synapse is under test, log a few more useful things for a developer
|
||||
attempting to debug something particularly tricky.
|
||||
|
||||
With `synapse.visibility.filtered_event_debug`, it logs when events are (maybe
|
||||
unexpectedly) filtered out of responses in tests. It's just nice to be able to
|
||||
look at the CI log and figure out why an event isn't being returned.
|
||||
#}
|
||||
synapse.visibility.filtered_event_debug:
|
||||
level: DEBUG
|
||||
{% endif %}
|
||||
|
||||
root:
|
||||
level: {{ SYNAPSE_LOG_LEVEL or "INFO" }}
|
||||
|
||||
@@ -40,6 +40,8 @@
|
||||
# log level. INFO is the default.
|
||||
# * SYNAPSE_LOG_SENSITIVE: If unset, SQL and SQL values won't be logged,
|
||||
# regardless of the SYNAPSE_LOG_LEVEL setting.
|
||||
# * SYNAPSE_LOG_TESTING: if set, Synapse will log additional information useful
|
||||
# for testing.
|
||||
#
|
||||
# NOTE: According to Complement's ENTRYPOINT expectations for a homeserver image (as defined
|
||||
# in the project's README), this script may be run multiple times, and functionality should
|
||||
@@ -947,6 +949,7 @@ def generate_worker_log_config(
|
||||
extra_log_template_args["SYNAPSE_LOG_SENSITIVE"] = environ.get(
|
||||
"SYNAPSE_LOG_SENSITIVE"
|
||||
)
|
||||
extra_log_template_args["SYNAPSE_LOG_TESTING"] = environ.get("SYNAPSE_LOG_TESTING")
|
||||
|
||||
# Render and write the file
|
||||
log_config_filepath = f"/conf/workers/{worker_name}.log.config"
|
||||
|
||||
@@ -10,7 +10,7 @@ ARG PYTHON_VERSION=3.9
|
||||
###
|
||||
# We hardcode the use of Debian bullseye here because this could change upstream
|
||||
# and other Dockerfiles used for testing are expecting bullseye.
|
||||
FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye
|
||||
FROM docker.io/library/python:${PYTHON_VERSION}-slim-bullseye
|
||||
|
||||
# Install Rust and other dependencies (stolen from normal Dockerfile)
|
||||
# install the OS build deps
|
||||
|
||||
@@ -2570,7 +2570,50 @@ Example configuration:
|
||||
```yaml
|
||||
nonrefreshable_access_token_lifetime: 24h
|
||||
```
|
||||
---
|
||||
### `ui_auth`
|
||||
|
||||
The amount of time to allow a user-interactive authentication session to be active.
|
||||
|
||||
This defaults to 0, meaning the user is queried for their credentials
|
||||
before every action, but this can be overridden to allow a single
|
||||
validation to be re-used. This weakens the protections afforded by
|
||||
the user-interactive authentication process, by allowing for multiple
|
||||
(and potentially different) operations to use the same validation session.
|
||||
|
||||
This is ignored for potentially "dangerous" operations (including
|
||||
deactivating an account, modifying an account password, adding a 3PID,
|
||||
and minting additional login tokens).
|
||||
|
||||
Use the `session_timeout` sub-option here to change the time allowed for credential validation.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
ui_auth:
|
||||
session_timeout: "15s"
|
||||
```
|
||||
---
|
||||
### `login_via_existing_session`
|
||||
|
||||
Matrix supports the ability of an existing session to mint a login token for
|
||||
another client.
|
||||
|
||||
Synapse disables this by default as it has security ramifications -- a malicious
|
||||
client could use the mechanism to spawn more than one session.
|
||||
|
||||
The duration of time the generated token is valid for can be configured with the
|
||||
`token_timeout` sub-option.
|
||||
|
||||
User-interactive authentication is required when this is enabled unless the
|
||||
`require_ui_auth` sub-option is set to `False`.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
login_via_existing_session:
|
||||
enabled: true
|
||||
require_ui_auth: false
|
||||
token_timeout: "5m"
|
||||
```
|
||||
---
|
||||
## Metrics
|
||||
Config options related to metrics.
|
||||
@@ -3415,28 +3458,6 @@ password_config:
|
||||
require_uppercase: true
|
||||
```
|
||||
---
|
||||
### `ui_auth`
|
||||
|
||||
The amount of time to allow a user-interactive authentication session to be active.
|
||||
|
||||
This defaults to 0, meaning the user is queried for their credentials
|
||||
before every action, but this can be overridden to allow a single
|
||||
validation to be re-used. This weakens the protections afforded by
|
||||
the user-interactive authentication process, by allowing for multiple
|
||||
(and potentially different) operations to use the same validation session.
|
||||
|
||||
This is ignored for potentially "dangerous" operations (including
|
||||
deactivating an account, modifying an account password, and
|
||||
adding a 3PID).
|
||||
|
||||
Use the `session_timeout` sub-option here to change the time allowed for credential validation.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
ui_auth:
|
||||
session_timeout: "15s"
|
||||
```
|
||||
---
|
||||
## Push
|
||||
Configuration settings related to push notifications
|
||||
|
||||
|
||||
26
mypy.ini
26
mypy.ini
@@ -2,17 +2,29 @@
|
||||
namespace_packages = True
|
||||
plugins = pydantic.mypy, mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
|
||||
follow_imports = normal
|
||||
check_untyped_defs = True
|
||||
show_error_codes = True
|
||||
show_traceback = True
|
||||
mypy_path = stubs
|
||||
warn_unreachable = True
|
||||
warn_unused_ignores = True
|
||||
local_partial_types = True
|
||||
no_implicit_optional = True
|
||||
|
||||
# Strict checks, see mypy --help
|
||||
warn_unused_configs = True
|
||||
# disallow_any_generics = True
|
||||
disallow_subclassing_any = True
|
||||
# disallow_untyped_calls = True
|
||||
disallow_untyped_defs = True
|
||||
strict_equality = True
|
||||
disallow_incomplete_defs = True
|
||||
# check_untyped_defs = True
|
||||
# disallow_untyped_decorators = True
|
||||
warn_redundant_casts = True
|
||||
warn_unused_ignores = True
|
||||
# warn_return_any = True
|
||||
# no_implicit_reexport = True
|
||||
strict_equality = True
|
||||
strict_concatenate = True
|
||||
|
||||
# Run mypy type checking with the minimum supported Python version to catch new usage
|
||||
# that isn't backwards-compatible (types, overloads, etc).
|
||||
python_version = 3.8
|
||||
@@ -31,6 +43,7 @@ warn_unused_ignores = False
|
||||
|
||||
[mypy-synapse.util.caches.treecache]
|
||||
disallow_untyped_defs = False
|
||||
disallow_incomplete_defs = False
|
||||
|
||||
;; Dependencies without annotations
|
||||
;; Before ignoring a module, check to see if type stubs are available.
|
||||
@@ -40,18 +53,18 @@ disallow_untyped_defs = False
|
||||
;; which we can pull in as a dev dependency by adding to `pyproject.toml`'s
|
||||
;; `[tool.poetry.dev-dependencies]` list.
|
||||
|
||||
# https://github.com/lepture/authlib/issues/460
|
||||
[mypy-authlib.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-ijson.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-lxml]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# https://github.com/msgpack/msgpack-python/issues/448
|
||||
[mypy-msgpack]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# https://github.com/wolever/parameterized/issues/143
|
||||
[mypy-parameterized.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -73,6 +86,7 @@ ignore_missing_imports = True
|
||||
[mypy-srvlookup.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
# https://github.com/twisted/treq/pull/366
|
||||
[mypy-treq.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
||||
25
poetry.lock
generated
25
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "alabaster"
|
||||
@@ -1215,6 +1215,21 @@ html5 = ["html5lib"]
|
||||
htmlsoup = ["BeautifulSoup4"]
|
||||
source = ["Cython (>=0.29.7)"]
|
||||
|
||||
[[package]]
|
||||
name = "lxml-stubs"
|
||||
version = "0.4.0"
|
||||
description = "Type annotations for the lxml package"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "lxml-stubs-0.4.0.tar.gz", hash = "sha256:184877b42127256abc2b932ba8bd0ab5ea80bd0b0fee618d16daa40e0b71abee"},
|
||||
{file = "lxml_stubs-0.4.0-py3-none-any.whl", hash = "sha256:3b381e9e82397c64ea3cc4d6f79d1255d015f7b114806d4826218805c10ec003"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["coverage[toml] (==5.2)", "pytest (>=6.0.0)", "pytest-mypy-plugins (==1.9.3)"]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "2.2.0"
|
||||
@@ -3409,22 +3424,22 @@ docs = ["Sphinx", "repoze.sphinx.autointerface"]
|
||||
test = ["zope.i18nmessageid", "zope.testing", "zope.testrunner"]
|
||||
|
||||
[extras]
|
||||
all = ["matrix-synapse-ldap3", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pysaml2", "authlib", "lxml", "sentry-sdk", "jaeger-client", "opentracing", "txredisapi", "hiredis", "Pympler", "pyicu"]
|
||||
all = ["Pympler", "authlib", "hiredis", "jaeger-client", "lxml", "matrix-synapse-ldap3", "opentracing", "psycopg2", "psycopg2cffi", "psycopg2cffi-compat", "pyicu", "pysaml2", "sentry-sdk", "txredisapi"]
|
||||
cache-memory = ["Pympler"]
|
||||
jwt = ["authlib"]
|
||||
matrix-synapse-ldap3 = ["matrix-synapse-ldap3"]
|
||||
oidc = ["authlib"]
|
||||
opentracing = ["jaeger-client", "opentracing"]
|
||||
postgres = ["psycopg2", "psycopg2cffi", "psycopg2cffi-compat"]
|
||||
redis = ["txredisapi", "hiredis"]
|
||||
redis = ["hiredis", "txredisapi"]
|
||||
saml2 = ["pysaml2"]
|
||||
sentry = ["sentry-sdk"]
|
||||
systemd = ["systemd-python"]
|
||||
test = ["parameterized", "idna"]
|
||||
test = ["idna", "parameterized"]
|
||||
url-preview = ["lxml"]
|
||||
user-search = ["pyicu"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.7.1"
|
||||
content-hash = "ef3a16dd66177f7141239e1a2d3e07cc14c08f1e4e0c5127184d022bc062da52"
|
||||
content-hash = "7ad11e62a675e09444cf33ca2de3216fc4efc5874a2575e54d95d577a52439d3"
|
||||
|
||||
@@ -314,6 +314,7 @@ black = ">=22.3.0"
|
||||
ruff = "0.0.265"
|
||||
|
||||
# Typechecking
|
||||
lxml-stubs = ">=0.4.0"
|
||||
mypy = "*"
|
||||
mypy-zope = "*"
|
||||
types-bleach = ">=4.1.0"
|
||||
|
||||
@@ -20,6 +20,8 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from types import FrameType
|
||||
from typing import Collection, Optional, Sequence, Set
|
||||
|
||||
# These are expanded inside the dockerfile to be a fully qualified image name.
|
||||
# e.g. docker.io/library/debian:bullseye
|
||||
DISTS = (
|
||||
"debian:buster", # oldstable: EOL 2022-08
|
||||
"debian:bullseye",
|
||||
|
||||
@@ -269,6 +269,10 @@ if [[ -n "$SYNAPSE_TEST_LOG_LEVEL" ]]; then
|
||||
export PASS_SYNAPSE_LOG_SENSITIVE=1
|
||||
fi
|
||||
|
||||
# Log a few more useful things for a developer attempting to debug something
|
||||
# particularly tricky.
|
||||
export PASS_SYNAPSE_LOG_TESTING=1
|
||||
|
||||
# Run the tests!
|
||||
echo "Images built; running complement"
|
||||
cd "$COMPLEMENT_DIR"
|
||||
|
||||
175
synapse/api/auth/__init__.py
Normal file
175
synapse/api/auth/__init__.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright 2023 The Matrix.org Foundation.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import Requester
|
||||
|
||||
# guests always get this device id.
|
||||
GUEST_DEVICE_ID = "guest_device"
|
||||
|
||||
|
||||
class Auth(Protocol):
|
||||
"""The interface that an auth provider must implement."""
|
||||
|
||||
async def check_user_in_room(
|
||||
self,
|
||||
room_id: str,
|
||||
requester: Requester,
|
||||
allow_departed_users: bool = False,
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Check if the user is in the room, or was at some point.
|
||||
Args:
|
||||
room_id: The room to check.
|
||||
|
||||
user_id: The user to check.
|
||||
|
||||
current_state: Optional map of the current state of the room.
|
||||
If provided then that map is used to check whether they are a
|
||||
member of the room. Otherwise the current membership is
|
||||
loaded from the database.
|
||||
|
||||
allow_departed_users: if True, accept users that were previously
|
||||
members but have now departed.
|
||||
|
||||
Raises:
|
||||
AuthError if the user is/was not in the room.
|
||||
Returns:
|
||||
The current membership of the user in the room and the
|
||||
membership event ID of the user.
|
||||
"""
|
||||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
"""Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
request: An HTTP request with an access_token query parameter.
|
||||
allow_guest: If False, will raise an AuthError if the user making the
|
||||
request is a guest.
|
||||
allow_expired: If True, allow the request through even if the account
|
||||
is expired, or session token lifetime has ended. Note that
|
||||
/login will deliver access tokens regardless of expiration.
|
||||
|
||||
Returns:
|
||||
Resolves to the requester
|
||||
Raises:
|
||||
InvalidClientCredentialsError if no user by that token exists or the token
|
||||
is invalid.
|
||||
AuthError if access is denied for the user in the access token
|
||||
"""
|
||||
|
||||
async def validate_appservice_can_control_user_id(
|
||||
self, app_service: ApplicationService, user_id: str
|
||||
) -> None:
|
||||
"""Validates that the app service is allowed to control
|
||||
the given user.
|
||||
|
||||
Args:
|
||||
app_service: The app service that controls the user
|
||||
user_id: The author MXID that the app service is controlling
|
||||
|
||||
Raises:
|
||||
AuthError: If the application service is not allowed to control the user
|
||||
(user namespace regex does not match, wrong homeserver, etc)
|
||||
or if the user has not been registered yet.
|
||||
"""
|
||||
|
||||
async def get_user_by_access_token(
|
||||
self,
|
||||
token: str,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
"""Validate access token and get user_id from it
|
||||
|
||||
Args:
|
||||
token: The access token to get the user by
|
||||
allow_expired: If False, raises an InvalidClientTokenError
|
||||
if the token is expired
|
||||
|
||||
Raises:
|
||||
InvalidClientTokenError if a user by that token exists, but the token is
|
||||
expired
|
||||
InvalidClientCredentialsError if no user by that token exists or the token
|
||||
is invalid
|
||||
"""
|
||||
|
||||
async def is_server_admin(self, requester: Requester) -> bool:
|
||||
"""Check if the given user is a local server admin.
|
||||
|
||||
Args:
|
||||
requester: user to check
|
||||
|
||||
Returns:
|
||||
True if the user is an admin
|
||||
"""
|
||||
|
||||
async def check_can_change_room_list(
|
||||
self, room_id: str, requester: Requester
|
||||
) -> bool:
|
||||
"""Determine whether the user is allowed to edit the room's entry in the
|
||||
published room list.
|
||||
|
||||
Args:
|
||||
room_id
|
||||
user
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def has_access_token(request: Request) -> bool:
|
||||
"""Checks if the request has an access_token.
|
||||
|
||||
Returns:
|
||||
False if no access_token was given, True otherwise.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_access_token_from_request(request: Request) -> str:
|
||||
"""Extracts the access_token from the request.
|
||||
|
||||
Args:
|
||||
request: The http request.
|
||||
Returns:
|
||||
The access_token
|
||||
Raises:
|
||||
MissingClientTokenError: If there isn't a single access_token in the
|
||||
request
|
||||
"""
|
||||
|
||||
async def check_user_in_room_or_world_readable(
|
||||
self, room_id: str, requester: Requester, allow_departed_users: bool = False
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Checks that the user is or was in the room or the room is world
|
||||
readable. If it isn't then an exception is raised.
|
||||
|
||||
Args:
|
||||
room_id: room to check
|
||||
user_id: user to check
|
||||
allow_departed_users: if True, accept users that were previously
|
||||
members but have now departed
|
||||
|
||||
Returns:
|
||||
Resolves to the current membership of the user in the room and the
|
||||
membership event ID of the user. If the user is not in the room and
|
||||
never has been, then `(Membership.JOIN, None)` is returned.
|
||||
"""
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2023 The Matrix.org Foundation.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,7 +14,6 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import pymacaroons
|
||||
from netaddr import IPAddress
|
||||
|
||||
from twisted.web.server import Request
|
||||
@@ -24,19 +23,11 @@ from synapse.api.constants import EventTypes, HistoryVisibility, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
InvalidClientTokenError,
|
||||
MissingClientTokenError,
|
||||
UnstableSpecAuthError,
|
||||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import (
|
||||
active_span,
|
||||
force_tracing,
|
||||
start_active_span,
|
||||
trace,
|
||||
)
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.types import Requester, create_requester
|
||||
from synapse.util.cancellation import cancellable
|
||||
|
||||
@@ -46,26 +37,13 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# guests always get this device id.
|
||||
GUEST_DEVICE_ID = "guest_device"
|
||||
|
||||
|
||||
class Auth:
|
||||
"""
|
||||
This class contains functions for authenticating users of our client-server API.
|
||||
"""
|
||||
class BaseAuth:
|
||||
"""Common base class for all auth implementations."""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastores().main
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._macaroon_generator = hs.get_macaroon_generator()
|
||||
|
||||
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
|
||||
self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
|
||||
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||
|
||||
async def check_user_in_room(
|
||||
self,
|
||||
@@ -119,139 +97,49 @@ class Auth:
|
||||
errcode=Codes.NOT_JOINED,
|
||||
)
|
||||
|
||||
@cancellable
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
"""Get a registered user's ID.
|
||||
@trace
|
||||
async def check_user_in_room_or_world_readable(
|
||||
self, room_id: str, requester: Requester, allow_departed_users: bool = False
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Checks that the user is or was in the room or the room is world
|
||||
readable. If it isn't then an exception is raised.
|
||||
|
||||
Args:
|
||||
request: An HTTP request with an access_token query parameter.
|
||||
allow_guest: If False, will raise an AuthError if the user making the
|
||||
request is a guest.
|
||||
allow_expired: If True, allow the request through even if the account
|
||||
is expired, or session token lifetime has ended. Note that
|
||||
/login will deliver access tokens regardless of expiration.
|
||||
room_id: room to check
|
||||
user_id: user to check
|
||||
allow_departed_users: if True, accept users that were previously
|
||||
members but have now departed
|
||||
|
||||
Returns:
|
||||
Resolves to the requester
|
||||
Raises:
|
||||
InvalidClientCredentialsError if no user by that token exists or the token
|
||||
is invalid.
|
||||
AuthError if access is denied for the user in the access token
|
||||
Resolves to the current membership of the user in the room and the
|
||||
membership event ID of the user. If the user is not in the room and
|
||||
never has been, then `(Membership.JOIN, None)` is returned.
|
||||
"""
|
||||
parent_span = active_span()
|
||||
with start_active_span("get_user_by_req"):
|
||||
requester = await self._wrapped_get_user_by_req(
|
||||
request, allow_guest, allow_expired
|
||||
)
|
||||
|
||||
if parent_span:
|
||||
if requester.authenticated_entity in self._force_tracing_for_users:
|
||||
# request tracing is enabled for this user, so we need to force it
|
||||
# tracing on for the parent span (which will be the servlet span).
|
||||
#
|
||||
# It's too late for the get_user_by_req span to inherit the setting,
|
||||
# so we also force it on for that.
|
||||
force_tracing()
|
||||
force_tracing(parent_span)
|
||||
parent_span.set_tag(
|
||||
"authenticated_entity", requester.authenticated_entity
|
||||
)
|
||||
parent_span.set_tag("user_id", requester.user.to_string())
|
||||
if requester.device_id is not None:
|
||||
parent_span.set_tag("device_id", requester.device_id)
|
||||
if requester.app_service is not None:
|
||||
parent_span.set_tag("appservice_id", requester.app_service.id)
|
||||
return requester
|
||||
|
||||
@cancellable
|
||||
async def _wrapped_get_user_by_req(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool,
|
||||
allow_expired: bool,
|
||||
) -> Requester:
|
||||
"""Helper for get_user_by_req
|
||||
|
||||
Once get_user_by_req has set up the opentracing span, this does the actual work.
|
||||
"""
|
||||
try:
|
||||
ip_addr = request.getClientAddress().host
|
||||
user_agent = get_request_user_agent(request)
|
||||
|
||||
access_token = self.get_access_token_from_request(request)
|
||||
|
||||
# First check if it could be a request from an appservice
|
||||
requester = await self._get_appservice_user(request)
|
||||
if not requester:
|
||||
# If not, it should be from a regular user
|
||||
requester = await self.get_user_by_access_token(
|
||||
access_token, allow_expired=allow_expired
|
||||
)
|
||||
|
||||
# Deny the request if the user account has expired.
|
||||
# This check is only done for regular users, not appservice ones.
|
||||
if not allow_expired:
|
||||
if await self._account_validity_handler.is_user_expired(
|
||||
requester.user.to_string()
|
||||
):
|
||||
# Raise the error if either an account validity module has determined
|
||||
# the account has expired, or the legacy account validity
|
||||
# implementation is enabled and determined the account has expired
|
||||
raise AuthError(
|
||||
403,
|
||||
"User account has expired",
|
||||
errcode=Codes.EXPIRED_ACCOUNT,
|
||||
)
|
||||
|
||||
if ip_addr and (
|
||||
not requester.app_service or self._track_appservice_user_ips
|
||||
# check_user_in_room will return the most recent membership
|
||||
# event for the user if:
|
||||
# * The user is a non-guest user, and was ever in the room
|
||||
# * The user is a guest user, and has joined the room
|
||||
# else it will throw.
|
||||
return await self.check_user_in_room(
|
||||
room_id, requester, allow_departed_users=allow_departed_users
|
||||
)
|
||||
except AuthError:
|
||||
visibility = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
visibility
|
||||
and visibility.content.get("history_visibility")
|
||||
== HistoryVisibility.WORLD_READABLE
|
||||
):
|
||||
# XXX(quenting): I'm 95% confident that we could skip setting the
|
||||
# device_id to "dummy-device" for appservices, and that the only impact
|
||||
# would be some rows which whould not deduplicate in the 'user_ips'
|
||||
# table during the transition
|
||||
recorded_device_id = (
|
||||
"dummy-device"
|
||||
if requester.device_id is None and requester.app_service is not None
|
||||
else requester.device_id
|
||||
)
|
||||
await self.store.insert_client_ip(
|
||||
user_id=requester.authenticated_entity,
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
user_agent=user_agent,
|
||||
device_id=recorded_device_id,
|
||||
)
|
||||
|
||||
# Track also the puppeted user client IP if enabled and the user is puppeting
|
||||
if (
|
||||
requester.user.to_string() != requester.authenticated_entity
|
||||
and self._track_puppeted_user_ips
|
||||
):
|
||||
await self.store.insert_client_ip(
|
||||
user_id=requester.user.to_string(),
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
user_agent=user_agent,
|
||||
device_id=requester.device_id,
|
||||
)
|
||||
|
||||
if requester.is_guest and not allow_guest:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Guest access not allowed",
|
||||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||
)
|
||||
|
||||
request.requester = requester
|
||||
return requester
|
||||
except KeyError:
|
||||
raise MissingClientTokenError()
|
||||
return Membership.JOIN, None
|
||||
raise AuthError(
|
||||
403,
|
||||
"User %r not in room %s, and room previews are disabled"
|
||||
% (requester.user, room_id),
|
||||
)
|
||||
|
||||
async def validate_appservice_can_control_user_id(
|
||||
self, app_service: ApplicationService, user_id: str
|
||||
@@ -284,184 +172,16 @@ class Auth:
|
||||
403, "Application service has not registered this user (%s)" % user_id
|
||||
)
|
||||
|
||||
@cancellable
|
||||
async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
|
||||
"""
|
||||
Given a request, reads the request parameters to determine:
|
||||
- whether it's an application service that's making this request
|
||||
- what user the application service should be treated as controlling
|
||||
(the user_id URI parameter allows an application service to masquerade
|
||||
any applicable user in its namespace)
|
||||
- what device the application service should be treated as controlling
|
||||
(the device_id[^1] URI parameter allows an application service to masquerade
|
||||
as any device that exists for the relevant user)
|
||||
|
||||
[^1] Unstable and provided by MSC3202.
|
||||
Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
|
||||
|
||||
Returns:
|
||||
the application service `Requester` of that request
|
||||
|
||||
Postconditions:
|
||||
- The `app_service` field in the returned `Requester` is set
|
||||
- The `user_id` field in the returned `Requester` is either the application
|
||||
service sender or the controlled user set by the `user_id` URI parameter
|
||||
- The returned application service is permitted to control the returned user ID.
|
||||
- The returned device ID, if present, has been checked to be a valid device ID
|
||||
for the returned user ID.
|
||||
"""
|
||||
DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id"
|
||||
|
||||
app_service = self.store.get_app_service_by_token(
|
||||
self.get_access_token_from_request(request)
|
||||
)
|
||||
if app_service is None:
|
||||
return None
|
||||
|
||||
if app_service.ip_range_whitelist:
|
||||
ip_address = IPAddress(request.getClientAddress().host)
|
||||
if ip_address not in app_service.ip_range_whitelist:
|
||||
return None
|
||||
|
||||
# This will always be set by the time Twisted calls us.
|
||||
assert request.args is not None
|
||||
|
||||
if b"user_id" in request.args:
|
||||
effective_user_id = request.args[b"user_id"][0].decode("utf8")
|
||||
await self.validate_appservice_can_control_user_id(
|
||||
app_service, effective_user_id
|
||||
)
|
||||
else:
|
||||
effective_user_id = app_service.sender
|
||||
|
||||
effective_device_id: Optional[str] = None
|
||||
|
||||
if (
|
||||
self.hs.config.experimental.msc3202_device_masquerading_enabled
|
||||
and DEVICE_ID_ARG_NAME in request.args
|
||||
):
|
||||
effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8")
|
||||
# We only just set this so it can't be None!
|
||||
assert effective_device_id is not None
|
||||
device_opt = await self.store.get_device(
|
||||
effective_user_id, effective_device_id
|
||||
)
|
||||
if device_opt is None:
|
||||
# For now, use 400 M_EXCLUSIVE if the device doesn't exist.
|
||||
# This is an open thread of discussion on MSC3202 as of 2021-12-09.
|
||||
raise AuthError(
|
||||
400,
|
||||
f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})",
|
||||
Codes.EXCLUSIVE,
|
||||
)
|
||||
|
||||
return create_requester(
|
||||
effective_user_id, app_service=app_service, device_id=effective_device_id
|
||||
)
|
||||
|
||||
async def get_user_by_access_token(
|
||||
self,
|
||||
token: str,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
"""Validate access token and get user_id from it
|
||||
|
||||
Args:
|
||||
token: The access token to get the user by
|
||||
allow_expired: If False, raises an InvalidClientTokenError
|
||||
if the token is expired
|
||||
|
||||
Raises:
|
||||
InvalidClientTokenError if a user by that token exists, but the token is
|
||||
expired
|
||||
InvalidClientCredentialsError if no user by that token exists or the token
|
||||
is invalid
|
||||
"""
|
||||
|
||||
# First look in the database to see if the access token is present
|
||||
# as an opaque token.
|
||||
user_info = await self.store.get_user_by_access_token(token)
|
||||
if user_info:
|
||||
valid_until_ms = user_info.valid_until_ms
|
||||
if (
|
||||
not allow_expired
|
||||
and valid_until_ms is not None
|
||||
and valid_until_ms < self.clock.time_msec()
|
||||
):
|
||||
# there was a valid access token, but it has expired.
|
||||
# soft-logout the user.
|
||||
raise InvalidClientTokenError(
|
||||
msg="Access token has expired", soft_logout=True
|
||||
)
|
||||
|
||||
# Mark the token as used. This is used to invalidate old refresh
|
||||
# tokens after some time.
|
||||
await self.store.mark_access_token_as_used(user_info.token_id)
|
||||
|
||||
requester = create_requester(
|
||||
user_id=user_info.user_id,
|
||||
access_token_id=user_info.token_id,
|
||||
is_guest=user_info.is_guest,
|
||||
shadow_banned=user_info.shadow_banned,
|
||||
device_id=user_info.device_id,
|
||||
authenticated_entity=user_info.token_owner,
|
||||
)
|
||||
|
||||
return requester
|
||||
|
||||
# If the token isn't found in the database, then it could still be a
|
||||
# macaroon for a guest, so we check that here.
|
||||
try:
|
||||
user_id = self._macaroon_generator.verify_guest_token(token)
|
||||
|
||||
# Guest access tokens are not stored in the database (there can
|
||||
# only be one access token per guest, anyway).
|
||||
#
|
||||
# In order to prevent guest access tokens being used as regular
|
||||
# user access tokens (and hence getting around the invalidation
|
||||
# process), we look up the user id and check that it is indeed
|
||||
# a guest user.
|
||||
#
|
||||
# It would of course be much easier to store guest access
|
||||
# tokens in the database as well, but that would break existing
|
||||
# guest tokens.
|
||||
stored_user = await self.store.get_user_by_id(user_id)
|
||||
if not stored_user:
|
||||
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
||||
if not stored_user["is_guest"]:
|
||||
raise InvalidClientTokenError(
|
||||
"Guest access token used for regular user"
|
||||
)
|
||||
|
||||
return create_requester(
|
||||
user_id=user_id,
|
||||
is_guest=True,
|
||||
# all guests get the same device id
|
||||
device_id=GUEST_DEVICE_ID,
|
||||
authenticated_entity=user_id,
|
||||
)
|
||||
except (
|
||||
pymacaroons.exceptions.MacaroonException,
|
||||
TypeError,
|
||||
ValueError,
|
||||
) as e:
|
||||
logger.warning(
|
||||
"Invalid access token in auth: %s %s.",
|
||||
type(e),
|
||||
e,
|
||||
)
|
||||
raise InvalidClientTokenError("Invalid access token passed.")
|
||||
|
||||
async def is_server_admin(self, requester: Requester) -> bool:
|
||||
"""Check if the given user is a local server admin.
|
||||
|
||||
Args:
|
||||
requester: The user making the request, according to the access token.
|
||||
requester: user to check
|
||||
|
||||
Returns:
|
||||
True if the user is an admin
|
||||
"""
|
||||
return await self.store.is_server_admin(requester.user)
|
||||
raise NotImplementedError()
|
||||
|
||||
async def check_can_change_room_list(
|
||||
self, room_id: str, requester: Requester
|
||||
@@ -470,8 +190,8 @@ class Auth:
|
||||
published room list.
|
||||
|
||||
Args:
|
||||
room_id: The room to check.
|
||||
requester: The user making the request, according to the access token.
|
||||
room_id
|
||||
user
|
||||
"""
|
||||
|
||||
is_admin = await self.is_server_admin(requester)
|
||||
@@ -518,7 +238,6 @@ class Auth:
|
||||
return bool(query_params) or bool(auth_headers)
|
||||
|
||||
@staticmethod
|
||||
@cancellable
|
||||
def get_access_token_from_request(request: Request) -> str:
|
||||
"""Extracts the access_token from the request.
|
||||
|
||||
@@ -556,47 +275,77 @@ class Auth:
|
||||
|
||||
return query_params[0].decode("ascii")
|
||||
|
||||
@trace
|
||||
async def check_user_in_room_or_world_readable(
|
||||
self, room_id: str, requester: Requester, allow_departed_users: bool = False
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Checks that the user is or was in the room or the room is world
|
||||
readable. If it isn't then an exception is raised.
|
||||
@cancellable
|
||||
async def get_appservice_user(
|
||||
self, request: Request, access_token: str
|
||||
) -> Optional[Requester]:
|
||||
"""
|
||||
Given a request, reads the request parameters to determine:
|
||||
- whether it's an application service that's making this request
|
||||
- what user the application service should be treated as controlling
|
||||
(the user_id URI parameter allows an application service to masquerade
|
||||
any applicable user in its namespace)
|
||||
- what device the application service should be treated as controlling
|
||||
(the device_id[^1] URI parameter allows an application service to masquerade
|
||||
as any device that exists for the relevant user)
|
||||
|
||||
Args:
|
||||
room_id: The room to check.
|
||||
requester: The user making the request, according to the access token.
|
||||
allow_departed_users: If True, accept users that were previously
|
||||
members but have now departed.
|
||||
[^1] Unstable and provided by MSC3202.
|
||||
Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
|
||||
|
||||
Returns:
|
||||
Resolves to the current membership of the user in the room and the
|
||||
membership event ID of the user. If the user is not in the room and
|
||||
never has been, then `(Membership.JOIN, None)` is returned.
|
||||
"""
|
||||
the application service `Requester` of that request
|
||||
|
||||
try:
|
||||
# check_user_in_room will return the most recent membership
|
||||
# event for the user if:
|
||||
# * The user is a non-guest user, and was ever in the room
|
||||
# * The user is a guest user, and has joined the room
|
||||
# else it will throw.
|
||||
return await self.check_user_in_room(
|
||||
room_id, requester, allow_departed_users=allow_departed_users
|
||||
Postconditions:
|
||||
- The `app_service` field in the returned `Requester` is set
|
||||
- The `user_id` field in the returned `Requester` is either the application
|
||||
service sender or the controlled user set by the `user_id` URI parameter
|
||||
- The returned application service is permitted to control the returned user ID.
|
||||
- The returned device ID, if present, has been checked to be a valid device ID
|
||||
for the returned user ID.
|
||||
"""
|
||||
DEVICE_ID_ARG_NAME = b"org.matrix.msc3202.device_id"
|
||||
|
||||
app_service = self.store.get_app_service_by_token(access_token)
|
||||
if app_service is None:
|
||||
return None
|
||||
|
||||
if app_service.ip_range_whitelist:
|
||||
ip_address = IPAddress(request.getClientAddress().host)
|
||||
if ip_address not in app_service.ip_range_whitelist:
|
||||
return None
|
||||
|
||||
# This will always be set by the time Twisted calls us.
|
||||
assert request.args is not None
|
||||
|
||||
if b"user_id" in request.args:
|
||||
effective_user_id = request.args[b"user_id"][0].decode("utf8")
|
||||
await self.validate_appservice_can_control_user_id(
|
||||
app_service, effective_user_id
|
||||
)
|
||||
except AuthError:
|
||||
visibility = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
visibility
|
||||
and visibility.content.get("history_visibility")
|
||||
== HistoryVisibility.WORLD_READABLE
|
||||
):
|
||||
return Membership.JOIN, None
|
||||
raise UnstableSpecAuthError(
|
||||
403,
|
||||
"User %s not in room %s, and room previews are disabled"
|
||||
% (requester.user, room_id),
|
||||
errcode=Codes.NOT_JOINED,
|
||||
else:
|
||||
effective_user_id = app_service.sender
|
||||
|
||||
effective_device_id: Optional[str] = None
|
||||
|
||||
if (
|
||||
self.hs.config.experimental.msc3202_device_masquerading_enabled
|
||||
and DEVICE_ID_ARG_NAME in request.args
|
||||
):
|
||||
effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8")
|
||||
# We only just set this so it can't be None!
|
||||
assert effective_device_id is not None
|
||||
device_opt = await self.store.get_device(
|
||||
effective_user_id, effective_device_id
|
||||
)
|
||||
if device_opt is None:
|
||||
# For now, use 400 M_EXCLUSIVE if the device doesn't exist.
|
||||
# This is an open thread of discussion on MSC3202 as of 2021-12-09.
|
||||
raise AuthError(
|
||||
400,
|
||||
f"Application service trying to use a device that doesn't exist ('{effective_device_id}' for {effective_user_id})",
|
||||
Codes.EXCLUSIVE,
|
||||
)
|
||||
|
||||
return create_requester(
|
||||
effective_user_id, app_service=app_service, device_id=effective_device_id
|
||||
)
|
||||
291
synapse/api/auth/internal.py
Normal file
291
synapse/api/auth/internal.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# Copyright 2023 The Matrix.org Foundation.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
InvalidClientTokenError,
|
||||
MissingClientTokenError,
|
||||
)
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
|
||||
from synapse.types import Requester, create_requester
|
||||
from synapse.util.cancellation import cancellable
|
||||
|
||||
from . import GUEST_DEVICE_ID
|
||||
from .base import BaseAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InternalAuth(BaseAuth):
|
||||
"""
|
||||
This class contains functions for authenticating users of our client-server API.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
self._macaroon_generator = hs.get_macaroon_generator()
|
||||
|
||||
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
|
||||
self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
|
||||
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||
|
||||
@cancellable
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
"""Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
request: An HTTP request with an access_token query parameter.
|
||||
allow_guest: If False, will raise an AuthError if the user making the
|
||||
request is a guest.
|
||||
allow_expired: If True, allow the request through even if the account
|
||||
is expired, or session token lifetime has ended. Note that
|
||||
/login will deliver access tokens regardless of expiration.
|
||||
|
||||
Returns:
|
||||
Resolves to the requester
|
||||
Raises:
|
||||
InvalidClientCredentialsError if no user by that token exists or the token
|
||||
is invalid.
|
||||
AuthError if access is denied for the user in the access token
|
||||
"""
|
||||
parent_span = active_span()
|
||||
with start_active_span("get_user_by_req"):
|
||||
requester = await self._wrapped_get_user_by_req(
|
||||
request, allow_guest, allow_expired
|
||||
)
|
||||
|
||||
if parent_span:
|
||||
if requester.authenticated_entity in self._force_tracing_for_users:
|
||||
# request tracing is enabled for this user, so we need to force it
|
||||
# tracing on for the parent span (which will be the servlet span).
|
||||
#
|
||||
# It's too late for the get_user_by_req span to inherit the setting,
|
||||
# so we also force it on for that.
|
||||
force_tracing()
|
||||
force_tracing(parent_span)
|
||||
parent_span.set_tag(
|
||||
"authenticated_entity", requester.authenticated_entity
|
||||
)
|
||||
parent_span.set_tag("user_id", requester.user.to_string())
|
||||
if requester.device_id is not None:
|
||||
parent_span.set_tag("device_id", requester.device_id)
|
||||
if requester.app_service is not None:
|
||||
parent_span.set_tag("appservice_id", requester.app_service.id)
|
||||
return requester
|
||||
|
||||
@cancellable
|
||||
async def _wrapped_get_user_by_req(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool,
|
||||
allow_expired: bool,
|
||||
) -> Requester:
|
||||
"""Helper for get_user_by_req
|
||||
|
||||
Once get_user_by_req has set up the opentracing span, this does the actual work.
|
||||
"""
|
||||
try:
|
||||
ip_addr = request.getClientAddress().host
|
||||
user_agent = get_request_user_agent(request)
|
||||
|
||||
access_token = self.get_access_token_from_request(request)
|
||||
|
||||
# First check if it could be a request from an appservice
|
||||
requester = await self.get_appservice_user(request, access_token)
|
||||
if not requester:
|
||||
# If not, it should be from a regular user
|
||||
requester = await self.get_user_by_access_token(
|
||||
access_token, allow_expired=allow_expired
|
||||
)
|
||||
|
||||
# Deny the request if the user account has expired.
|
||||
# This check is only done for regular users, not appservice ones.
|
||||
if not allow_expired:
|
||||
if await self._account_validity_handler.is_user_expired(
|
||||
requester.user.to_string()
|
||||
):
|
||||
# Raise the error if either an account validity module has determined
|
||||
# the account has expired, or the legacy account validity
|
||||
# implementation is enabled and determined the account has expired
|
||||
raise AuthError(
|
||||
403,
|
||||
"User account has expired",
|
||||
errcode=Codes.EXPIRED_ACCOUNT,
|
||||
)
|
||||
|
||||
if ip_addr and (
|
||||
not requester.app_service or self._track_appservice_user_ips
|
||||
):
|
||||
# XXX(quenting): I'm 95% confident that we could skip setting the
|
||||
# device_id to "dummy-device" for appservices, and that the only impact
|
||||
# would be some rows which whould not deduplicate in the 'user_ips'
|
||||
# table during the transition
|
||||
recorded_device_id = (
|
||||
"dummy-device"
|
||||
if requester.device_id is None and requester.app_service is not None
|
||||
else requester.device_id
|
||||
)
|
||||
await self.store.insert_client_ip(
|
||||
user_id=requester.authenticated_entity,
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
user_agent=user_agent,
|
||||
device_id=recorded_device_id,
|
||||
)
|
||||
|
||||
# Track also the puppeted user client IP if enabled and the user is puppeting
|
||||
if (
|
||||
requester.user.to_string() != requester.authenticated_entity
|
||||
and self._track_puppeted_user_ips
|
||||
):
|
||||
await self.store.insert_client_ip(
|
||||
user_id=requester.user.to_string(),
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
user_agent=user_agent,
|
||||
device_id=requester.device_id,
|
||||
)
|
||||
|
||||
if requester.is_guest and not allow_guest:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Guest access not allowed",
|
||||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||
)
|
||||
|
||||
request.requester = requester
|
||||
return requester
|
||||
except KeyError:
|
||||
raise MissingClientTokenError()
|
||||
|
||||
async def get_user_by_access_token(
|
||||
self,
|
||||
token: str,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
"""Validate access token and get user_id from it
|
||||
|
||||
Args:
|
||||
token: The access token to get the user by
|
||||
allow_expired: If False, raises an InvalidClientTokenError
|
||||
if the token is expired
|
||||
|
||||
Raises:
|
||||
InvalidClientTokenError if a user by that token exists, but the token is
|
||||
expired
|
||||
InvalidClientCredentialsError if no user by that token exists or the token
|
||||
is invalid
|
||||
"""
|
||||
|
||||
# First look in the database to see if the access token is present
|
||||
# as an opaque token.
|
||||
user_info = await self.store.get_user_by_access_token(token)
|
||||
if user_info:
|
||||
valid_until_ms = user_info.valid_until_ms
|
||||
if (
|
||||
not allow_expired
|
||||
and valid_until_ms is not None
|
||||
and valid_until_ms < self.clock.time_msec()
|
||||
):
|
||||
# there was a valid access token, but it has expired.
|
||||
# soft-logout the user.
|
||||
raise InvalidClientTokenError(
|
||||
msg="Access token has expired", soft_logout=True
|
||||
)
|
||||
|
||||
# Mark the token as used. This is used to invalidate old refresh
|
||||
# tokens after some time.
|
||||
await self.store.mark_access_token_as_used(user_info.token_id)
|
||||
|
||||
requester = create_requester(
|
||||
user_id=user_info.user_id,
|
||||
access_token_id=user_info.token_id,
|
||||
is_guest=user_info.is_guest,
|
||||
shadow_banned=user_info.shadow_banned,
|
||||
device_id=user_info.device_id,
|
||||
authenticated_entity=user_info.token_owner,
|
||||
)
|
||||
|
||||
return requester
|
||||
|
||||
# If the token isn't found in the database, then it could still be a
|
||||
# macaroon for a guest, so we check that here.
|
||||
try:
|
||||
user_id = self._macaroon_generator.verify_guest_token(token)
|
||||
|
||||
# Guest access tokens are not stored in the database (there can
|
||||
# only be one access token per guest, anyway).
|
||||
#
|
||||
# In order to prevent guest access tokens being used as regular
|
||||
# user access tokens (and hence getting around the invalidation
|
||||
# process), we look up the user id and check that it is indeed
|
||||
# a guest user.
|
||||
#
|
||||
# It would of course be much easier to store guest access
|
||||
# tokens in the database as well, but that would break existing
|
||||
# guest tokens.
|
||||
stored_user = await self.store.get_user_by_id(user_id)
|
||||
if not stored_user:
|
||||
raise InvalidClientTokenError("Unknown user_id %s" % user_id)
|
||||
if not stored_user["is_guest"]:
|
||||
raise InvalidClientTokenError(
|
||||
"Guest access token used for regular user"
|
||||
)
|
||||
|
||||
return create_requester(
|
||||
user_id=user_id,
|
||||
is_guest=True,
|
||||
# all guests get the same device id
|
||||
device_id=GUEST_DEVICE_ID,
|
||||
authenticated_entity=user_id,
|
||||
)
|
||||
except (
|
||||
pymacaroons.exceptions.MacaroonException,
|
||||
TypeError,
|
||||
ValueError,
|
||||
) as e:
|
||||
logger.warning(
|
||||
"Invalid access token in auth: %s %s.",
|
||||
type(e),
|
||||
e,
|
||||
)
|
||||
raise InvalidClientTokenError("Invalid access token passed.")
|
||||
|
||||
async def is_server_admin(self, requester: Requester) -> bool:
|
||||
"""Check if the given user is a local server admin.
|
||||
|
||||
Args:
|
||||
requester: The user making the request, according to the access token.
|
||||
|
||||
Returns:
|
||||
True if the user is an admin
|
||||
"""
|
||||
return await self.store.is_server_admin(requester.user)
|
||||
352
synapse/api/auth/msc3861_delegated.py
Normal file
352
synapse/api/auth/msc3861_delegated.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# Copyright 2023 The Matrix.org Foundation.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from authlib.oauth2 import ClientAuth
|
||||
from authlib.oauth2.auth import encode_client_secret_basic, encode_client_secret_post
|
||||
from authlib.oauth2.rfc7523 import ClientSecretJWT, PrivateKeyJWT, private_key_jwt_sign
|
||||
from authlib.oauth2.rfc7662 import IntrospectionToken
|
||||
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||
|
||||
from twisted.web.client import readBody
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.api.auth.base import BaseAuth
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
HttpResponseException,
|
||||
InvalidClientTokenError,
|
||||
OAuthInsufficientScopeError,
|
||||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Scope as defined by MSC2967
|
||||
# https://github.com/matrix-org/matrix-spec-proposals/pull/2967
|
||||
SCOPE_MATRIX_API = "urn:matrix:org.matrix.msc2967.client:api:*"
|
||||
SCOPE_MATRIX_GUEST = "urn:matrix:org.matrix.msc2967.client:api:guest"
|
||||
SCOPE_MATRIX_DEVICE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:"
|
||||
|
||||
# Scope which allows access to the Synapse admin API
|
||||
SCOPE_SYNAPSE_ADMIN = "urn:synapse:admin:*"
|
||||
|
||||
|
||||
def scope_to_list(scope: str) -> List[str]:
|
||||
"""Convert a scope string to a list of scope tokens"""
|
||||
return scope.strip().split(" ")
|
||||
|
||||
|
||||
class PrivateKeyJWTWithKid(PrivateKeyJWT): # type: ignore[misc]
|
||||
"""An implementation of the private_key_jwt client auth method that includes a kid header.
|
||||
|
||||
This is needed because some providers (Keycloak) require the kid header to figure
|
||||
out which key to use to verify the signature.
|
||||
"""
|
||||
|
||||
def sign(self, auth: Any, token_endpoint: str) -> bytes:
|
||||
return private_key_jwt_sign(
|
||||
auth.client_secret,
|
||||
client_id=auth.client_id,
|
||||
token_endpoint=token_endpoint,
|
||||
claims=self.claims,
|
||||
header={"kid": auth.client_secret["kid"]},
|
||||
)
|
||||
|
||||
|
||||
class MSC3861DelegatedAuth(BaseAuth):
|
||||
AUTH_METHODS = {
|
||||
"client_secret_post": encode_client_secret_post,
|
||||
"client_secret_basic": encode_client_secret_basic,
|
||||
"client_secret_jwt": ClientSecretJWT(),
|
||||
"private_key_jwt": PrivateKeyJWTWithKid(),
|
||||
}
|
||||
|
||||
EXTERNAL_ID_PROVIDER = "oauth-delegated"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self._config = hs.config.experimental.msc3861
|
||||
auth_method = MSC3861DelegatedAuth.AUTH_METHODS.get(
|
||||
self._config.client_auth_method.value, None
|
||||
)
|
||||
# Those assertions are already checked when parsing the config
|
||||
assert self._config.enabled, "OAuth delegation is not enabled"
|
||||
assert self._config.issuer, "No issuer provided"
|
||||
assert self._config.client_id, "No client_id provided"
|
||||
assert auth_method is not None, "Invalid client_auth_method provided"
|
||||
|
||||
self._http_client = hs.get_proxied_http_client()
|
||||
self._hostname = hs.hostname
|
||||
self._admin_token = self._config.admin_token
|
||||
|
||||
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
|
||||
|
||||
if isinstance(auth_method, PrivateKeyJWTWithKid):
|
||||
# Use the JWK as the client secret when using the private_key_jwt method
|
||||
assert self._config.jwk, "No JWK provided"
|
||||
self._client_auth = ClientAuth(
|
||||
self._config.client_id, self._config.jwk, auth_method
|
||||
)
|
||||
else:
|
||||
# Else use the client secret
|
||||
assert self._config.client_secret, "No client_secret provided"
|
||||
self._client_auth = ClientAuth(
|
||||
self._config.client_id, self._config.client_secret, auth_method
|
||||
)
|
||||
|
||||
async def _load_metadata(self) -> OpenIDProviderMetadata:
|
||||
if self._config.issuer_metadata is not None:
|
||||
return OpenIDProviderMetadata(**self._config.issuer_metadata)
|
||||
url = get_well_known_url(self._config.issuer, external=True)
|
||||
response = await self._http_client.get_json(url)
|
||||
metadata = OpenIDProviderMetadata(**response)
|
||||
# metadata.validate_introspection_endpoint()
|
||||
return metadata
|
||||
|
||||
async def _introspect_token(self, token: str) -> IntrospectionToken:
|
||||
"""
|
||||
Send a token to the introspection endpoint and returns the introspection response
|
||||
|
||||
Parameters:
|
||||
token: The token to introspect
|
||||
|
||||
Raises:
|
||||
HttpResponseException: If the introspection endpoint returns a non-2xx response
|
||||
ValueError: If the introspection endpoint returns an invalid JSON response
|
||||
JSONDecodeError: If the introspection endpoint returns a non-JSON response
|
||||
Exception: If the HTTP request fails
|
||||
|
||||
Returns:
|
||||
The introspection response
|
||||
"""
|
||||
metadata = await self._issuer_metadata.get()
|
||||
introspection_endpoint = metadata.get("introspection_endpoint")
|
||||
raw_headers: Dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": str(self._http_client.user_agent, "utf-8"),
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
args = {"token": token, "token_type_hint": "access_token"}
|
||||
body = urlencode(args, True)
|
||||
|
||||
# Fill the body/headers with credentials
|
||||
uri, raw_headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
|
||||
)
|
||||
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||
|
||||
# Do the actual request
|
||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||
# check the HTTP status code, and we do the body encoding ourselves.
|
||||
response = await self._http_client.request(
|
||||
method="POST",
|
||||
uri=uri,
|
||||
data=body.encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
resp_body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if response.code < 200 or response.code >= 300:
|
||||
raise HttpResponseException(
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
resp_body,
|
||||
)
|
||||
|
||||
resp = json_decoder.decode(resp_body.decode("utf-8"))
|
||||
|
||||
if not isinstance(resp, dict):
|
||||
raise ValueError(
|
||||
"The introspection endpoint returned an invalid JSON response."
|
||||
)
|
||||
|
||||
return IntrospectionToken(**resp)
|
||||
|
||||
async def is_server_admin(self, requester: Requester) -> bool:
|
||||
return "urn:synapse:admin:*" in requester.scope
|
||||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
access_token = self.get_access_token_from_request(request)
|
||||
|
||||
requester = await self.get_appservice_user(request, access_token)
|
||||
if not requester:
|
||||
# TODO: we probably want to assert the allow_guest inside this call
|
||||
# so that we don't provision the user if they don't have enough permission:
|
||||
requester = await self.get_user_by_access_token(access_token, allow_expired)
|
||||
|
||||
if not allow_guest and requester.is_guest:
|
||||
raise OAuthInsufficientScopeError([SCOPE_MATRIX_API])
|
||||
|
||||
request.requester = requester
|
||||
|
||||
return requester
|
||||
|
||||
async def get_user_by_access_token(
|
||||
self,
|
||||
token: str,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
if self._admin_token is not None and token == self._admin_token:
|
||||
# XXX: This is a temporary solution so that the admin API can be called by
|
||||
# the OIDC provider. This will be removed once we have OIDC client
|
||||
# credentials grant support in matrix-authentication-service.
|
||||
logging.info("Admin toked used")
|
||||
# XXX: that user doesn't exist and won't be provisioned.
|
||||
# This is mostly fine for admin calls, but we should also think about doing
|
||||
# requesters without a user_id.
|
||||
admin_user = UserID("__oidc_admin", self._hostname)
|
||||
return create_requester(
|
||||
user_id=admin_user,
|
||||
scope=["urn:synapse:admin:*"],
|
||||
)
|
||||
|
||||
try:
|
||||
introspection_result = await self._introspect_token(token)
|
||||
except Exception:
|
||||
logger.exception("Failed to introspect token")
|
||||
raise SynapseError(503, "Unable to introspect the access token")
|
||||
|
||||
logger.info(f"Introspection result: {introspection_result!r}")
|
||||
|
||||
# TODO: introspection verification should be more extensive, especially:
|
||||
# - verify the audience
|
||||
if not introspection_result.get("active"):
|
||||
raise InvalidClientTokenError("Token is not active")
|
||||
|
||||
# Let's look at the scope
|
||||
scope: List[str] = scope_to_list(introspection_result.get("scope", ""))
|
||||
|
||||
# Determine type of user based on presence of particular scopes
|
||||
has_user_scope = SCOPE_MATRIX_API in scope
|
||||
has_guest_scope = SCOPE_MATRIX_GUEST in scope
|
||||
|
||||
if not has_user_scope and not has_guest_scope:
|
||||
raise InvalidClientTokenError("No scope in token granting user rights")
|
||||
|
||||
# Match via the sub claim
|
||||
sub: Optional[str] = introspection_result.get("sub")
|
||||
if sub is None:
|
||||
raise InvalidClientTokenError(
|
||||
"Invalid sub claim in the introspection result"
|
||||
)
|
||||
|
||||
user_id_str = await self.store.get_user_by_external_id(
|
||||
MSC3861DelegatedAuth.EXTERNAL_ID_PROVIDER, sub
|
||||
)
|
||||
if user_id_str is None:
|
||||
# If we could not find a user via the external_id, it either does not exist,
|
||||
# or the external_id was never recorded
|
||||
|
||||
# TODO: claim mapping should be configurable
|
||||
username: Optional[str] = introspection_result.get("username")
|
||||
if username is None or not isinstance(username, str):
|
||||
raise AuthError(
|
||||
500,
|
||||
"Invalid username claim in the introspection result",
|
||||
)
|
||||
user_id = UserID(username, self._hostname)
|
||||
|
||||
# First try to find a user from the username claim
|
||||
user_info = await self.store.get_userinfo_by_id(user_id=user_id.to_string())
|
||||
if user_info is None:
|
||||
# If the user does not exist, we should create it on the fly
|
||||
# TODO: we could use SCIM to provision users ahead of time and listen
|
||||
# for SCIM SET events if those ever become standard:
|
||||
# https://datatracker.ietf.org/doc/html/draft-hunt-scim-notify-00
|
||||
|
||||
# TODO: claim mapping should be configurable
|
||||
# If present, use the name claim as the displayname
|
||||
name: Optional[str] = introspection_result.get("name")
|
||||
|
||||
await self.store.register_user(
|
||||
user_id=user_id.to_string(), create_profile_with_displayname=name
|
||||
)
|
||||
|
||||
# And record the sub as external_id
|
||||
await self.store.record_user_external_id(
|
||||
MSC3861DelegatedAuth.EXTERNAL_ID_PROVIDER, sub, user_id.to_string()
|
||||
)
|
||||
else:
|
||||
user_id = UserID.from_string(user_id_str)
|
||||
|
||||
# Find device_ids in scope
|
||||
# We only allow a single device_id in the scope, so we find them all in the
|
||||
# scope list, and raise if there are more than one. The OIDC server should be
|
||||
# the one enforcing valid scopes, so we raise a 500 if we find an invalid scope.
|
||||
device_ids = [
|
||||
tok[len(SCOPE_MATRIX_DEVICE_PREFIX) :]
|
||||
for tok in scope
|
||||
if tok.startswith(SCOPE_MATRIX_DEVICE_PREFIX)
|
||||
]
|
||||
|
||||
if len(device_ids) > 1:
|
||||
raise AuthError(
|
||||
500,
|
||||
"Multiple device IDs in scope",
|
||||
)
|
||||
|
||||
device_id = device_ids[0] if device_ids else None
|
||||
if device_id is not None:
|
||||
# Sanity check the device_id
|
||||
if len(device_id) > 255 or len(device_id) < 1:
|
||||
raise AuthError(
|
||||
500,
|
||||
"Invalid device ID in scope",
|
||||
)
|
||||
|
||||
# Create the device on the fly if it does not exist
|
||||
try:
|
||||
await self.store.get_device(
|
||||
user_id=user_id.to_string(), device_id=device_id
|
||||
)
|
||||
except StoreError:
|
||||
await self.store.store_device(
|
||||
user_id=user_id.to_string(),
|
||||
device_id=device_id,
|
||||
initial_device_display_name="OIDC-native client",
|
||||
)
|
||||
|
||||
# TODO: there is a few things missing in the requester here, which still need
|
||||
# to be figured out, like:
|
||||
# - impersonation, with the `authenticated_entity`, which is used for
|
||||
# rate-limiting, MAU limits, etc.
|
||||
# - shadow-banning, with the `shadow_banned` flag
|
||||
# - a proper solution for appservices, which still needs to be figured out in
|
||||
# the context of MSC3861
|
||||
return create_requester(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
scope=scope,
|
||||
is_guest=(has_guest_scope and not has_user_scope),
|
||||
)
|
||||
@@ -119,14 +119,20 @@ class Codes(str, Enum):
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
"""An exception with integer code and message string attributes.
|
||||
"""An exception with integer code, a message string attributes and optional headers.
|
||||
|
||||
Attributes:
|
||||
code: HTTP error code
|
||||
msg: string describing the error
|
||||
headers: optional response headers to send
|
||||
"""
|
||||
|
||||
def __init__(self, code: Union[int, HTTPStatus], msg: str):
|
||||
def __init__(
|
||||
self,
|
||||
code: Union[int, HTTPStatus],
|
||||
msg: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
super().__init__("%d: %s" % (code, msg))
|
||||
|
||||
# Some calls to this method pass instances of http.HTTPStatus for `code`.
|
||||
@@ -137,6 +143,7 @@ class CodeMessageException(RuntimeError):
|
||||
# To eliminate this behaviour, we convert them to their integer equivalents here.
|
||||
self.code = int(code)
|
||||
self.msg = msg
|
||||
self.headers = headers
|
||||
|
||||
|
||||
class RedirectException(CodeMessageException):
|
||||
@@ -182,6 +189,7 @@ class SynapseError(CodeMessageException):
|
||||
msg: str,
|
||||
errcode: str = Codes.UNKNOWN,
|
||||
additional_fields: Optional[Dict] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Constructs a synapse error.
|
||||
|
||||
@@ -190,7 +198,7 @@ class SynapseError(CodeMessageException):
|
||||
msg: The human-readable error message.
|
||||
errcode: The matrix error code e.g 'M_FORBIDDEN'
|
||||
"""
|
||||
super().__init__(code, msg)
|
||||
super().__init__(code, msg, headers)
|
||||
self.errcode = errcode
|
||||
if additional_fields is None:
|
||||
self._additional_fields: Dict = {}
|
||||
@@ -335,6 +343,20 @@ class AuthError(SynapseError):
|
||||
super().__init__(code, msg, errcode, additional_fields)
|
||||
|
||||
|
||||
class OAuthInsufficientScopeError(SynapseError):
|
||||
"""An error raised when the caller does not have sufficient scope to perform the requested action"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
required_scopes: List[str],
|
||||
):
|
||||
headers = {
|
||||
"WWW-Authenticate": 'Bearer error="insufficient_scope", scope="%s"'
|
||||
% (" ".join(required_scopes))
|
||||
}
|
||||
super().__init__(401, "Insufficient scope", Codes.FORBIDDEN, None, headers)
|
||||
|
||||
|
||||
class UnstableSpecAuthError(AuthError):
|
||||
"""An error raised when a new error code is being proposed to replace a previous one.
|
||||
This error will return a "org.matrix.unstable.errcode" property with the new error code,
|
||||
|
||||
@@ -152,9 +152,9 @@ class Filtering:
|
||||
self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
|
||||
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: UserID, filter_id: Union[int, str]
|
||||
) -> "FilterCollection":
|
||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||
result = await self.store.get_user_filter(user_id, filter_id)
|
||||
return FilterCollection(self._hs, result)
|
||||
|
||||
def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> Awaitable[int]:
|
||||
|
||||
@@ -29,7 +29,14 @@ class AuthConfig(Config):
|
||||
if password_config is None:
|
||||
password_config = {}
|
||||
|
||||
passwords_enabled = password_config.get("enabled", True)
|
||||
# The default value of password_config.enabled is True, unless msc3861 is enabled.
|
||||
msc3861_enabled = (
|
||||
config.get("experimental_features", {})
|
||||
.get("msc3861", {})
|
||||
.get("enabled", False)
|
||||
)
|
||||
passwords_enabled = password_config.get("enabled", not msc3861_enabled)
|
||||
|
||||
# 'only_for_reauth' allows users who have previously set a password to use it,
|
||||
# even though passwords would otherwise be disabled.
|
||||
passwords_for_reauth_only = passwords_enabled == "only_for_reauth"
|
||||
@@ -53,3 +60,13 @@ class AuthConfig(Config):
|
||||
self.ui_auth_session_timeout = self.parse_duration(
|
||||
ui_auth.get("session_timeout", 0)
|
||||
)
|
||||
|
||||
# Logging in with an existing session.
|
||||
login_via_existing = config.get("login_via_existing_session", {})
|
||||
self.login_via_existing_enabled = login_via_existing.get("enabled", False)
|
||||
self.login_via_existing_require_ui_auth = login_via_existing.get(
|
||||
"require_ui_auth", True
|
||||
)
|
||||
self.login_via_existing_token_timeout = self.parse_duration(
|
||||
login_via_existing.get("token_timeout", "5m")
|
||||
)
|
||||
|
||||
@@ -12,15 +12,216 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import attr
|
||||
import attr.validators
|
||||
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config._base import Config
|
||||
from synapse.config._base import Config, RootConfig
|
||||
from synapse.types import JsonDict
|
||||
|
||||
# Determine whether authlib is installed.
|
||||
try:
|
||||
import authlib # noqa: F401
|
||||
|
||||
HAS_AUTHLIB = True
|
||||
except ImportError:
|
||||
HAS_AUTHLIB = False
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Only import this if we're type checking, as it might not be installed at runtime.
|
||||
from authlib.jose.rfc7517 import JsonWebKey
|
||||
|
||||
|
||||
class ClientAuthMethod(enum.Enum):
|
||||
"""List of supported client auth methods."""
|
||||
|
||||
CLIENT_SECRET_POST = "client_secret_post"
|
||||
CLIENT_SECRET_BASIC = "client_secret_basic"
|
||||
CLIENT_SECRET_JWT = "client_secret_jwt"
|
||||
PRIVATE_KEY_JWT = "private_key_jwt"
|
||||
|
||||
|
||||
def _parse_jwks(jwks: Optional[JsonDict]) -> Optional["JsonWebKey"]:
|
||||
"""A helper function to parse a JWK dict into a JsonWebKey."""
|
||||
|
||||
if jwks is None:
|
||||
return None
|
||||
|
||||
from authlib.jose.rfc7517 import JsonWebKey
|
||||
|
||||
return JsonWebKey.import_key(jwks)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class MSC3861:
|
||||
"""Configuration for MSC3861: Matrix architecture change to delegate authentication via OIDC"""
|
||||
|
||||
enabled: bool = attr.ib(default=False, validator=attr.validators.instance_of(bool))
|
||||
"""Whether to enable MSC3861 auth delegation."""
|
||||
|
||||
@enabled.validator
|
||||
def _check_enabled(self, attribute: attr.Attribute, value: bool) -> None:
|
||||
# Only allow enabling MSC3861 if authlib is installed
|
||||
if value and not HAS_AUTHLIB:
|
||||
raise ConfigError(
|
||||
"MSC3861 is enabled but authlib is not installed. "
|
||||
"Please install authlib to use MSC3861.",
|
||||
("experimental", "msc3861", "enabled"),
|
||||
)
|
||||
|
||||
issuer: str = attr.ib(default="", validator=attr.validators.instance_of(str))
|
||||
"""The URL of the OIDC Provider."""
|
||||
|
||||
issuer_metadata: Optional[JsonDict] = attr.ib(default=None)
|
||||
"""The issuer metadata to use, otherwise discovered from /.well-known/openid-configuration as per MSC2965."""
|
||||
|
||||
client_id: str = attr.ib(
|
||||
default="",
|
||||
validator=attr.validators.instance_of(str),
|
||||
)
|
||||
"""The client ID to use when calling the introspection endpoint."""
|
||||
|
||||
client_auth_method: ClientAuthMethod = attr.ib(
|
||||
default=ClientAuthMethod.CLIENT_SECRET_POST, converter=ClientAuthMethod
|
||||
)
|
||||
"""The auth method used when calling the introspection endpoint."""
|
||||
|
||||
client_secret: Optional[str] = attr.ib(
|
||||
default=None,
|
||||
validator=attr.validators.optional(attr.validators.instance_of(str)),
|
||||
)
|
||||
"""
|
||||
The client secret to use when calling the introspection endpoint,
|
||||
when using any of the client_secret_* client auth methods.
|
||||
"""
|
||||
|
||||
jwk: Optional["JsonWebKey"] = attr.ib(default=None, converter=_parse_jwks)
|
||||
"""
|
||||
The JWKS to use when calling the introspection endpoint,
|
||||
when using the private_key_jwt client auth method.
|
||||
"""
|
||||
|
||||
@client_auth_method.validator
|
||||
def _check_client_auth_method(
|
||||
self, attribute: attr.Attribute, value: ClientAuthMethod
|
||||
) -> None:
|
||||
# Check that the right client credentials are provided for the client auth method.
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if value == ClientAuthMethod.PRIVATE_KEY_JWT and self.jwk is None:
|
||||
raise ConfigError(
|
||||
"A JWKS must be provided when using the private_key_jwt client auth method",
|
||||
("experimental", "msc3861", "client_auth_method"),
|
||||
)
|
||||
|
||||
if (
|
||||
value
|
||||
in (
|
||||
ClientAuthMethod.CLIENT_SECRET_POST,
|
||||
ClientAuthMethod.CLIENT_SECRET_BASIC,
|
||||
ClientAuthMethod.CLIENT_SECRET_JWT,
|
||||
)
|
||||
and self.client_secret is None
|
||||
):
|
||||
raise ConfigError(
|
||||
f"A client secret must be provided when using the {value} client auth method",
|
||||
("experimental", "msc3861", "client_auth_method"),
|
||||
)
|
||||
|
||||
account_management_url: Optional[str] = attr.ib(
|
||||
default=None,
|
||||
validator=attr.validators.optional(attr.validators.instance_of(str)),
|
||||
)
|
||||
"""The URL of the My Account page on the OIDC Provider as per MSC2965."""
|
||||
|
||||
admin_token: Optional[str] = attr.ib(
|
||||
default=None,
|
||||
validator=attr.validators.optional(attr.validators.instance_of(str)),
|
||||
)
|
||||
"""
|
||||
A token that should be considered as an admin token.
|
||||
This is used by the OIDC provider, to make admin calls to Synapse.
|
||||
"""
|
||||
|
||||
def check_config_conflicts(self, root: RootConfig) -> None:
|
||||
"""Checks for any configuration conflicts with other parts of Synapse.
|
||||
|
||||
Raises:
|
||||
ConfigError: If there are any configuration conflicts.
|
||||
"""
|
||||
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
if (
|
||||
root.auth.password_enabled_for_reauth
|
||||
or root.auth.password_enabled_for_login
|
||||
):
|
||||
raise ConfigError(
|
||||
"Password auth cannot be enabled when OAuth delegation is enabled",
|
||||
("password_config", "enabled"),
|
||||
)
|
||||
|
||||
if root.registration.enable_registration:
|
||||
raise ConfigError(
|
||||
"Registration cannot be enabled when OAuth delegation is enabled",
|
||||
("enable_registration",),
|
||||
)
|
||||
|
||||
if (
|
||||
root.oidc.oidc_enabled
|
||||
or root.saml2.saml2_enabled
|
||||
or root.cas.cas_enabled
|
||||
or root.jwt.jwt_enabled
|
||||
):
|
||||
raise ConfigError("SSO cannot be enabled when OAuth delegation is enabled")
|
||||
|
||||
if bool(root.authproviders.password_providers):
|
||||
raise ConfigError(
|
||||
"Password auth providers cannot be enabled when OAuth delegation is enabled"
|
||||
)
|
||||
|
||||
if root.captcha.enable_registration_captcha:
|
||||
raise ConfigError(
|
||||
"CAPTCHA cannot be enabled when OAuth delegation is enabled",
|
||||
("captcha", "enable_registration_captcha"),
|
||||
)
|
||||
|
||||
if root.auth.login_via_existing_enabled:
|
||||
raise ConfigError(
|
||||
"Login via existing session cannot be enabled when OAuth delegation is enabled",
|
||||
("login_via_existing_session", "enabled"),
|
||||
)
|
||||
|
||||
if root.registration.refresh_token_lifetime:
|
||||
raise ConfigError(
|
||||
"refresh_token_lifetime cannot be set when OAuth delegation is enabled",
|
||||
("refresh_token_lifetime",),
|
||||
)
|
||||
|
||||
if root.registration.nonrefreshable_access_token_lifetime:
|
||||
raise ConfigError(
|
||||
"nonrefreshable_access_token_lifetime cannot be set when OAuth delegation is enabled",
|
||||
("nonrefreshable_access_token_lifetime",),
|
||||
)
|
||||
|
||||
if root.registration.session_lifetime:
|
||||
raise ConfigError(
|
||||
"session_lifetime cannot be set when OAuth delegation is enabled",
|
||||
("session_lifetime",),
|
||||
)
|
||||
|
||||
if not root.experimental.msc3970_enabled:
|
||||
raise ConfigError(
|
||||
"experimental_features.msc3970_enabled must be 'true' when OAuth delegation is enabled",
|
||||
("experimental_features", "msc3970_enabled"),
|
||||
)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class MSC3866Config:
|
||||
@@ -118,13 +319,6 @@ class ExperimentalConfig(Config):
|
||||
# MSC3881: Remotely toggle push notifications for another client
|
||||
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
|
||||
|
||||
# MSC3882: Allow an existing session to sign in a new session
|
||||
self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False)
|
||||
self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True)
|
||||
self.msc3882_token_timeout = self.parse_duration(
|
||||
experimental.get("msc3882_token_timeout", "5m")
|
||||
)
|
||||
|
||||
# MSC3874: Filtering /messages with rel_types / not_rel_types.
|
||||
self.msc3874_enabled: bool = experimental.get("msc3874_enabled", False)
|
||||
|
||||
@@ -182,8 +376,19 @@ class ExperimentalConfig(Config):
|
||||
"msc3981_recurse_relations", False
|
||||
)
|
||||
|
||||
# MSC3861: Matrix architecture change to delegate authentication via OIDC
|
||||
try:
|
||||
self.msc3861 = MSC3861(**experimental.get("msc3861", {}))
|
||||
except ValueError as exc:
|
||||
raise ConfigError(
|
||||
"Invalid MSC3861 configuration", ("experimental", "msc3861")
|
||||
) from exc
|
||||
|
||||
# MSC3970: Scope transaction IDs to devices
|
||||
self.msc3970_enabled = experimental.get("msc3970_enabled", False)
|
||||
self.msc3970_enabled = experimental.get("msc3970_enabled", self.msc3861.enabled)
|
||||
|
||||
# Check that none of the other config options conflict with MSC3861 when enabled
|
||||
self.msc3861.check_config_conflicts(self.root)
|
||||
|
||||
# MSC4009: E.164 Matrix IDs
|
||||
self.msc4009_e164_mxids = experimental.get("msc4009_e164_mxids", False)
|
||||
|
||||
@@ -515,7 +515,7 @@ class FederationServer(FederationBase):
|
||||
logger.error(
|
||||
"Failed to handle PDU %s",
|
||||
event_id,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
return {"error": str(e)}
|
||||
|
||||
@@ -1247,7 +1247,7 @@ class FederationServer(FederationBase):
|
||||
logger.error(
|
||||
"Failed to handle PDU %s",
|
||||
event.event_id,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
|
||||
received_ts = await self.store.remove_received_event_from_staging(
|
||||
@@ -1291,9 +1291,6 @@ class FederationServer(FederationBase):
|
||||
return
|
||||
lock = new_lock
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
async def exchange_third_party_invite(
|
||||
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
|
||||
) -> None:
|
||||
|
||||
@@ -164,7 +164,7 @@ class AccountValidityHandler:
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
UserID.from_string(user_id)
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
|
||||
@@ -89,7 +89,7 @@ class AdminHandler:
|
||||
}
|
||||
|
||||
# Add additional user metadata
|
||||
profile = await self._store.get_profileinfo(user.localpart)
|
||||
profile = await self._store.get_profileinfo(user)
|
||||
threepids = await self._store.user_get_threepids(user.to_string())
|
||||
external_ids = [
|
||||
({"auth_provider": auth_provider, "external_id": external_id})
|
||||
|
||||
@@ -274,6 +274,8 @@ class AuthHandler:
|
||||
# response.
|
||||
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
|
||||
|
||||
self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
|
||||
|
||||
async def validate_user_via_ui_auth(
|
||||
self,
|
||||
requester: Requester,
|
||||
@@ -322,8 +324,12 @@ class AuthHandler:
|
||||
|
||||
LimitExceededError if the ratelimiter's failed request count for this
|
||||
user is too high to proceed
|
||||
|
||||
"""
|
||||
if self.msc3861_oauth_delegation_enabled:
|
||||
raise SynapseError(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR, "UIA shouldn't be used with MSC3861"
|
||||
)
|
||||
|
||||
if not requester.access_token_id:
|
||||
raise ValueError("Cannot validate a user without an access token")
|
||||
if can_skip_ui_auth and self._ui_auth_session_timeout:
|
||||
@@ -1753,7 +1759,7 @@ class AuthHandler:
|
||||
return
|
||||
|
||||
user_profile_data = await self.store.get_profileinfo(
|
||||
UserID.from_string(registered_user_id).localpart
|
||||
UserID.from_string(registered_user_id)
|
||||
)
|
||||
|
||||
# Store any extra attributes which will be passed in the login response.
|
||||
|
||||
@@ -297,5 +297,5 @@ class DeactivateAccountHandler:
|
||||
# Add the user to the directory, if necessary. Note that
|
||||
# this must be done after the user is re-activated, because
|
||||
# deactivated users are excluded from the user directory.
|
||||
profile = await self.store.get_profileinfo(user.localpart)
|
||||
profile = await self.store.get_profileinfo(user)
|
||||
await self.user_directory_handler.handle_local_profile_change(user_id, profile)
|
||||
|
||||
@@ -1354,7 +1354,7 @@ class OidcProvider:
|
||||
finish_request(request)
|
||||
|
||||
|
||||
class LogoutToken(JWTClaims):
|
||||
class LogoutToken(JWTClaims): # type: ignore[misc]
|
||||
"""
|
||||
Holds and verify claims of a logout token, as per
|
||||
https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
|
||||
|
||||
@@ -360,7 +360,7 @@ class PaginationHandler:
|
||||
except Exception:
|
||||
f = Failure()
|
||||
logger.error(
|
||||
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
|
||||
"[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
|
||||
)
|
||||
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
|
||||
self._purges_by_id[purge_id].error = f.getErrorMessage()
|
||||
@@ -689,7 +689,7 @@ class PaginationHandler:
|
||||
f = Failure()
|
||||
logger.error(
|
||||
"failed",
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_FAILED
|
||||
self._delete_by_id[delete_id].error = f.getErrorMessage()
|
||||
|
||||
@@ -67,7 +67,7 @@ class ProfileHandler:
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
||||
if self.hs.is_mine(target_user):
|
||||
profileinfo = await self.store.get_profileinfo(target_user.localpart)
|
||||
profileinfo = await self.store.get_profileinfo(target_user)
|
||||
if profileinfo.display_name is None:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
||||
@@ -99,9 +99,7 @@ class ProfileHandler:
|
||||
async def get_displayname(self, target_user: UserID) -> Optional[str]:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
displayname = await self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
displayname = await self.store.get_profile_displayname(target_user)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
@@ -147,7 +145,7 @@ class ProfileHandler:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if not by_admin and not self.hs.config.registration.enable_set_displayname:
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
if profile.display_name:
|
||||
raise SynapseError(
|
||||
400,
|
||||
@@ -180,7 +178,7 @@ class ProfileHandler:
|
||||
|
||||
await self.store.set_profile_displayname(target_user, displayname_to_set)
|
||||
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
@@ -194,9 +192,7 @@ class ProfileHandler:
|
||||
async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
avatar_url = await self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
avatar_url = await self.store.get_profile_avatar_url(target_user)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
@@ -241,7 +237,7 @@ class ProfileHandler:
|
||||
raise AuthError(400, "Cannot set another user's avatar_url")
|
||||
|
||||
if not by_admin and not self.hs.config.registration.enable_set_avatar_url:
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
if profile.avatar_url:
|
||||
raise SynapseError(
|
||||
400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
|
||||
@@ -272,7 +268,7 @@ class ProfileHandler:
|
||||
|
||||
await self.store.set_profile_avatar_url(target_user, avatar_url_to_set)
|
||||
|
||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||
profile = await self.store.get_profileinfo(target_user)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
target_user.to_string(), profile
|
||||
)
|
||||
@@ -369,14 +365,10 @@ class ProfileHandler:
|
||||
response = {}
|
||||
try:
|
||||
if just_field is None or just_field == "displayname":
|
||||
response["displayname"] = await self.store.get_profile_displayname(
|
||||
user.localpart
|
||||
)
|
||||
response["displayname"] = await self.store.get_profile_displayname(user)
|
||||
|
||||
if just_field is None or just_field == "avatar_url":
|
||||
response["avatar_url"] = await self.store.get_profile_avatar_url(
|
||||
user.localpart
|
||||
)
|
||||
response["avatar_url"] = await self.store.get_profile_avatar_url(user)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
|
||||
@@ -315,7 +315,7 @@ class RegistrationHandler:
|
||||
approved=approved,
|
||||
)
|
||||
|
||||
profile = await self.store.get_profileinfo(localpart)
|
||||
profile = await self.store.get_profileinfo(user)
|
||||
await self.user_directory_handler.handle_local_profile_change(
|
||||
user_id, profile
|
||||
)
|
||||
|
||||
@@ -205,16 +205,22 @@ class RelationsHandler:
|
||||
event_id: The event IDs to look and redact relations of.
|
||||
initial_redaction_event: The redaction for the event referred to by
|
||||
event_id.
|
||||
relation_types: The types of relations to look for.
|
||||
relation_types: The types of relations to look for. If "*" is in the list,
|
||||
all related events will be redacted regardless of the type.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned
|
||||
"""
|
||||
related_event_ids = (
|
||||
await self._main_store.get_all_relations_for_event_with_types(
|
||||
event_id, relation_types
|
||||
if "*" in relation_types:
|
||||
related_event_ids = await self._main_store.get_all_relations_for_event(
|
||||
event_id
|
||||
)
|
||||
else:
|
||||
related_event_ids = (
|
||||
await self._main_store.get_all_relations_for_event_with_types(
|
||||
event_id, relation_types
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for related_event_id in related_event_ids:
|
||||
try:
|
||||
|
||||
@@ -108,9 +108,12 @@ def return_json_error(
|
||||
|
||||
if f.check(SynapseError):
|
||||
# mypy doesn't understand that f.check asserts the type.
|
||||
exc: SynapseError = f.value # type: ignore
|
||||
exc: SynapseError = f.value
|
||||
error_code = exc.code
|
||||
error_dict = exc.error_dict(config)
|
||||
if exc.headers is not None:
|
||||
for header, value in exc.headers.items():
|
||||
request.setHeader(header, value)
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
|
||||
elif f.check(CancelledError):
|
||||
error_code = HTTP_STATUS_REQUEST_CANCELLED
|
||||
@@ -121,7 +124,7 @@ def return_json_error(
|
||||
"Got cancellation before client disconnection from %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
else:
|
||||
error_code = 500
|
||||
@@ -131,7 +134,7 @@ def return_json_error(
|
||||
"Failed handle request via %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
|
||||
# Only respond with an error response if we haven't already started writing,
|
||||
@@ -169,9 +172,12 @@ def return_html_error(
|
||||
"""
|
||||
if f.check(CodeMessageException):
|
||||
# mypy doesn't understand that f.check asserts the type.
|
||||
cme: CodeMessageException = f.value # type: ignore
|
||||
cme: CodeMessageException = f.value
|
||||
code = cme.code
|
||||
msg = cme.msg
|
||||
if cme.headers is not None:
|
||||
for header, value in cme.headers.items():
|
||||
request.setHeader(header, value)
|
||||
|
||||
if isinstance(cme, RedirectException):
|
||||
logger.info("%s redirect to %s", request, cme.location)
|
||||
@@ -183,7 +189,7 @@ def return_html_error(
|
||||
logger.error(
|
||||
"Failed handle request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
elif f.check(CancelledError):
|
||||
code = HTTP_STATUS_REQUEST_CANCELLED
|
||||
@@ -193,7 +199,7 @@ def return_html_error(
|
||||
logger.error(
|
||||
"Got cancellation before client disconnection when handling request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
else:
|
||||
code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
@@ -202,7 +208,7 @@ def return_html_error(
|
||||
logger.error(
|
||||
"Failed handle request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore[arg-type]
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
|
||||
if isinstance(error_template, str):
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
import html
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional, cast
|
||||
|
||||
import attr
|
||||
|
||||
@@ -98,7 +98,7 @@ class OEmbedProvider:
|
||||
# No match.
|
||||
return None
|
||||
|
||||
def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]:
|
||||
def autodiscover_from_html(self, tree: "etree._Element") -> Optional[str]:
|
||||
"""
|
||||
Search an HTML document for oEmbed autodiscovery information.
|
||||
|
||||
@@ -109,18 +109,22 @@ class OEmbedProvider:
|
||||
The URL to use for oEmbed information, or None if no URL was found.
|
||||
"""
|
||||
# Search for link elements with the proper rel and type attributes.
|
||||
for tag in tree.xpath(
|
||||
"//link[@rel='alternate'][@type='application/json+oembed']"
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
for tag in cast(
|
||||
List["etree._Element"],
|
||||
tree.xpath("//link[@rel='alternate'][@type='application/json+oembed']"),
|
||||
):
|
||||
if "href" in tag.attrib:
|
||||
return tag.attrib["href"]
|
||||
return cast(str, tag.attrib["href"])
|
||||
|
||||
# Some providers (e.g. Flickr) use alternative instead of alternate.
|
||||
for tag in tree.xpath(
|
||||
"//link[@rel='alternative'][@type='application/json+oembed']"
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
for tag in cast(
|
||||
List["etree._Element"],
|
||||
tree.xpath("//link[@rel='alternative'][@type='application/json+oembed']"),
|
||||
):
|
||||
if "href" in tag.attrib:
|
||||
return tag.attrib["href"]
|
||||
return cast(str, tag.attrib["href"])
|
||||
|
||||
return None
|
||||
|
||||
@@ -212,11 +216,12 @@ class OEmbedProvider:
|
||||
return OEmbedResult(open_graph_response, author_name, cache_age)
|
||||
|
||||
|
||||
def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]:
|
||||
def _fetch_urls(tree: "etree._Element", tag_name: str) -> List[str]:
|
||||
results = []
|
||||
for tag in tree.xpath("//*/" + tag_name):
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
for tag in cast(List["etree._Element"], tree.xpath("//*/" + tag_name)):
|
||||
if "src" in tag.attrib:
|
||||
results.append(tag.attrib["src"])
|
||||
results.append(cast(str, tag.attrib["src"]))
|
||||
return results
|
||||
|
||||
|
||||
@@ -244,11 +249,12 @@ def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) ->
|
||||
parser = etree.HTMLParser(recover=True, encoding="utf-8")
|
||||
|
||||
# Attempt to parse the body. If this fails, log and return no metadata.
|
||||
tree = etree.fromstring(html_body, parser)
|
||||
# TODO Develop of lxml-stubs has this correct.
|
||||
tree = etree.fromstring(html_body, parser) # type: ignore[arg-type]
|
||||
|
||||
# The data was successfully parsed, but no tree was found.
|
||||
if tree is None:
|
||||
return
|
||||
return # type: ignore[unreachable]
|
||||
|
||||
# Attempt to find interesting URLs (images, videos, embeds).
|
||||
if "og:image" not in open_graph_response:
|
||||
|
||||
@@ -24,6 +24,7 @@ from typing import (
|
||||
Optional,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -115,7 +116,7 @@ def _get_html_media_encodings(
|
||||
|
||||
def decode_body(
|
||||
body: bytes, uri: str, content_type: Optional[str] = None
|
||||
) -> Optional["etree.Element"]:
|
||||
) -> Optional["etree._Element"]:
|
||||
"""
|
||||
This uses lxml to parse the HTML document.
|
||||
|
||||
@@ -152,11 +153,12 @@ def decode_body(
|
||||
|
||||
# Attempt to parse the body. Returns None if the body was successfully
|
||||
# parsed, but no tree was found.
|
||||
return etree.fromstring(body, parser)
|
||||
# TODO Develop of lxml-stubs has this correct.
|
||||
return etree.fromstring(body, parser) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def _get_meta_tags(
|
||||
tree: "etree.Element",
|
||||
tree: "etree._Element",
|
||||
property: str,
|
||||
prefix: str,
|
||||
property_mapper: Optional[Callable[[str], Optional[str]]] = None,
|
||||
@@ -175,9 +177,15 @@ def _get_meta_tags(
|
||||
Returns:
|
||||
A map of tag name to value.
|
||||
"""
|
||||
# This actually returns Dict[str, str], but the caller sets this as a variable
|
||||
# which is Dict[str, Optional[str]].
|
||||
results: Dict[str, Optional[str]] = {}
|
||||
for tag in tree.xpath(
|
||||
f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]"
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
for tag in cast(
|
||||
List["etree._Element"],
|
||||
tree.xpath(
|
||||
f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]"
|
||||
),
|
||||
):
|
||||
# if we've got more than 50 tags, someone is taking the piss
|
||||
if len(results) >= 50:
|
||||
@@ -187,14 +195,15 @@ def _get_meta_tags(
|
||||
)
|
||||
return {}
|
||||
|
||||
key = tag.attrib[property]
|
||||
key = cast(str, tag.attrib[property])
|
||||
if property_mapper:
|
||||
key = property_mapper(key)
|
||||
new_key = property_mapper(key)
|
||||
# None is a special value used to ignore a value.
|
||||
if key is None:
|
||||
if new_key is None:
|
||||
continue
|
||||
key = new_key
|
||||
|
||||
results[key] = tag.attrib["content"]
|
||||
results[key] = cast(str, tag.attrib["content"])
|
||||
|
||||
return results
|
||||
|
||||
@@ -219,7 +228,7 @@ def _map_twitter_to_open_graph(key: str) -> Optional[str]:
|
||||
return "og" + key[7:]
|
||||
|
||||
|
||||
def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
def parse_html_to_open_graph(tree: "etree._Element") -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Parse the HTML document into an Open Graph response.
|
||||
|
||||
@@ -276,24 +285,36 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
|
||||
if "og:title" not in og:
|
||||
# Attempt to find a title from the title tag, or the biggest header on the page.
|
||||
title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()")
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
title = cast(
|
||||
List["etree._ElementUnicodeResult"],
|
||||
tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()"),
|
||||
)
|
||||
if title:
|
||||
og["og:title"] = title[0].strip()
|
||||
else:
|
||||
og["og:title"] = None
|
||||
|
||||
if "og:image" not in og:
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]"
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
meta_image = cast(
|
||||
List["etree._ElementUnicodeResult"],
|
||||
tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]"
|
||||
),
|
||||
)
|
||||
# If a meta image is found, use it.
|
||||
if meta_image:
|
||||
og["og:image"] = meta_image[0]
|
||||
else:
|
||||
# Try to find images which are larger than 10px by 10px.
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
#
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = cast(
|
||||
List["etree._Element"],
|
||||
tree.xpath("//img[@src][number(@width)>10][number(@height)>10]"),
|
||||
)
|
||||
images = sorted(
|
||||
images,
|
||||
key=lambda i: (
|
||||
@@ -302,20 +323,29 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
)
|
||||
# If no images were found, try to find *any* images.
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src][1]")
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
images = cast(List["etree._Element"], tree.xpath("//img[@src][1]"))
|
||||
if images:
|
||||
og["og:image"] = images[0].attrib["src"]
|
||||
og["og:image"] = cast(str, images[0].attrib["src"])
|
||||
|
||||
# Finally, fallback to the favicon if nothing else.
|
||||
else:
|
||||
favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]")
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
favicons = cast(
|
||||
List["etree._ElementUnicodeResult"],
|
||||
tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]"),
|
||||
)
|
||||
if favicons:
|
||||
og["og:image"] = favicons[0]
|
||||
|
||||
if "og:description" not in og:
|
||||
# Check the first meta description tag for content.
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]"
|
||||
# Cast: the type returned by xpath depends on the xpath expression: mypy can't deduce this.
|
||||
meta_description = cast(
|
||||
List["etree._ElementUnicodeResult"],
|
||||
tree.xpath(
|
||||
"//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]"
|
||||
),
|
||||
)
|
||||
# If a meta description is found with content, use it.
|
||||
if meta_description:
|
||||
@@ -332,7 +362,7 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
return og
|
||||
|
||||
|
||||
def parse_html_description(tree: "etree.Element") -> Optional[str]:
|
||||
def parse_html_description(tree: "etree._Element") -> Optional[str]:
|
||||
"""
|
||||
Calculate a text description based on an HTML document.
|
||||
|
||||
@@ -368,6 +398,9 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]:
|
||||
"canvas",
|
||||
"img",
|
||||
"picture",
|
||||
# etree.Comment is a function which creates an etree._Comment element.
|
||||
# The "tag" attribute of an etree._Comment instance is confusingly the
|
||||
# etree.Comment function instead of a string.
|
||||
etree.Comment,
|
||||
}
|
||||
|
||||
@@ -381,8 +414,8 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]:
|
||||
|
||||
|
||||
def _iterate_over_text(
|
||||
tree: Optional["etree.Element"],
|
||||
tags_to_ignore: Set[Union[str, "etree.Comment"]],
|
||||
tree: Optional["etree._Element"],
|
||||
tags_to_ignore: Set[object],
|
||||
stack_limit: int = 1024,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||
@@ -402,7 +435,7 @@ def _iterate_over_text(
|
||||
|
||||
# This is a stack whose items are elements to iterate over *or* strings
|
||||
# to be returned.
|
||||
elements: List[Union[str, "etree.Element"]] = [tree]
|
||||
elements: List[Union[str, "etree._Element"]] = [tree]
|
||||
while elements:
|
||||
el = elements.pop()
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.presence_router import (
|
||||
GET_INTERESTED_USERS_CALLBACK,
|
||||
@@ -121,6 +122,7 @@ from synapse.types import (
|
||||
JsonMapping,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
StateMap,
|
||||
UserID,
|
||||
UserInfo,
|
||||
@@ -252,6 +254,7 @@ class ModuleApi:
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self.custom_template_dir = hs.config.server.custom_template_directory
|
||||
self._callbacks = hs.get_module_api_callbacks()
|
||||
self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
|
||||
|
||||
try:
|
||||
app_name = self._hs.config.email.email_app_name
|
||||
@@ -419,6 +422,11 @@ class ModuleApi:
|
||||
|
||||
Added in Synapse v1.46.0.
|
||||
"""
|
||||
if self.msc3861_oauth_delegation_enabled:
|
||||
raise ConfigError(
|
||||
"Cannot use password auth provider callbacks when OAuth delegation is enabled"
|
||||
)
|
||||
|
||||
return self._password_auth_provider.register_password_auth_provider_callbacks(
|
||||
check_3pid_auth=check_3pid_auth,
|
||||
on_logged_out=on_logged_out,
|
||||
@@ -647,7 +655,9 @@ class ModuleApi:
|
||||
Returns:
|
||||
The profile information (i.e. display name and avatar URL).
|
||||
"""
|
||||
return await self._store.get_profileinfo(localpart)
|
||||
server_name = self._hs.hostname
|
||||
user_id = UserID.from_string(f"@{localpart}:{server_name}")
|
||||
return await self._store.get_profileinfo(user_id)
|
||||
|
||||
async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
|
||||
"""Look up the threepids (email addresses and phone numbers) associated with the
|
||||
@@ -1563,6 +1573,32 @@ class ModuleApi:
|
||||
start_timestamp, end_timestamp
|
||||
)
|
||||
|
||||
async def get_canonical_room_alias(self, room_id: RoomID) -> Optional[RoomAlias]:
|
||||
"""
|
||||
Retrieve the given room's current canonical alias.
|
||||
|
||||
A room may declare an alias as "canonical", meaning that it is the
|
||||
preferred alias to use when referring to the room. This function
|
||||
retrieves that alias from the room's state.
|
||||
|
||||
Added in Synapse v1.86.0.
|
||||
|
||||
Args:
|
||||
room_id: The Room ID to find the alias of.
|
||||
|
||||
Returns:
|
||||
None if the room ID does not exist, or if the room exists but has no canonical alias.
|
||||
Otherwise, the parsed room alias.
|
||||
"""
|
||||
room_alias_str = (
|
||||
await self._storage_controllers.state.get_canonical_alias_for_room(
|
||||
room_id.to_string()
|
||||
)
|
||||
)
|
||||
if room_alias_str:
|
||||
return RoomAlias.from_string(room_alias_str)
|
||||
return None
|
||||
|
||||
async def lookup_room_alias(self, room_alias: str) -> Tuple[str, List[str]]:
|
||||
"""
|
||||
Get the room ID associated with a room alias.
|
||||
|
||||
@@ -247,7 +247,7 @@ class Mailer:
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
UserID.from_string(user_id)
|
||||
)
|
||||
if user_display_name is None:
|
||||
user_display_name = user_id
|
||||
|
||||
@@ -257,9 +257,11 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
VersionServlet(hs).register(http_server)
|
||||
UserAdminServlet(hs).register(http_server)
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
UserAdminServlet(hs).register(http_server)
|
||||
UserMembershipRestServlet(hs).register(http_server)
|
||||
UserTokenRestServlet(hs).register(http_server)
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
UserTokenRestServlet(hs).register(http_server)
|
||||
UserRestServletV2(hs).register(http_server)
|
||||
UsersRestServletV2(hs).register(http_server)
|
||||
UserMediaStatisticsRestServlet(hs).register(http_server)
|
||||
@@ -274,9 +276,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
RoomEventContextServlet(hs).register(http_server)
|
||||
RateLimitRestServlet(hs).register(http_server)
|
||||
UsernameAvailableRestServlet(hs).register(http_server)
|
||||
ListRegistrationTokensRestServlet(hs).register(http_server)
|
||||
NewRegistrationTokenRestServlet(hs).register(http_server)
|
||||
RegistrationTokenRestServlet(hs).register(http_server)
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
ListRegistrationTokensRestServlet(hs).register(http_server)
|
||||
NewRegistrationTokenRestServlet(hs).register(http_server)
|
||||
RegistrationTokenRestServlet(hs).register(http_server)
|
||||
DestinationMembershipRestServlet(hs).register(http_server)
|
||||
DestinationResetConnectionRestServlet(hs).register(http_server)
|
||||
DestinationRestServlet(hs).register(http_server)
|
||||
@@ -306,10 +309,12 @@ def register_servlets_for_client_rest_resource(
|
||||
# The following resources can only be run on the main process.
|
||||
if hs.config.worker.worker_app is None:
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
ResetPasswordRestServlet(hs).register(http_server)
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
ResetPasswordRestServlet(hs).register(http_server)
|
||||
SearchUsersRestServlet(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
|
||||
# Load the media repo ones if we're using them. Otherwise load the servlets which
|
||||
# don't need a media repo (typically readonly admin APIs).
|
||||
|
||||
@@ -71,6 +71,7 @@ class UsersRestServletV2(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
|
||||
self._msc3861_enabled = hs.config.experimental.msc3861.enabled
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
@@ -94,7 +95,14 @@ class UsersRestServletV2(RestServlet):
|
||||
|
||||
user_id = parse_string(request, "user_id")
|
||||
name = parse_string(request, "name")
|
||||
|
||||
guests = parse_boolean(request, "guests", default=True)
|
||||
if self._msc3861_enabled and guests:
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"The guests parameter is not supported when MSC3861 is enabled.",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||
|
||||
# If support for MSC3866 is not enabled, apply no filtering based on the
|
||||
|
||||
@@ -27,6 +27,7 @@ from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
InteractiveAuthIncompleteError,
|
||||
NotFoundError,
|
||||
SynapseError,
|
||||
ThreepidValidationError,
|
||||
)
|
||||
@@ -600,6 +601,9 @@ class ThreepidRestServlet(RestServlet):
|
||||
# ThreePidBindRestServelet.PostBody with an `alias_generator` to handle
|
||||
# `threePidCreds` versus `three_pid_creds`.
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
if self.hs.config.experimental.msc3861.enabled:
|
||||
raise NotFoundError(errcode=Codes.UNRECOGNIZED)
|
||||
|
||||
if not self.hs.config.registration.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
|
||||
@@ -890,19 +894,21 @@ class AccountStatusRestServlet(RestServlet):
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.worker.worker_app is None:
|
||||
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
PasswordRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
|
||||
AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
PasswordRestServlet(hs).register(http_server)
|
||||
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
|
||||
AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
|
||||
ThreepidRestServlet(hs).register(http_server)
|
||||
if hs.config.worker.worker_app is None:
|
||||
ThreepidAddRestServlet(hs).register(http_server)
|
||||
ThreepidBindRestServlet(hs).register(http_server)
|
||||
ThreepidUnbindRestServlet(hs).register(http_server)
|
||||
ThreepidDeleteRestServlet(hs).register(http_server)
|
||||
if not hs.config.experimental.msc3861.enabled:
|
||||
ThreepidAddRestServlet(hs).register(http_server)
|
||||
ThreepidDeleteRestServlet(hs).register(http_server)
|
||||
WhoamiRestServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.worker.worker_app is None and hs.config.experimental.msc3720_enabled:
|
||||
|
||||
@@ -65,6 +65,9 @@ class CapabilitiesRestServlet(RestServlet):
|
||||
"m.3pid_changes": {
|
||||
"enabled": self.config.registration.enable_3pid_changes
|
||||
},
|
||||
"m.get_login_token": {
|
||||
"enabled": self.config.auth.login_via_existing_enabled,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from pydantic import Extra, StrictStr
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.api.errors import NotFoundError, UnrecognizedRequestError
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
@@ -135,6 +135,7 @@ class DeviceRestServlet(RestServlet):
|
||||
self.device_handler = handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
|
||||
self._msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
@@ -166,6 +167,9 @@ class DeviceRestServlet(RestServlet):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
if self._msc3861_oauth_delegation_enabled:
|
||||
raise UnrecognizedRequestError(code=404)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
try:
|
||||
@@ -344,7 +348,10 @@ class ClaimDehydratedDeviceServlet(RestServlet):
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.worker.worker_app is None:
|
||||
if (
|
||||
hs.config.worker.worker_app is None
|
||||
and not hs.config.experimental.msc3861.enabled
|
||||
):
|
||||
DeleteDevicesRestServlet(hs).register(http_server)
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
if hs.config.worker.worker_app is None:
|
||||
|
||||
@@ -58,7 +58,7 @@ class GetFilterRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user_localpart=target_user.localpart, filter_id=filter_id_int
|
||||
user_id=target_user, filter_id=filter_id_int
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code != 404:
|
||||
|
||||
@@ -17,9 +17,10 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import Counter
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import InvalidAPICallError, SynapseError
|
||||
from synapse.api.errors import Codes, InvalidAPICallError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
@@ -375,9 +376,29 @@ class SigningKeyUploadServlet(RestServlet):
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if self.hs.config.experimental.msc3967_enabled:
|
||||
if await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id):
|
||||
# If we already have a master key then cross signing is set up and we require UIA to reset
|
||||
is_cross_signing_setup = (
|
||||
await self.e2e_keys_handler.is_cross_signing_set_up_for_user(user_id)
|
||||
)
|
||||
|
||||
# Before MSC3967 we required UIA both when setting up cross signing for the
|
||||
# first time and when resetting the device signing key. With MSC3967 we only
|
||||
# require UIA when resetting cross-signing, and not when setting up the first
|
||||
# time. Because there is no UIA in MSC3861, for now we throw an error if the
|
||||
# user tries to reset the device signing key when MSC3861 is enabled, but allow
|
||||
# first-time setup.
|
||||
if self.hs.config.experimental.msc3861.enabled:
|
||||
# There is no way to reset the device signing key with MSC3861
|
||||
if is_cross_signing_setup:
|
||||
raise SynapseError(
|
||||
HTTPStatus.NOT_IMPLEMENTED,
|
||||
"Resetting cross signing keys is not yet supported with MSC3861",
|
||||
Codes.UNRECOGNIZED,
|
||||
)
|
||||
# But first-time setup is fine
|
||||
|
||||
elif self.hs.config.experimental.msc3967_enabled:
|
||||
# If we already have a master key then cross signing is set up and we require UIA to reset
|
||||
if is_cross_signing_setup:
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester,
|
||||
request,
|
||||
@@ -387,6 +408,7 @@ class SigningKeyUploadServlet(RestServlet):
|
||||
can_skip_ui_auth=False,
|
||||
)
|
||||
# Otherwise we don't require UIA since we are setting up cross signing for first time
|
||||
|
||||
else:
|
||||
# Previous behaviour is to always require UIA but allow it to be skipped
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
|
||||
@@ -104,6 +104,9 @@ class LoginRestServlet(RestServlet):
|
||||
and hs.config.experimental.msc3866.require_approval_for_new_accounts
|
||||
)
|
||||
|
||||
# Whether get login token is enabled.
|
||||
self._get_login_token_enabled = hs.config.auth.login_via_existing_enabled
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
@@ -142,6 +145,9 @@ class LoginRestServlet(RestServlet):
|
||||
# to SSO.
|
||||
flows.append({"type": LoginRestServlet.CAS_TYPE})
|
||||
|
||||
# The login token flow requires m.login.token to be advertised.
|
||||
support_login_token_flow = self._get_login_token_enabled
|
||||
|
||||
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
|
||||
flows.append(
|
||||
{
|
||||
@@ -153,14 +159,23 @@ class LoginRestServlet(RestServlet):
|
||||
}
|
||||
)
|
||||
|
||||
# While it's valid for us to advertise this login type generally,
|
||||
# synapse currently only gives out these tokens as part of the
|
||||
# SSO login flow.
|
||||
# Generally we don't want to advertise login flows that clients
|
||||
# don't know how to implement, since they (currently) will always
|
||||
# fall back to the fallback API if they don't understand one of the
|
||||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
# SSO requires a login token to be generated, so we need to advertise that flow
|
||||
support_login_token_flow = True
|
||||
|
||||
# While it's valid for us to advertise this login type generally,
|
||||
# synapse currently only gives out these tokens as part of the
|
||||
# SSO login flow or as part of login via an existing session.
|
||||
#
|
||||
# Generally we don't want to advertise login flows that clients
|
||||
# don't know how to implement, since they (currently) will always
|
||||
# fall back to the fallback API if they don't understand one of the
|
||||
# login flow types returned.
|
||||
if support_login_token_flow:
|
||||
tokenTypeFlow: Dict[str, Any] = {"type": LoginRestServlet.TOKEN_TYPE}
|
||||
# If the login token flow is enabled advertise the get_login_token flag.
|
||||
if self._get_login_token_enabled:
|
||||
tokenTypeFlow["get_login_token"] = True
|
||||
flows.append(tokenTypeFlow)
|
||||
|
||||
flows.extend({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
|
||||
@@ -633,6 +648,9 @@ class CasTicketServlet(RestServlet):
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3861.enabled:
|
||||
return
|
||||
|
||||
LoginRestServlet(hs).register(http_server)
|
||||
if (
|
||||
hs.config.worker.worker_app is None
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
@@ -33,7 +34,7 @@ class LoginTokenRequestServlet(RestServlet):
|
||||
|
||||
Request:
|
||||
|
||||
POST /login/token HTTP/1.1
|
||||
POST /login/get_token HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{}
|
||||
@@ -43,30 +44,45 @@ class LoginTokenRequestServlet(RestServlet):
|
||||
HTTP/1.1 200 OK
|
||||
{
|
||||
"login_token": "ABDEFGH",
|
||||
"expires_in": 3600,
|
||||
"expires_in_ms": 3600000,
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
|
||||
)
|
||||
PATTERNS = [
|
||||
*client_patterns(
|
||||
"/login/get_token$", releases=["v1"], v1=False, unstable=False
|
||||
),
|
||||
# TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
|
||||
*client_patterns(
|
||||
"/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server.server_name
|
||||
self._main_store = hs.get_datastores().main
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.token_timeout = hs.config.experimental.msc3882_token_timeout
|
||||
self.ui_auth = hs.config.experimental.msc3882_ui_auth
|
||||
self.token_timeout = hs.config.auth.login_via_existing_token_timeout
|
||||
self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth
|
||||
|
||||
# Ratelimit aggressively to a maxmimum of 1 request per minute.
|
||||
#
|
||||
# This endpoint can be used to spawn additional sessions and could be
|
||||
# abused by a malicious client to create many sessions.
|
||||
self._ratelimiter = Ratelimiter(
|
||||
store=self._main_store,
|
||||
clock=hs.get_clock(),
|
||||
rate_hz=1 / 60,
|
||||
burst_count=1,
|
||||
)
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if self.ui_auth:
|
||||
if self._require_ui_auth:
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester,
|
||||
request,
|
||||
@@ -75,9 +91,12 @@ class LoginTokenRequestServlet(RestServlet):
|
||||
can_skip_ui_auth=False, # Don't allow skipping of UI auth
|
||||
)
|
||||
|
||||
# Ensure that this endpoint isn't being used too often. (Ensure this is
|
||||
# done *after* UI auth.)
|
||||
await self._ratelimiter.ratelimit(None, requester.user.to_string().lower())
|
||||
|
||||
login_token = await self.auth_handler.create_login_token_for_user_id(
|
||||
user_id=requester.user.to_string(),
|
||||
auth_provider_id="org.matrix.msc3882.login_token_request",
|
||||
duration_ms=self.token_timeout,
|
||||
)
|
||||
|
||||
@@ -85,11 +104,13 @@ class LoginTokenRequestServlet(RestServlet):
|
||||
200,
|
||||
{
|
||||
"login_token": login_token,
|
||||
# TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
|
||||
"expires_in": self.token_timeout // 1000,
|
||||
"expires_in_ms": self.token_timeout,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3882_enabled:
|
||||
if hs.config.auth.login_via_existing_enabled:
|
||||
LoginTokenRequestServlet(hs).register(http_server)
|
||||
|
||||
@@ -80,5 +80,8 @@ class LogoutAllRestServlet(RestServlet):
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3861.enabled:
|
||||
return
|
||||
|
||||
LogoutRestServlet(hs).register(http_server)
|
||||
LogoutAllRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -869,6 +869,74 @@ class RegisterRestServlet(RestServlet):
|
||||
return 200, result
|
||||
|
||||
|
||||
class RegisterAppServiceOnlyRestServlet(RestServlet):
|
||||
"""An alternative registration API endpoint that only allows ASes to register
|
||||
|
||||
This replaces the regular /register endpoint if MSC3861. There are two notable
|
||||
differences with the regular /register endpoint:
|
||||
- It only allows the `m.login.application_service` login type
|
||||
- It does not create a device or access token for the just-registered user
|
||||
|
||||
Note that the exact behaviour of this endpoint is not yet finalised. It should be
|
||||
just good enough to make most ASes work.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/register$")
|
||||
CATEGORY = "Registration/login requests"
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
client_addr = request.getClientAddress().host
|
||||
|
||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||
|
||||
kind = parse_string(request, "kind", default="user")
|
||||
|
||||
if kind == "guest":
|
||||
raise SynapseError(403, "Guest access is disabled")
|
||||
elif kind != "user":
|
||||
raise UnrecognizedRequestError(
|
||||
f"Do not understand membership kind: {kind}",
|
||||
)
|
||||
|
||||
# Pull out the provided username and do basic sanity checks early since
|
||||
# the auth layer will store these in sessions.
|
||||
desired_username = body.get("username")
|
||||
if not isinstance(desired_username, str) or len(desired_username) > 512:
|
||||
raise SynapseError(400, "Invalid username")
|
||||
|
||||
# Allow only ASes to use this API.
|
||||
if body.get("type") != APP_SERVICE_REGISTRATION_TYPE:
|
||||
raise SynapseError(403, "Non-application service registration type")
|
||||
|
||||
if not self.auth.has_access_token(request):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Appservice token must be provided when using a type of m.login.application_service",
|
||||
)
|
||||
|
||||
# XXX we should check that desired_username is valid. Currently
|
||||
# we give appservices carte blanche for any insanity in mxids,
|
||||
# because the IRC bridges rely on being able to register stupid
|
||||
# IDs.
|
||||
|
||||
as_token = self.auth.get_access_token_from_request(request)
|
||||
|
||||
user_id = await self.registration_handler.appservice_register(
|
||||
desired_username, as_token
|
||||
)
|
||||
return 200, {"user_id": user_id}
|
||||
|
||||
|
||||
def _calculate_registration_flows(
|
||||
config: HomeServerConfig, auth_handler: AuthHandler
|
||||
) -> List[List[str]]:
|
||||
@@ -955,6 +1023,10 @@ def _calculate_registration_flows(
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3861.enabled:
|
||||
RegisterAppServiceOnlyRestServlet(hs).register(http_server)
|
||||
return
|
||||
|
||||
if hs.config.worker.worker_app is None:
|
||||
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -178,7 +178,7 @@ class SyncRestServlet(RestServlet):
|
||||
else:
|
||||
try:
|
||||
filter_collection = await self.filtering.get_user_filter(
|
||||
user.localpart, filter_id
|
||||
user, filter_id
|
||||
)
|
||||
except StoreError as err:
|
||||
if err.code != 404:
|
||||
|
||||
@@ -113,8 +113,8 @@ class VersionsRestServlet(RestServlet):
|
||||
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
|
||||
# Adds a ping endpoint for appservices to check HS->AS connection
|
||||
"fi.mau.msc2659.stable": True, # TODO: remove when "v1.7" is added above
|
||||
# Adds support for login token requests as per MSC3882
|
||||
"org.matrix.msc3882": self.config.experimental.msc3882_enabled,
|
||||
# TODO: this is no longer needed once unstable MSC3882 does not need to be supported:
|
||||
"org.matrix.msc3882": self.config.auth.login_via_existing_enabled,
|
||||
# Adds support for remotely enabling/disabling pushers, as per MSC3881
|
||||
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
|
||||
# Adds support for filtering /messages by event relation.
|
||||
|
||||
@@ -46,6 +46,12 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc
|
||||
"/_synapse/client/unsubscribe": UnsubscribeResource(hs),
|
||||
}
|
||||
|
||||
# Expose the JWKS endpoint if OAuth2 delegation is enabled
|
||||
if hs.config.experimental.msc3861.enabled:
|
||||
from synapse.rest.synapse.client.jwks import JwksResource
|
||||
|
||||
resources["/_synapse/jwks"] = JwksResource(hs)
|
||||
|
||||
# provider-specific SSO bits. Only load these if they are enabled, since they
|
||||
# rely on optional dependencies.
|
||||
if hs.config.oidc.oidc_enabled:
|
||||
|
||||
70
synapse/rest/synapse/client/jwks.py
Normal file
70
synapse/rest/synapse/client/jwks.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JwksResource(DirectServeJsonResource):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(extract_context=True)
|
||||
|
||||
# Parameters that are allowed to be exposed in the public key.
|
||||
# This is done manually, because authlib's private to public key conversion
|
||||
# is unreliable depending on the version. Instead, we just serialize the private
|
||||
# key and only keep the public parameters.
|
||||
# List from https://www.iana.org/assignments/jose/jose.xhtml#web-key-parameters
|
||||
public_parameters = {
|
||||
"kty",
|
||||
"use",
|
||||
"key_ops",
|
||||
"alg",
|
||||
"kid",
|
||||
"x5u",
|
||||
"x5c",
|
||||
"x5t",
|
||||
"x5t#S256",
|
||||
"crv",
|
||||
"x",
|
||||
"y",
|
||||
"n",
|
||||
"e",
|
||||
"ext",
|
||||
}
|
||||
|
||||
key = hs.config.experimental.msc3861.jwk
|
||||
|
||||
if key is not None:
|
||||
private_key = key.as_dict()
|
||||
public_key = {
|
||||
k: v for k, v in private_key.items() if k in public_parameters
|
||||
}
|
||||
keys = [public_key]
|
||||
else:
|
||||
keys = []
|
||||
|
||||
self.res = {
|
||||
"keys": keys,
|
||||
}
|
||||
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
return 200, self.res
|
||||
@@ -44,6 +44,16 @@ class WellKnownBuilder:
|
||||
"base_url": self._config.registration.default_identity_server
|
||||
}
|
||||
|
||||
# We use the MSC3861 values as they are used by multiple MSCs
|
||||
if self._config.experimental.msc3861.enabled:
|
||||
result["org.matrix.msc2965.authentication"] = {
|
||||
"issuer": self._config.experimental.msc3861.issuer
|
||||
}
|
||||
if self._config.experimental.msc3861.account_management_url is not None:
|
||||
result["org.matrix.msc2965.authentication"][
|
||||
"account"
|
||||
] = self._config.experimental.msc3861.account_management_url
|
||||
|
||||
if self._config.server.extra_well_known_client_content:
|
||||
for (
|
||||
key,
|
||||
|
||||
@@ -31,6 +31,7 @@ from twisted.web.iweb import IPolicyForHTTPS
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.auth.internal import InternalAuth
|
||||
from synapse.api.auth_blocking import AuthBlocking
|
||||
from synapse.api.filtering import Filtering
|
||||
from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
|
||||
@@ -427,7 +428,11 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
|
||||
@cache_in_self
|
||||
def get_auth(self) -> Auth:
|
||||
return Auth(self)
|
||||
if self.config.experimental.msc3861.enabled:
|
||||
from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth
|
||||
|
||||
return MSC3861DelegatedAuth(self)
|
||||
return InternalAuth(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_auth_blocking(self) -> AuthBlocking:
|
||||
|
||||
@@ -485,7 +485,7 @@ class StateStorageController:
|
||||
if not event:
|
||||
return None
|
||||
|
||||
return event.content.get("canonical_alias")
|
||||
return event.content.get("alias")
|
||||
|
||||
@trace
|
||||
@tag_args
|
||||
|
||||
@@ -1941,6 +1941,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
user_id,
|
||||
stream_ids[-1],
|
||||
)
|
||||
txn.call_after(
|
||||
self._get_e2e_device_keys_for_federation_query_inner.invalidate,
|
||||
(user_id,),
|
||||
)
|
||||
|
||||
min_stream_id = stream_ids[0]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import abc
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
@@ -39,6 +40,7 @@ from synapse.appservice import (
|
||||
TransactionUnusedFallbackKeys,
|
||||
)
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.replication.tcp.streams._base import DeviceListsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
@@ -104,6 +106,23 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
self.hs.config.federation.allow_device_name_lookup_over_federation
|
||||
)
|
||||
|
||||
def process_replication_rows(
|
||||
self,
|
||||
stream_name: str,
|
||||
instance_name: str,
|
||||
token: int,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
for row in rows:
|
||||
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
|
||||
if row.entity.startswith("@"):
|
||||
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
|
||||
(row.entity,)
|
||||
)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
async def get_e2e_device_keys_for_federation_query(
|
||||
self, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
@@ -114,6 +133,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
"""
|
||||
now_stream_id = self.get_device_stream_token()
|
||||
|
||||
# We need to be careful with the caching here, as we need to always
|
||||
# return *all* persisted devices, however there may be a lag between a
|
||||
# new device being persisted and the cache being invalidated.
|
||||
cached_results = (
|
||||
self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate(
|
||||
user_id, None
|
||||
)
|
||||
)
|
||||
if cached_results is not None:
|
||||
# Check that there have been no new devices added by another worker
|
||||
# after the cache. This should be quick as there should be few rows
|
||||
# with a higher stream ordering.
|
||||
#
|
||||
# Note that we invalidate based on the device stream, so we only
|
||||
# have to check for potential invalidations after the
|
||||
# `now_stream_id`.
|
||||
sql = """
|
||||
SELECT user_id FROM device_lists_stream
|
||||
WHERE stream_id >= ? AND user_id = ?
|
||||
"""
|
||||
rows = await self.db_pool.execute(
|
||||
"get_e2e_device_keys_for_federation_query_check",
|
||||
None,
|
||||
sql,
|
||||
now_stream_id,
|
||||
user_id,
|
||||
)
|
||||
if not rows:
|
||||
# No new rows, so cache is still valid.
|
||||
return now_stream_id, cached_results
|
||||
|
||||
# There has, so let's invalidate the cache and run the query.
|
||||
self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,))
|
||||
|
||||
results = await self._get_e2e_device_keys_for_federation_query_inner(user_id)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
@cached(iterable=True)
|
||||
async def _get_e2e_device_keys_for_federation_query_inner(
|
||||
self, user_id: str
|
||||
) -> List[JsonDict]:
|
||||
"""Get all devices (with any device keys) for a user"""
|
||||
|
||||
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
|
||||
|
||||
if devices:
|
||||
@@ -134,9 +197,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
return results
|
||||
|
||||
return now_stream_id, []
|
||||
return []
|
||||
|
||||
@trace
|
||||
@cancellable
|
||||
|
||||
@@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
|
||||
@cached(num_args=2)
|
||||
async def get_user_filter(
|
||||
self, user_localpart: str, filter_id: Union[int, str]
|
||||
self, user_id: UserID, filter_id: Union[int, str]
|
||||
) -> JsonDict:
|
||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||
@@ -156,7 +156,7 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
|
||||
def_json = await self.db_pool.simple_select_one_onecol(
|
||||
table="user_filters",
|
||||
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
|
||||
keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
@@ -172,15 +172,15 @@ class FilteringWorkerStore(SQLBaseStore):
|
||||
def _do_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT filter_id FROM user_filters "
|
||||
"WHERE user_id = ? AND filter_json = ?"
|
||||
"WHERE full_user_id = ? AND filter_json = ?"
|
||||
)
|
||||
txn.execute(sql, (user_id.localpart, bytearray(def_json)))
|
||||
txn.execute(sql, (user_id.to_string(), bytearray(def_json)))
|
||||
filter_id_response = txn.fetchone()
|
||||
if filter_id_response is not None:
|
||||
return filter_id_response[0]
|
||||
|
||||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
|
||||
txn.execute(sql, (user_id.localpart,))
|
||||
sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?"
|
||||
txn.execute(sql, (user_id.to_string(),))
|
||||
max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
|
||||
if max_id is None:
|
||||
filter_id = 0
|
||||
|
||||
@@ -137,11 +137,11 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
|
||||
return 50
|
||||
|
||||
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
|
||||
async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
|
||||
try:
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
desc="get_profileinfo",
|
||||
)
|
||||
@@ -156,18 +156,18 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
|
||||
async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
|
||||
async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
|
||||
async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
keyvalues={"full_user_id": user_id.to_string()},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
@@ -365,6 +365,36 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
func=get_all_relation_ids_for_event_with_types_txn,
|
||||
)
|
||||
|
||||
async def get_all_relations_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
) -> List[str]:
|
||||
"""Get the event IDs of all events that have a relation to the given event.
|
||||
|
||||
Args:
|
||||
event_id: The event for which to look for related events.
|
||||
|
||||
Returns:
|
||||
A list of the IDs of the events that relate to the given event.
|
||||
"""
|
||||
|
||||
def get_all_relation_ids_for_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[str]:
|
||||
rows = self.db_pool.simple_select_list_txn(
|
||||
txn=txn,
|
||||
table="event_relations",
|
||||
keyvalues={"relates_to_id": event_id},
|
||||
retcols=["event_id"],
|
||||
)
|
||||
|
||||
return [row["event_id"] for row in rows]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="get_all_relation_ids_for_event",
|
||||
func=get_all_relation_ids_for_event_txn,
|
||||
)
|
||||
|
||||
async def event_includes_relation(self, event_id: str) -> bool:
|
||||
"""Check if the given event relates to another event.
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from synapse.storage.database import (
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.caches import intern_string
|
||||
@@ -328,6 +328,15 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
||||
columns=["event_stream_ordering"],
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"add_event_stream_ordering",
|
||||
self._add_event_stream_ordering,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"add_stream_ordering_triggers", self._add_triggers_in_bg
|
||||
)
|
||||
|
||||
async def _background_deduplicate_state(
|
||||
self, progress: dict, batch_size: int
|
||||
) -> int:
|
||||
@@ -504,3 +513,175 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
|
||||
)
|
||||
|
||||
return 1
|
||||
|
||||
async def _add_event_stream_ordering(self, progress: dict, batch_size: int) -> int:
|
||||
"""
|
||||
Add denormalised copies of `stream_ordering` from the corresponding row in `events`
|
||||
to the tables current_state_events, local_current_membership, and room_memberships.
|
||||
This is done to improve database performance by reduring JOINs.
|
||||
|
||||
"""
|
||||
tables = [
|
||||
"current_state_events",
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
]
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
|
||||
def check_pg_column(txn: LoggingTransaction, table: str) -> list:
|
||||
"""
|
||||
check if the column event_stream_ordering already exists
|
||||
"""
|
||||
check_sql = f"""
|
||||
SELECT column_name FROM information_schema.columns
|
||||
WHERE table_name = '{table}' and column_name = 'event_stream_ordering';
|
||||
"""
|
||||
txn.execute(check_sql)
|
||||
column = txn.fetchall()
|
||||
return column
|
||||
|
||||
def add_pg_column(txn: LoggingTransaction, table: str) -> None:
|
||||
"""
|
||||
Add column event_stream_ordering to A given table
|
||||
"""
|
||||
add_column_sql = f"""
|
||||
ALTER TABLE {table} ADD COLUMN event_stream_ordering BIGINT;
|
||||
"""
|
||||
txn.execute(add_column_sql)
|
||||
|
||||
add_fk_sql = f"""
|
||||
ALTER TABLE {table} ADD CONSTRAINT event_stream_ordering_fkey
|
||||
FOREIGN KEY(event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID;
|
||||
"""
|
||||
txn.execute(add_fk_sql)
|
||||
|
||||
for table in tables:
|
||||
res = await self.db_pool.runInteraction(
|
||||
"check_column", check_pg_column, table
|
||||
)
|
||||
# if the column exists do nothing
|
||||
if not res:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_event_stream_ordering",
|
||||
add_pg_column,
|
||||
table,
|
||||
)
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"add_event_stream_ordering"
|
||||
)
|
||||
return 1
|
||||
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
|
||||
def check_sqlite_column(txn: LoggingTransaction, table: str) -> List[tuple]:
|
||||
"""
|
||||
Get table info (to see if column event_stream_ordering exists)
|
||||
"""
|
||||
check_sql = f"""
|
||||
PRAGMA table_info({table})
|
||||
"""
|
||||
txn.execute(check_sql)
|
||||
res = txn.fetchall()
|
||||
return res
|
||||
|
||||
def add_sqlite_column(txn: LoggingTransaction, table: str) -> None:
|
||||
"""
|
||||
Add column event_stream_ordering to given table
|
||||
"""
|
||||
add_column_sql = f"""
|
||||
ALTER TABLE {table} ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering);
|
||||
"""
|
||||
txn.execute(add_column_sql)
|
||||
|
||||
for table in tables:
|
||||
res = await self.db_pool.runInteraction(
|
||||
"check_column", check_sqlite_column, table
|
||||
)
|
||||
columns = [tup[1] for tup in res]
|
||||
|
||||
# if the column exists do nothing
|
||||
if "event_stream_ordering" not in columns:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_event_stream_ordering", add_sqlite_column, table
|
||||
)
|
||||
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"add_event_stream_ordering"
|
||||
)
|
||||
return 1
|
||||
|
||||
async def _add_triggers_in_bg(self, progress: dict, batch_size: int) -> int:
|
||||
"""
|
||||
Adds triggers to the room membership tables to enforce consistency
|
||||
"""
|
||||
# Complain if the `event_stream_ordering` in membership tables doesn't match
|
||||
# the `stream_ordering` row with the same `event_id` in `events`.
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
|
||||
def add_sqlite_triggers(txn: LoggingTransaction) -> None:
|
||||
for table in (
|
||||
"current_state_events",
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
):
|
||||
txn.execute(
|
||||
f"""
|
||||
CREATE TRIGGER IF NOT EXISTS {table}_bad_event_stream_ordering
|
||||
BEFORE INSERT ON {table}
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
SELECT RAISE(ABORT, 'Incorrect event_stream_ordering in {table}')
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM events
|
||||
WHERE events.event_id = NEW.event_id
|
||||
AND events.stream_ordering != NEW.event_stream_ordering
|
||||
);
|
||||
END;
|
||||
"""
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"add_sqlite_triggers", add_sqlite_triggers
|
||||
)
|
||||
elif isinstance(self.database_engine, PostgresEngine):
|
||||
|
||||
def add_pg_triggers(txn: LoggingTransaction) -> None:
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION check_event_stream_ordering() RETURNS trigger AS $BODY$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM events
|
||||
WHERE events.event_id = NEW.event_id
|
||||
AND events.stream_ordering != NEW.event_stream_ordering
|
||||
) THEN
|
||||
RAISE EXCEPTION 'Incorrect event_stream_ordering';
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$BODY$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
|
||||
for table in (
|
||||
"current_state_events",
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
):
|
||||
txn.execute(
|
||||
f"""
|
||||
CREATE TRIGGER check_event_stream_ordering BEFORE INSERT OR UPDATE ON {table}
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE check_event_stream_ordering()
|
||||
"""
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction("add_postgres_triggers", add_pg_triggers)
|
||||
else:
|
||||
raise NotImplementedError("Unknown database engine")
|
||||
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"add_stream_ordering_triggers"
|
||||
)
|
||||
return 1
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 77 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 78 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
@@ -103,6 +103,9 @@ Changes in SCHEMA_VERSION = 76:
|
||||
|
||||
Changes in SCHEMA_VERSION = 77
|
||||
- (Postgres) Add NOT VALID CHECK (full_user_id IS NOT NULL) to tables profiles and user_filters
|
||||
|
||||
Changes in SCHEMA_VERSION = 78
|
||||
- Validate check (full_user_id IS NOT NULL) on tables profiles and user_filters
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on)
|
||||
VALUES
|
||||
(7403, 'add_event_stream_ordering', '{}', 'replace_stream_ordering_column');
|
||||
@@ -1,29 +0,0 @@
|
||||
/* Copyright 2022 Beeper
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Each of these are denormalised copies of `stream_ordering` from the corresponding row in` events` which
|
||||
-- we use to improve database performance by reduring JOINs.
|
||||
|
||||
-- NOTE: these are set to NOT VALID to prevent locks while adding the column on large existing tables,
|
||||
-- which will be validated in a later migration. For all new/updated rows the FKEY will be checked.
|
||||
|
||||
ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT;
|
||||
ALTER TABLE current_state_events ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID;
|
||||
|
||||
ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT;
|
||||
ALTER TABLE local_current_membership ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID;
|
||||
|
||||
ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT;
|
||||
ALTER TABLE room_memberships ADD CONSTRAINT event_stream_ordering_fkey FOREIGN KEY (event_stream_ordering) REFERENCES events(stream_ordering) NOT VALID;
|
||||
@@ -1,23 +0,0 @@
|
||||
/* Copyright 2022 Beeper
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Each of these are denormalised copies of `stream_ordering` from the corresponding row in` events` which
|
||||
-- we use to improve database performance by reduring JOINs.
|
||||
|
||||
-- NOTE: sqlite does not support ADD CONSTRAINT so we add the new columns with FK constraint as-is
|
||||
|
||||
ALTER TABLE current_state_events ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering);
|
||||
ALTER TABLE local_current_membership ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering);
|
||||
ALTER TABLE room_memberships ADD COLUMN event_stream_ordering BIGINT REFERENCES events(stream_ordering);
|
||||
@@ -1,79 +0,0 @@
|
||||
# Copyright 2022 Beeper
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This migration adds triggers to the room membership tables to enforce consistency.
|
||||
Triggers cannot be expressed in .sql files, so we have to use a separate file.
|
||||
"""
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
|
||||
|
||||
def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
|
||||
# Complain if the `event_stream_ordering` in membership tables doesn't match
|
||||
# the `stream_ordering` row with the same `event_id` in `events`.
|
||||
if isinstance(database_engine, Sqlite3Engine):
|
||||
for table in (
|
||||
"current_state_events",
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
):
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TRIGGER IF NOT EXISTS {table}_bad_event_stream_ordering
|
||||
BEFORE INSERT ON {table}
|
||||
FOR EACH ROW
|
||||
BEGIN
|
||||
SELECT RAISE(ABORT, 'Incorrect event_stream_ordering in {table}')
|
||||
WHERE EXISTS (
|
||||
SELECT 1 FROM events
|
||||
WHERE events.event_id = NEW.event_id
|
||||
AND events.stream_ordering != NEW.event_stream_ordering
|
||||
);
|
||||
END;
|
||||
"""
|
||||
)
|
||||
elif isinstance(database_engine, PostgresEngine):
|
||||
cur.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION check_event_stream_ordering() RETURNS trigger AS $BODY$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM events
|
||||
WHERE events.event_id = NEW.event_id
|
||||
AND events.stream_ordering != NEW.event_stream_ordering
|
||||
) THEN
|
||||
RAISE EXCEPTION 'Incorrect event_stream_ordering';
|
||||
END IF;
|
||||
RETURN NEW;
|
||||
END;
|
||||
$BODY$ LANGUAGE plpgsql;
|
||||
"""
|
||||
)
|
||||
|
||||
for table in (
|
||||
"current_state_events",
|
||||
"local_current_membership",
|
||||
"room_memberships",
|
||||
):
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TRIGGER check_event_stream_ordering BEFORE INSERT OR UPDATE ON {table}
|
||||
FOR EACH ROW
|
||||
EXECUTE PROCEDURE check_event_stream_ordering()
|
||||
"""
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown database engine")
|
||||
@@ -0,0 +1,22 @@
|
||||
/* Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
|
||||
-- This migration adds triggers to the room membership tables to enforce consistency.
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on)
|
||||
VALUES
|
||||
(7404, 'add_stream_ordering_triggers', '{}', 'add_event_stream_ordering');
|
||||
@@ -13,8 +13,8 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json)
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json, depends_on)
|
||||
VALUES
|
||||
(7714, 'current_state_events_stream_ordering_idx', '{}'),
|
||||
(7714, 'local_current_membership_stream_ordering_idx', '{}'),
|
||||
(7714, 'room_memberships_stream_ordering_idx', '{}');
|
||||
(7714, 'current_state_events_stream_ordering_idx', '{}', 'add_event_stream_ordering'),
|
||||
(7714, 'local_current_membership_stream_ordering_idx', '{}', 'add_event_stream_ordering'),
|
||||
(7714, 'room_memberships_stream_ordering_idx', '{}', 'add_event_stream_ordering');
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
|
||||
|
||||
def run_upgrade(
|
||||
cur: LoggingTransaction,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
config: HomeServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Part 3 of a multi-step migration to drop the column `user_id` and replace it with
|
||||
`full_user_id`. See the database schema docs for more information on the full
|
||||
migration steps.
|
||||
"""
|
||||
hostname = config.server.server_name
|
||||
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
# check if the constraint can be validated
|
||||
check_sql = """
|
||||
SELECT user_id from profiles WHERE full_user_id IS NULL
|
||||
"""
|
||||
cur.execute(check_sql)
|
||||
res = cur.fetchall()
|
||||
|
||||
if res:
|
||||
# there are rows the background job missed, finish them here before we validate the constraint
|
||||
process_rows_sql = """
|
||||
UPDATE profiles
|
||||
SET full_user_id = '@' || user_id || ?
|
||||
WHERE user_id IN (
|
||||
SELECT user_id FROM profiles WHERE full_user_id IS NULL
|
||||
)
|
||||
"""
|
||||
cur.execute(process_rows_sql, (f":{hostname}",))
|
||||
|
||||
# Now we can validate
|
||||
validate_sql = """
|
||||
ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
|
||||
"""
|
||||
cur.execute(validate_sql)
|
||||
|
||||
else:
|
||||
# in SQLite we need to rewrite the table to add the constraint.
|
||||
# First drop any temporary table that might be here from a previous failed migration.
|
||||
cur.execute("DROP TABLE IF EXISTS temp_profiles")
|
||||
|
||||
create_sql = """
|
||||
CREATE TABLE temp_profiles (
|
||||
full_user_id text NOT NULL,
|
||||
user_id text,
|
||||
displayname text,
|
||||
avatar_url text,
|
||||
UNIQUE (full_user_id),
|
||||
UNIQUE (user_id)
|
||||
)
|
||||
"""
|
||||
cur.execute(create_sql)
|
||||
|
||||
copy_sql = """
|
||||
INSERT INTO temp_profiles (
|
||||
user_id,
|
||||
displayname,
|
||||
avatar_url,
|
||||
full_user_id)
|
||||
SELECT user_id, displayname, avatar_url, '@' || user_id || ':' || ? FROM profiles
|
||||
"""
|
||||
cur.execute(copy_sql, (f"{hostname}",))
|
||||
|
||||
drop_sql = """
|
||||
DROP TABLE profiles
|
||||
"""
|
||||
cur.execute(drop_sql)
|
||||
|
||||
rename_sql = """
|
||||
ALTER TABLE temp_profiles RENAME to profiles
|
||||
"""
|
||||
cur.execute(rename_sql)
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright 2023 The Matrix.org Foundation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
|
||||
|
||||
def run_upgrade(
|
||||
cur: LoggingTransaction,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
config: HomeServerConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Part 3 of a multi-step migration to drop the column `user_id` and replace it with
|
||||
`full_user_id`. See the database schema docs for more information on the full
|
||||
migration steps.
|
||||
"""
|
||||
hostname = config.server.server_name
|
||||
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
# check if the constraint can be validated
|
||||
check_sql = """
|
||||
SELECT user_id from user_filters WHERE full_user_id IS NULL
|
||||
"""
|
||||
cur.execute(check_sql)
|
||||
res = cur.fetchall()
|
||||
|
||||
if res:
|
||||
# there are rows the background job missed, finish them here before we validate constraint
|
||||
process_rows_sql = """
|
||||
UPDATE user_filters
|
||||
SET full_user_id = '@' || user_id || ?
|
||||
WHERE user_id IN (
|
||||
SELECT user_id FROM user_filters WHERE full_user_id IS NULL
|
||||
)
|
||||
"""
|
||||
cur.execute(process_rows_sql, (f":{hostname}",))
|
||||
|
||||
# Now we can validate
|
||||
validate_sql = """
|
||||
ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null
|
||||
"""
|
||||
cur.execute(validate_sql)
|
||||
|
||||
else:
|
||||
cur.execute("DROP TABLE IF EXISTS temp_user_filters")
|
||||
create_sql = """
|
||||
CREATE TABLE temp_user_filters (
|
||||
full_user_id text NOT NULL,
|
||||
user_id text NOT NULL,
|
||||
filter_id bigint NOT NULL,
|
||||
filter_json bytea NOT NULL,
|
||||
UNIQUE (full_user_id),
|
||||
UNIQUE (user_id)
|
||||
)
|
||||
"""
|
||||
cur.execute(create_sql)
|
||||
|
||||
index_sql = """
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS user_filters_unique ON
|
||||
temp_user_filters (user_id, filter_id)
|
||||
"""
|
||||
cur.execute(index_sql)
|
||||
|
||||
copy_sql = """
|
||||
INSERT INTO temp_user_filters (
|
||||
user_id,
|
||||
filter_id,
|
||||
filter_json,
|
||||
full_user_id)
|
||||
SELECT user_id, filter_id, filter_json, '@' || user_id || ':' || ? FROM user_filters
|
||||
"""
|
||||
cur.execute(copy_sql, (f"{hostname}",))
|
||||
|
||||
drop_sql = """
|
||||
DROP TABLE user_filters
|
||||
"""
|
||||
cur.execute(drop_sql)
|
||||
|
||||
rename_sql = """
|
||||
ALTER TABLE temp_user_filters RENAME to user_filters
|
||||
"""
|
||||
cur.execute(rename_sql)
|
||||
@@ -131,6 +131,7 @@ class Requester:
|
||||
user: "UserID"
|
||||
access_token_id: Optional[int]
|
||||
is_guest: bool
|
||||
scope: Set[str]
|
||||
shadow_banned: bool
|
||||
device_id: Optional[str]
|
||||
app_service: Optional["ApplicationService"]
|
||||
@@ -147,6 +148,7 @@ class Requester:
|
||||
"user_id": self.user.to_string(),
|
||||
"access_token_id": self.access_token_id,
|
||||
"is_guest": self.is_guest,
|
||||
"scope": list(self.scope),
|
||||
"shadow_banned": self.shadow_banned,
|
||||
"device_id": self.device_id,
|
||||
"app_server_id": self.app_service.id if self.app_service else None,
|
||||
@@ -175,6 +177,7 @@ class Requester:
|
||||
user=UserID.from_string(input["user_id"]),
|
||||
access_token_id=input["access_token_id"],
|
||||
is_guest=input["is_guest"],
|
||||
scope=set(input["scope"]),
|
||||
shadow_banned=input["shadow_banned"],
|
||||
device_id=input["device_id"],
|
||||
app_service=appservice,
|
||||
@@ -186,6 +189,7 @@ def create_requester(
|
||||
user_id: Union[str, "UserID"],
|
||||
access_token_id: Optional[int] = None,
|
||||
is_guest: bool = False,
|
||||
scope: StrCollection = (),
|
||||
shadow_banned: bool = False,
|
||||
device_id: Optional[str] = None,
|
||||
app_service: Optional["ApplicationService"] = None,
|
||||
@@ -199,6 +203,7 @@ def create_requester(
|
||||
access_token_id: *ID* of the access token used for this
|
||||
request, or None if it came via the appservice API or similar
|
||||
is_guest: True if the user making this request is a guest user
|
||||
scope: the scope of the access token used for this request, if any
|
||||
shadow_banned: True if the user making this request is shadow-banned.
|
||||
device_id: device_id which was set at authentication time
|
||||
app_service: the AS requesting on behalf of the user
|
||||
@@ -215,10 +220,13 @@ def create_requester(
|
||||
if authenticated_entity is None:
|
||||
authenticated_entity = user_id.to_string()
|
||||
|
||||
scope = set(scope)
|
||||
|
||||
return Requester(
|
||||
user_id,
|
||||
access_token_id,
|
||||
is_guest,
|
||||
scope,
|
||||
shadow_banned,
|
||||
device_id,
|
||||
app_service,
|
||||
|
||||
@@ -76,7 +76,7 @@ def unwrapFirstError(failure: Failure) -> Failure:
|
||||
# the subFailure's value, which will do a better job of preserving stacktraces.
|
||||
# (actually, you probably want to use yieldable_gather_results anyway)
|
||||
failure.trap(defer.FirstError)
|
||||
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
||||
return failure.value.subFailure
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
@@ -178,7 +178,7 @@ def log_failure(
|
||||
"""
|
||||
|
||||
logger.error(
|
||||
msg, exc_info=(failure.type, failure.value, failure.getTracebackObject()) # type: ignore[arg-type]
|
||||
msg, exc_info=(failure.type, failure.value, failure.getTracebackObject())
|
||||
)
|
||||
|
||||
if not consumeErrors:
|
||||
|
||||
@@ -138,7 +138,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
for observer in observers:
|
||||
# This is a little bit of magic to correctly propagate stack
|
||||
# traces when we `await` on one of the observer deferreds.
|
||||
f.value.__failure__ = f # type: ignore[union-attr]
|
||||
f.value.__failure__ = f
|
||||
try:
|
||||
observer.errback(f)
|
||||
except Exception as e:
|
||||
|
||||
@@ -93,10 +93,8 @@ VT = TypeVar("VT")
|
||||
# a general type var, distinct from either KT or VT
|
||||
T = TypeVar("T")
|
||||
|
||||
P = TypeVar("P")
|
||||
|
||||
|
||||
class _TimedListNode(ListNode[P]):
|
||||
class _TimedListNode(ListNode[T]):
|
||||
"""A `ListNode` that tracks last access time."""
|
||||
|
||||
__slots__ = ["last_access_ts_secs"]
|
||||
@@ -821,7 +819,7 @@ class AsyncLruCache(Generic[KT, VT]):
|
||||
utilize external cache systems that require await behaviour to be created.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs): # type: ignore
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs)
|
||||
|
||||
async def get(
|
||||
|
||||
@@ -41,7 +41,7 @@ from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
filtered_event_logger = logging.getLogger("synapse.visibility.filtered_event_debug")
|
||||
|
||||
VISIBILITY_PRIORITY = (
|
||||
HistoryVisibility.WORLD_READABLE,
|
||||
@@ -97,8 +97,8 @@ async def filter_events_for_client(
|
||||
events_before_filtering = events
|
||||
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
|
||||
if len(events_before_filtering) != len(events):
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug(
|
||||
if filtered_event_logger.isEnabledFor(logging.DEBUG):
|
||||
filtered_event_logger.debug(
|
||||
"filter_events_for_client: Filtered out soft-failed events: Before=%s, After=%s",
|
||||
[event.event_id for event in events_before_filtering],
|
||||
[event.event_id for event in events],
|
||||
@@ -319,7 +319,7 @@ def _check_client_allowed_to_see_event(
|
||||
_check_filter_send_to_client(event, clock, retention_policy, sender_ignored)
|
||||
== _CheckFilter.DENIED
|
||||
):
|
||||
logger.debug(
|
||||
filtered_event_logger.debug(
|
||||
"_check_client_allowed_to_see_event(event=%s): Filtered out event because `_check_filter_send_to_client` returned `_CheckFilter.DENIED`",
|
||||
event.event_id,
|
||||
)
|
||||
@@ -341,7 +341,7 @@ def _check_client_allowed_to_see_event(
|
||||
)
|
||||
return event
|
||||
|
||||
logger.debug(
|
||||
filtered_event_logger.debug(
|
||||
"_check_client_allowed_to_see_event(event=%s): Filtered out event because it's an outlier",
|
||||
event.event_id,
|
||||
)
|
||||
@@ -367,7 +367,7 @@ def _check_client_allowed_to_see_event(
|
||||
|
||||
membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
|
||||
if not membership_result.allowed:
|
||||
logger.debug(
|
||||
filtered_event_logger.debug(
|
||||
"_check_client_allowed_to_see_event(event=%s): Filtered out event because the user can't see the event because of their membership, membership_result.allowed=%s membership_result.joined=%s",
|
||||
event.event_id,
|
||||
membership_result.allowed,
|
||||
@@ -378,7 +378,7 @@ def _check_client_allowed_to_see_event(
|
||||
# If the sender has been erased and the user was not joined at the time, we
|
||||
# must only return the redacted form.
|
||||
if sender_erased and not membership_result.joined:
|
||||
logger.debug(
|
||||
filtered_event_logger.debug(
|
||||
"_check_client_allowed_to_see_event(event=%s): Returning pruned event because `sender_erased` and the user was not joined at the time",
|
||||
event.event_id,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ import pymacaroons
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.auth.internal import InternalAuth
|
||||
from synapse.api.auth_blocking import AuthBlocking
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import (
|
||||
@@ -48,7 +48,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
# have been called by the HomeserverTestCase machinery.
|
||||
hs.datastores.main = self.store # type: ignore[union-attr]
|
||||
hs.get_auth_handler().store = self.store
|
||||
self.auth = Auth(hs)
|
||||
self.auth = InternalAuth(hs)
|
||||
|
||||
# AuthBlocking reads from the hs' config on initialization. We need to
|
||||
# modify its config instead of the hs'
|
||||
@@ -426,6 +426,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
access_token_id=None,
|
||||
device_id="FOOBAR",
|
||||
is_guest=False,
|
||||
scope=set(),
|
||||
shadow_banned=False,
|
||||
app_service=appservice,
|
||||
authenticated_entity="@appservice:server",
|
||||
@@ -456,6 +457,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
access_token_id=None,
|
||||
device_id="FOOBAR",
|
||||
is_guest=False,
|
||||
scope=set(),
|
||||
shadow_banned=False,
|
||||
app_service=appservice,
|
||||
authenticated_entity="@appservice:server",
|
||||
|
||||
@@ -35,7 +35,6 @@ from tests.events.test_utils import MockEvent
|
||||
|
||||
user_id = UserID.from_string("@test_user:test")
|
||||
user2_id = UserID.from_string("@test_user2:test")
|
||||
user_localpart = "test_user"
|
||||
|
||||
|
||||
class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
@@ -449,9 +448,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
@@ -479,9 +476,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart + "2", filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_presence(presence_states))
|
||||
@@ -498,9 +493,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events=events))
|
||||
@@ -519,9 +512,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
events = [event]
|
||||
|
||||
user_filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
results = self.get_success(user_filter.filter_room_state(events))
|
||||
@@ -603,9 +594,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
user_filter_json,
|
||||
(
|
||||
self.get_success(
|
||||
self.datastore.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=0
|
||||
)
|
||||
self.datastore.get_user_filter(user_id=user_id, filter_id=0)
|
||||
)
|
||||
),
|
||||
)
|
||||
@@ -620,9 +609,7 @@ class FilteringTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
filter = self.get_success(
|
||||
self.filtering.get_user_filter(
|
||||
user_localpart=user_localpart, filter_id=filter_id
|
||||
)
|
||||
self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id)
|
||||
)
|
||||
|
||||
self.assertEqual(filter.get_filter_json(), user_filter_json)
|
||||
|
||||
257
tests/config/test_oauth_delegation.py
Normal file
257
tests/config/test_oauth_delegation.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright 2023 Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests.server import get_clock, setup_test_homeserver
|
||||
from tests.unittest import TestCase, skip_unless
|
||||
from tests.utils import default_config
|
||||
|
||||
try:
|
||||
import authlib # noqa: F401
|
||||
|
||||
HAS_AUTHLIB = True
|
||||
except ImportError:
|
||||
HAS_AUTHLIB = False
|
||||
|
||||
|
||||
# These are a few constants that are used as config parameters in the tests.
|
||||
SERVER_NAME = "test"
|
||||
ISSUER = "https://issuer/"
|
||||
CLIENT_ID = "test-client-id"
|
||||
CLIENT_SECRET = "test-client-secret"
|
||||
BASE_URL = "https://synapse/"
|
||||
|
||||
|
||||
class CustomAuthModule:
|
||||
"""A module which registers a password auth provider."""
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: JsonDict) -> None:
|
||||
pass
|
||||
|
||||
def __init__(self, config: None, api: ModuleApi):
|
||||
api.register_password_auth_provider_callbacks(
|
||||
auth_checkers={("m.login.password", ("password",)): Mock()},
|
||||
)
|
||||
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
class MSC3861OAuthDelegation(TestCase):
|
||||
"""Test that the Homeserver fails to initialize if the config is invalid."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.config_dict: JsonDict = {
|
||||
**default_config("test"),
|
||||
"public_baseurl": BASE_URL,
|
||||
"enable_registration": False,
|
||||
"experimental_features": {
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": ISSUER,
|
||||
"client_id": CLIENT_ID,
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": CLIENT_SECRET,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def parse_config(self) -> HomeServerConfig:
|
||||
config = HomeServerConfig()
|
||||
config.parse_config_dict(self.config_dict, "", "")
|
||||
return config
|
||||
|
||||
def test_client_secret_post_works(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="client_secret_post",
|
||||
client_secret=CLIENT_SECRET,
|
||||
)
|
||||
|
||||
self.parse_config()
|
||||
|
||||
def test_client_secret_post_requires_client_secret(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="client_secret_post",
|
||||
client_secret=None,
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_client_secret_basic_works(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="client_secret_basic",
|
||||
client_secret=CLIENT_SECRET,
|
||||
)
|
||||
|
||||
self.parse_config()
|
||||
|
||||
def test_client_secret_basic_requires_client_secret(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="client_secret_basic",
|
||||
client_secret=None,
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_client_secret_jwt_works(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="client_secret_jwt",
|
||||
client_secret=CLIENT_SECRET,
|
||||
)
|
||||
|
||||
self.parse_config()
|
||||
|
||||
def test_client_secret_jwt_requires_client_secret(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="client_secret_jwt",
|
||||
client_secret=None,
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_invalid_client_auth_method(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="invalid",
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_private_key_jwt_requires_jwk(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="private_key_jwt",
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_private_key_jwt_works(self) -> None:
|
||||
self.config_dict["experimental_features"]["msc3861"].update(
|
||||
client_auth_method="private_key_jwt",
|
||||
jwk={
|
||||
"p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8",
|
||||
"kty": "RSA",
|
||||
"q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU",
|
||||
"d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ",
|
||||
"e": "AQAB",
|
||||
"kid": "test",
|
||||
"qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI",
|
||||
"dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8",
|
||||
"dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE",
|
||||
"n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw",
|
||||
},
|
||||
)
|
||||
self.parse_config()
|
||||
|
||||
def test_registration_cannot_be_enabled(self) -> None:
|
||||
self.config_dict["enable_registration"] = True
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_password_config_cannot_be_enabled(self) -> None:
|
||||
self.config_dict["password_config"] = {"enabled": True}
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_oidc_sso_cannot_be_enabled(self) -> None:
|
||||
self.config_dict["oidc_providers"] = [
|
||||
{
|
||||
"idp_id": "microsoft",
|
||||
"idp_name": "Microsoft",
|
||||
"issuer": "https://login.microsoftonline.com/<tenant id>/v2.0",
|
||||
"client_id": "<client id>",
|
||||
"client_secret": "<client secret>",
|
||||
"scopes": ["openid", "profile"],
|
||||
"authorization_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize",
|
||||
"token_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token",
|
||||
"userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo",
|
||||
}
|
||||
]
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_cas_sso_cannot_be_enabled(self) -> None:
|
||||
self.config_dict["cas_config"] = {
|
||||
"enabled": True,
|
||||
"server_url": "https://cas-server.com",
|
||||
"displayname_attribute": "name",
|
||||
"required_attributes": {"userGroup": "staff", "department": "None"},
|
||||
}
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_auth_providers_cannot_be_enabled(self) -> None:
|
||||
self.config_dict["modules"] = [
|
||||
{
|
||||
"module": f"{__name__}.{CustomAuthModule.__qualname__}",
|
||||
"config": {},
|
||||
}
|
||||
]
|
||||
|
||||
# This requires actually setting up an HS, as the module will be run on setup,
|
||||
# which should raise as the module tries to register an auth provider
|
||||
config = self.parse_config()
|
||||
reactor, clock = get_clock()
|
||||
with self.assertRaises(ConfigError):
|
||||
setup_test_homeserver(
|
||||
self.addCleanup, reactor=reactor, clock=clock, config=config
|
||||
)
|
||||
|
||||
def test_jwt_auth_cannot_be_enabled(self) -> None:
|
||||
self.config_dict["jwt_config"] = {
|
||||
"enabled": True,
|
||||
"secret": "my-secret-token",
|
||||
"algorithm": "HS256",
|
||||
}
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_login_via_existing_session_cannot_be_enabled(self) -> None:
|
||||
self.config_dict["login_via_existing_session"] = {"enabled": True}
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_captcha_cannot_be_enabled(self) -> None:
|
||||
self.config_dict.update(
|
||||
enable_registration_captcha=True,
|
||||
recaptcha_public_key="test",
|
||||
recaptcha_private_key="test",
|
||||
)
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_refreshable_tokens_cannot_be_enabled(self) -> None:
|
||||
self.config_dict.update(
|
||||
refresh_token_lifetime="24h",
|
||||
refreshable_access_token_lifetime="10m",
|
||||
nonrefreshable_access_token_lifetime="24h",
|
||||
)
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
|
||||
def test_session_lifetime_cannot_be_set(self) -> None:
|
||||
self.config_dict["session_lifetime"] = "24h"
|
||||
with self.assertRaises(ConfigError):
|
||||
self.parse_config()
|
||||
664
tests/handlers/test_oauth_delegation.py
Normal file
664
tests/handlers/test_oauth_delegation.py
Normal file
@@ -0,0 +1,664 @@
|
||||
# Copyright 2022 Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, Union
|
||||
from unittest.mock import ANY, Mock
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
from signedjson.key import (
|
||||
encode_verify_key_base64,
|
||||
generate_signing_key,
|
||||
get_verify_key,
|
||||
)
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
InvalidClientTokenError,
|
||||
OAuthInsufficientScopeError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import account, devices, keys, login, logout, register
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
||||
from tests.unittest import HomeserverTestCase, skip_unless
|
||||
from tests.utils import mock_getRawHeaders
|
||||
|
||||
try:
|
||||
import authlib # noqa: F401
|
||||
|
||||
HAS_AUTHLIB = True
|
||||
except ImportError:
|
||||
HAS_AUTHLIB = False
|
||||
|
||||
|
||||
# These are a few constants that are used as config parameters in the tests.
|
||||
SERVER_NAME = "test"
|
||||
ISSUER = "https://issuer/"
|
||||
CLIENT_ID = "test-client-id"
|
||||
CLIENT_SECRET = "test-client-secret"
|
||||
BASE_URL = "https://synapse/"
|
||||
SCOPES = ["openid"]
|
||||
|
||||
AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
|
||||
TOKEN_ENDPOINT = ISSUER + "token"
|
||||
USERINFO_ENDPOINT = ISSUER + "userinfo"
|
||||
WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
|
||||
JWKS_URI = ISSUER + ".well-known/jwks.json"
|
||||
INTROSPECTION_ENDPOINT = ISSUER + "introspect"
|
||||
|
||||
SYNAPSE_ADMIN_SCOPE = "urn:synapse:admin:*"
|
||||
MATRIX_USER_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:*"
|
||||
MATRIX_GUEST_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:guest"
|
||||
MATRIX_DEVICE_SCOPE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:"
|
||||
DEVICE = "AABBCCDD"
|
||||
MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE
|
||||
SUBJECT = "abc-def-ghi"
|
||||
USERNAME = "test-user"
|
||||
USER_ID = "@" + USERNAME + ":" + SERVER_NAME
|
||||
|
||||
|
||||
async def get_json(url: str) -> JsonDict:
|
||||
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
||||
if url == WELL_KNOWN:
|
||||
# Minimal discovery document, as defined in OpenID.Discovery
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
return {
|
||||
"issuer": ISSUER,
|
||||
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
|
||||
"token_endpoint": TOKEN_ENDPOINT,
|
||||
"jwks_uri": JWKS_URI,
|
||||
"userinfo_endpoint": USERINFO_ENDPOINT,
|
||||
"introspection_endpoint": INTROSPECTION_ENDPOINT,
|
||||
"response_types_supported": ["code"],
|
||||
"subject_types_supported": ["public"],
|
||||
"id_token_signing_alg_values_supported": ["RS256"],
|
||||
}
|
||||
elif url == JWKS_URI:
|
||||
return {"keys": []}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
@skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
servlets = [
|
||||
account.register_servlets,
|
||||
devices.register_servlets,
|
||||
keys.register_servlets,
|
||||
register.register_servlets,
|
||||
login.register_servlets,
|
||||
logout.register_servlets,
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
config["disable_registration"] = True
|
||||
config["experimental_features"] = {
|
||||
"msc3861": {
|
||||
"enabled": True,
|
||||
"issuer": ISSUER,
|
||||
"client_id": CLIENT_ID,
|
||||
"client_auth_method": "client_secret_post",
|
||||
"client_secret": CLIENT_SECRET,
|
||||
}
|
||||
}
|
||||
return config
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.http_client = Mock(spec=["get_json"])
|
||||
self.http_client.get_json.side_effect = get_json
|
||||
self.http_client.user_agent = b"Synapse Test"
|
||||
|
||||
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
return hs
|
||||
|
||||
def _assertParams(self) -> None:
|
||||
"""Assert that the request parameters are correct."""
|
||||
params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
|
||||
self.assertEqual(params["token"], ["mockAccessToken"])
|
||||
self.assertEqual(params["client_id"], [CLIENT_ID])
|
||||
self.assertEqual(params["client_secret"], [CLIENT_SECRET])
|
||||
|
||||
def test_inactive_token(self) -> None:
|
||||
"""The handler should return a 403 where the token is inactive."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": False},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_no_scope(self) -> None:
|
||||
"""The handler should return a 403 where no scope is given."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": True},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_user_no_subject(self) -> None:
|
||||
"""The handler should return a 500 when no subject is present."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_no_user_scope(self) -> None:
|
||||
"""The handler should return a 500 when no subject is present."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_DEVICE_SCOPE]),
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_admin_not_user(self) -> None:
|
||||
"""The handler should raise when the scope has admin right but not user."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_admin(self) -> None:
|
||||
"""The handler should return a requester with admin rights."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
self.assertEqual(requester.is_guest, False)
|
||||
self.assertEqual(requester.device_id, None)
|
||||
self.assertEqual(
|
||||
get_awaitable_result(self.auth.is_server_admin(requester)), True
|
||||
)
|
||||
|
||||
def test_active_admin_highest_privilege(self) -> None:
|
||||
"""The handler should resolve to the most permissive scope."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
|
||||
),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
self.assertEqual(requester.is_guest, False)
|
||||
self.assertEqual(requester.device_id, None)
|
||||
self.assertEqual(
|
||||
get_awaitable_result(self.auth.is_server_admin(requester)), True
|
||||
)
|
||||
|
||||
def test_active_user(self) -> None:
|
||||
"""The handler should return a requester with normal user rights."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
self.assertEqual(requester.is_guest, False)
|
||||
self.assertEqual(requester.device_id, None)
|
||||
self.assertEqual(
|
||||
get_awaitable_result(self.auth.is_server_admin(requester)), False
|
||||
)
|
||||
|
||||
def test_active_user_with_device(self) -> None:
|
||||
"""The handler should return a requester with normal user rights and a device ID."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
self.assertEqual(requester.is_guest, False)
|
||||
self.assertEqual(
|
||||
get_awaitable_result(self.auth.is_server_admin(requester)), False
|
||||
)
|
||||
self.assertEqual(requester.device_id, DEVICE)
|
||||
|
||||
def test_multiple_devices(self) -> None:
|
||||
"""The handler should raise an error if multiple devices are found in the scope."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), AuthError)
|
||||
|
||||
def test_active_guest_not_allowed(self) -> None:
|
||||
"""The handler should return an insufficient scope error."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
error = self.get_failure(
|
||||
self.auth.get_user_by_req(request), OAuthInsufficientScopeError
|
||||
)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(
|
||||
getattr(error.value, "headers", {})["WWW-Authenticate"],
|
||||
'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"',
|
||||
)
|
||||
|
||||
def test_active_guest_allowed(self) -> None:
|
||||
"""The handler should return a requester with guest user rights and a device ID."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(
|
||||
self.auth.get_user_by_req(request, allow_guest=True)
|
||||
)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
self.assertEqual(requester.is_guest, True)
|
||||
self.assertEqual(
|
||||
get_awaitable_result(self.auth.is_server_admin(requester)), False
|
||||
)
|
||||
self.assertEqual(requester.device_id, DEVICE)
|
||||
|
||||
def test_unavailable_introspection_endpoint(self) -> None:
|
||||
"""The handler should return an internal server error."""
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
# The introspection endpoint is returning an error.
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse(code=500, body=b"Internal Server Error")
|
||||
)
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint request fails.
|
||||
self.http_client.request = simple_async_mock(raises=Exception())
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return a JSON object.
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200, payload=["this is an array", "not an object"]
|
||||
)
|
||||
)
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return valid JSON.
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
|
||||
)
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
|
||||
# We only generate a master key to simplify the test.
|
||||
master_signing_key = generate_signing_key(device_id)
|
||||
master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key))
|
||||
|
||||
return {
|
||||
"master_key": sign_json(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"usage": ["master"],
|
||||
"keys": {"ed25519:" + master_verify_key: master_verify_key},
|
||||
},
|
||||
user_id,
|
||||
master_signing_key,
|
||||
),
|
||||
}
|
||||
|
||||
def test_cross_signing(self) -> None:
|
||||
"""Try uploading device keys with OAuth delegation enabled."""
|
||||
|
||||
self.http_client.request = simple_async_mock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
)
|
||||
keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v3/keys/device_signing/upload",
|
||||
keys_upload_body,
|
||||
access_token="mockAccessToken",
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v3/keys/device_signing/upload",
|
||||
keys_upload_body,
|
||||
access_token="mockAccessToken",
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body)
|
||||
|
||||
def expect_unauthorized(
|
||||
self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
|
||||
) -> None:
|
||||
channel = self.make_request(method, path, content, shorthand=False)
|
||||
|
||||
self.assertEqual(channel.code, 401, channel.json_body)
|
||||
|
||||
def expect_unrecognized(
|
||||
self, method: str, path: str, content: Union[bytes, str, JsonDict] = ""
|
||||
) -> None:
|
||||
channel = self.make_request(method, path, content)
|
||||
|
||||
self.assertEqual(channel.code, 404, channel.json_body)
|
||||
self.assertEqual(
|
||||
channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body
|
||||
)
|
||||
|
||||
def test_uia_endpoints(self) -> None:
|
||||
"""Test that endpoints that were removed in MSC2964 are no longer available."""
|
||||
|
||||
# This is just an endpoint that should remain visible (but requires auth):
|
||||
self.expect_unauthorized("GET", "/_matrix/client/v3/devices")
|
||||
|
||||
# This remains usable, but will require a uia scope:
|
||||
self.expect_unauthorized(
|
||||
"POST", "/_matrix/client/v3/keys/device_signing/upload"
|
||||
)
|
||||
|
||||
def test_3pid_endpoints(self) -> None:
|
||||
"""Test that 3pid account management endpoints that were removed in MSC2964 are no longer available."""
|
||||
|
||||
# Remains and requires auth:
|
||||
self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid")
|
||||
self.expect_unauthorized(
|
||||
"POST",
|
||||
"/_matrix/client/v3/account/3pid/bind",
|
||||
{
|
||||
"client_secret": "foo",
|
||||
"id_access_token": "bar",
|
||||
"id_server": "foo",
|
||||
"sid": "bar",
|
||||
},
|
||||
)
|
||||
self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {})
|
||||
|
||||
# These are gone:
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/account/3pid"
|
||||
) # deprecated
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add")
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete")
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/account/3pid/email/requestToken"
|
||||
)
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken"
|
||||
)
|
||||
|
||||
def test_account_management_endpoints_removed(self) -> None:
|
||||
"""Test that account management endpoints that were removed in MSC2964 are no longer available."""
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate")
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/account/password")
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/account/password/email/requestToken"
|
||||
)
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/account/password/msisdn/requestToken"
|
||||
)
|
||||
|
||||
def test_registration_endpoints_removed(self) -> None:
|
||||
"""Test that registration endpoints that were removed in MSC2964 are no longer available."""
|
||||
self.expect_unrecognized(
|
||||
"GET", "/_matrix/client/v1/register/m.login.registration_token/validity"
|
||||
)
|
||||
# This is still available for AS registrations
|
||||
# self.expect_unrecognized("POST", "/_matrix/client/v3/register")
|
||||
self.expect_unrecognized("GET", "/_matrix/client/v3/register/available")
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/register/email/requestToken"
|
||||
)
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/register/msisdn/requestToken"
|
||||
)
|
||||
|
||||
def test_session_management_endpoints_removed(self) -> None:
|
||||
"""Test that session management endpoints that were removed in MSC2964 are no longer available."""
|
||||
self.expect_unrecognized("GET", "/_matrix/client/v3/login")
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/login")
|
||||
self.expect_unrecognized("GET", "/_matrix/client/v3/login/sso/redirect")
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/logout")
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/logout/all")
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/refresh")
|
||||
self.expect_unrecognized("GET", "/_matrix/static/client/login")
|
||||
|
||||
def test_device_management_endpoints_removed(self) -> None:
|
||||
"""Test that device management endpoints that were removed in MSC2964 are no longer available."""
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices")
|
||||
self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}")
|
||||
|
||||
def test_openid_endpoints_removed(self) -> None:
|
||||
"""Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available."""
|
||||
self.expect_unrecognized(
|
||||
"POST", "/_matrix/client/v3/user/{USERNAME}/openid/request_token"
|
||||
)
|
||||
|
||||
def test_admin_api_endpoints_removed(self) -> None:
|
||||
"""Test that admin API endpoints that were removed in MSC2964 are no longer available."""
|
||||
self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens")
|
||||
self.expect_unrecognized("POST", "/_synapse/admin/v1/registration_tokens/new")
|
||||
self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens/abcd")
|
||||
self.expect_unrecognized("PUT", "/_synapse/admin/v1/registration_tokens/abcd")
|
||||
self.expect_unrecognized(
|
||||
"DELETE", "/_synapse/admin/v1/registration_tokens/abcd"
|
||||
)
|
||||
self.expect_unrecognized("POST", "/_synapse/admin/v1/reset_password/foo")
|
||||
self.expect_unrecognized("POST", "/_synapse/admin/v1/users/foo/login")
|
||||
self.expect_unrecognized("GET", "/_synapse/admin/v1/register")
|
||||
self.expect_unrecognized("POST", "/_synapse/admin/v1/register")
|
||||
self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin")
|
||||
self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin")
|
||||
self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity")
|
||||
@@ -80,11 +80,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||
"Frank Jr.",
|
||||
)
|
||||
|
||||
@@ -96,11 +92,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
@@ -112,7 +104,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
self.get_success(self.store.get_profile_displayname(self.frank.localpart))
|
||||
self.get_success(self.store.get_profile_displayname(self.frank))
|
||||
)
|
||||
|
||||
def test_set_my_name_if_disabled(self) -> None:
|
||||
@@ -122,11 +114,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(self.store.set_profile_displayname(self.frank, "Frank"))
|
||||
|
||||
self.assertEqual(
|
||||
(
|
||||
self.get_success(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
(self.get_success(self.store.get_profile_displayname(self.frank))),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
@@ -201,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
|
||||
@@ -215,7 +203,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
@@ -229,7 +217,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertIsNone(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
)
|
||||
|
||||
def test_set_my_avatar_if_disabled(self) -> None:
|
||||
@@ -241,7 +229,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
|
||||
(self.get_success(self.store.get_profile_avatar_url(self.frank))),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.auth.internal import InternalAuth
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
@@ -683,7 +683,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [token.encode("ascii")]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
auth = Auth(self.hs)
|
||||
auth = InternalAuth(self.hs)
|
||||
requester = self.get_success(auth.get_user_by_req(request))
|
||||
|
||||
self.assertTrue(requester.shadow_banned)
|
||||
|
||||
@@ -24,7 +24,7 @@ from tests import unittest
|
||||
try:
|
||||
import lxml
|
||||
except ImportError:
|
||||
lxml = None
|
||||
lxml = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class SummarizeTestCase(unittest.TestCase):
|
||||
@@ -160,6 +160,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
@@ -176,6 +177,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
@@ -195,6 +197,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -217,6 +220,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
@@ -231,6 +235,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
@@ -246,6 +251,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
|
||||
@@ -261,6 +267,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
|
||||
@@ -281,6 +288,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
|
||||
@@ -296,6 +304,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
@@ -324,6 +333,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
<head><title>Foo</title></head><body>Some text.</body></html>
|
||||
""".strip()
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
@@ -338,6 +348,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
</html>
|
||||
"""
|
||||
tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
@@ -353,6 +364,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
</html>
|
||||
"""
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
|
||||
|
||||
@@ -367,6 +379,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
</html>
|
||||
"""
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
|
||||
|
||||
@@ -380,6 +393,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
</html>
|
||||
"""
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
self.assertEqual(
|
||||
og,
|
||||
@@ -401,6 +415,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
</html>
|
||||
"""
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
self.assertEqual(
|
||||
og,
|
||||
@@ -419,6 +434,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a>
|
||||
"""
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
assert tree is not None
|
||||
og = parse_html_to_open_graph(tree)
|
||||
self.assertEqual(
|
||||
og,
|
||||
|
||||
@@ -28,7 +28,7 @@ from tests.unittest import HomeserverTestCase
|
||||
try:
|
||||
import lxml
|
||||
except ImportError:
|
||||
lxml = None
|
||||
lxml = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class OEmbedTests(HomeserverTestCase):
|
||||
|
||||
@@ -24,7 +24,7 @@ from tests.unittest import override_config
|
||||
try:
|
||||
import lxml
|
||||
except ImportError:
|
||||
lxml = None
|
||||
lxml = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user