Speed up MAS token introspection (#18357)

We do this by shoving it into Rust. We believe our python http client is
a bit slow.

Also bumps minimum rust version to 1.81.0, released last September (over
six months ago)

To allow for async Rust, includes some adapters between Tokio in Rust
and the Twisted reactor in Python.
This commit is contained in:
Erik Johnston
2025-06-16 10:41:35 -05:00
committed by GitHub
parent df04931f0b
commit f500c7d982
11 changed files with 1828 additions and 301 deletions

View File

@@ -85,7 +85,7 @@ jobs:
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
with:
@@ -149,7 +149,7 @@ jobs:
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
- name: Setup Poetry
@@ -210,7 +210,7 @@ jobs:
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
with:
@@ -227,7 +227,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@0d72692bcfbf448b1e2afa01a67f71b455a9dcec # 1.86.0
with:
components: clippy
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
@@ -247,7 +247,7 @@ jobs:
- name: Install Rust
uses: dtolnay/rust-toolchain@56f84321dbccf38fb67ce29ab63e4754056677e0 # master (rust 1.85.1)
with:
toolchain: nightly-2022-12-01
toolchain: nightly-2025-04-23
components: clippy
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
@@ -265,7 +265,7 @@ jobs:
uses: dtolnay/rust-toolchain@56f84321dbccf38fb67ce29ab63e4754056677e0 # master (rust 1.85.1)
with:
# We use nightly so that it correctly groups together imports
toolchain: nightly-2022-12-01
toolchain: nightly-2025-04-23
components: rustfmt
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
@@ -362,7 +362,7 @@ jobs:
postgres:${{ matrix.job.postgres-version }}
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
@@ -404,7 +404,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
# There aren't wheels for some of the older deps, so we need to install
@@ -519,7 +519,7 @@ jobs:
run: cat sytest-blacklist .ci/worker-blacklist > synapse-blacklist-with-workers
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
- name: Run SyTest
@@ -663,7 +663,7 @@ jobs:
path: synapse
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
- name: Prepare Complement's Prerequisites
@@ -695,7 +695,7 @@ jobs:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install Rust
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
- run: cargo test

1333
Cargo.lock generated

File diff suppressed because it is too large Load Diff

1
changelog.d/18357.misc Normal file
View File

@@ -0,0 +1 @@
Increase performance of introspecting access tokens when using delegated auth.

View File

@@ -7,7 +7,7 @@ name = "synapse"
version = "0.1.0"
edition = "2021"
rust-version = "1.66.0"
rust-version = "1.81.0"
[lib]
name = "synapse"
@@ -36,13 +36,21 @@ pyo3 = { version = "0.24.2", features = [
"abi3",
"abi3-py39",
] }
pyo3-log = "0.12.0"
pyo3-log = "0.12.3"
pythonize = "0.24.0"
regex = "1.6.0"
sha2 = "0.10.8"
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"
ulid = "1.1.2"
reqwest = { version = "0.12.15", default-features = false, features = [
"http2",
"stream",
"rustls-tls-native-roots",
] }
http-body-util = "0.1.3"
futures = "0.3.31"
tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] }
[features]
extension-module = ["pyo3/extension-module"]

View File

@@ -58,3 +58,15 @@ impl NotFoundError {
NotFoundError::new_err(())
}
}
import_exception!(synapse.api.errors, HttpResponseException);
impl HttpResponseException {
pub fn new(status: StatusCode, bytes: Vec<u8>) -> pyo3::PyErr {
HttpResponseException::new_err((
status.as_u16(),
status.canonical_reason().unwrap_or_default(),
bytes,
))
}
}

218
rust/src/http_client.rs Normal file
View File

@@ -0,0 +1,218 @@
/*
* This file is licensed under the Affero General Public License (AGPL) version 3.
*
* Copyright (C) 2025 New Vector, Ltd
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* See the GNU Affero General Public License for more details:
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*/
use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, sync::LazyLock};
use anyhow::Context;
use futures::{FutureExt, TryStreamExt};
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
use reqwest::RequestBuilder;
use tokio::runtime::Runtime;
use crate::errors::HttpResponseException;
/// The tokio runtime that we're using to run async Rust libs.
static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()
.unwrap()
});
/// A reference to the `Deferred` python class.
static DEFERRED_CLASS: LazyLock<PyObject> = LazyLock::new(|| {
Python::with_gil(|py| {
py.import("twisted.internet.defer")
.expect("module 'twisted.internet.defer' should be importable")
.getattr("Deferred")
.expect("module 'twisted.internet.defer' should have a 'Deferred' class")
.unbind()
})
});
/// A reference to the twisted `reactor`.
static TWISTED_REACTOR: LazyLock<Py<PyModule>> = LazyLock::new(|| {
Python::with_gil(|py| {
py.import("twisted.internet.reactor")
.expect("module 'twisted.internet.reactor' should be importable")
.unbind()
})
});
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
let child_module: Bound<'_, PyModule> = PyModule::new(py, "http_client")?;
child_module.add_class::<HttpClient>()?;
// Make sure we fail early if we can't build the lazy statics.
LazyLock::force(&RUNTIME);
LazyLock::force(&DEFERRED_CLASS);
m.add_submodule(&child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import acl` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.http_client", child_module)?;
Ok(())
}
#[pyclass]
#[derive(Clone)]
struct HttpClient {
client: reqwest::Client,
}
#[pymethods]
impl HttpClient {
#[new]
pub fn py_new(user_agent: &str) -> PyResult<HttpClient> {
// The twisted reactor can only be imported after Synapse has been
// imported, to allow Synapse to change the twisted reactor. If we try
// and import the reactor too early twisted installs a default reactor,
// which can't be replaced.
LazyLock::force(&TWISTED_REACTOR);
Ok(HttpClient {
client: reqwest::Client::builder()
.user_agent(user_agent)
.build()
.context("building reqwest client")?,
})
}
pub fn get<'a>(
&self,
py: Python<'a>,
url: String,
response_limit: usize,
) -> PyResult<Bound<'a, PyAny>> {
self.send_request(py, self.client.get(url), response_limit)
}
pub fn post<'a>(
&self,
py: Python<'a>,
url: String,
response_limit: usize,
headers: HashMap<String, String>,
request_body: String,
) -> PyResult<Bound<'a, PyAny>> {
let mut builder = self.client.post(url);
for (name, value) in headers {
builder = builder.header(name, value);
}
builder = builder.body(request_body);
self.send_request(py, builder, response_limit)
}
}
impl HttpClient {
fn send_request<'a>(
&self,
py: Python<'a>,
builder: RequestBuilder,
response_limit: usize,
) -> PyResult<Bound<'a, PyAny>> {
create_deferred(py, async move {
let response = builder.send().await.context("sending request")?;
let status = response.status();
let mut stream = response.bytes_stream();
let mut buffer = Vec::new();
while let Some(chunk) = stream.try_next().await.context("reading body")? {
if buffer.len() + chunk.len() > response_limit {
Err(anyhow::anyhow!("Response size too large"))?;
}
buffer.extend_from_slice(&chunk);
}
if !status.is_success() {
return Err(HttpResponseException::new(status, buffer));
}
let r = Python::with_gil(|py| buffer.into_pyobject(py).map(|o| o.unbind()))?;
Ok(r)
})
}
}
/// Creates a twisted deferred from the given future, spawning the task on the
/// tokio runtime.
///
/// Does not handle deferred cancellation or contextvars.
fn create_deferred<F, O>(py: Python, fut: F) -> PyResult<Bound<'_, PyAny>>
where
F: Future<Output = PyResult<O>> + Send + 'static,
for<'a> O: IntoPyObject<'a>,
{
let deferred = DEFERRED_CLASS.bind(py).call0()?;
let deferred_callback = deferred.getattr("callback")?.unbind();
let deferred_errback = deferred.getattr("errback")?.unbind();
RUNTIME.spawn(async move {
// TODO: Is it safe to assert unwind safety here? I think so, as we
// don't use anything that could be tainted by the panic afterwards.
// Note that `.spawn(..)` asserts unwind safety on the future too.
let res = AssertUnwindSafe(fut).catch_unwind().await;
Python::with_gil(move |py| {
// Flatten the panic into standard python error
let res = match res {
Ok(r) => r,
Err(panic_err) => {
let panic_message = get_panic_message(&panic_err);
Err(PyException::new_err(
PyString::new(py, panic_message).unbind(),
))
}
};
// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
match res {
Ok(obj) => {
TWISTED_REACTOR
.call_method(py, "callFromThread", (deferred_callback, obj), None)
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
}
Err(err) => {
TWISTED_REACTOR
.call_method(py, "callFromThread", (deferred_errback, err), None)
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
}
}
});
});
Ok(deferred)
}
/// Try and get the panic message out of the panic
fn get_panic_message<'a>(panic_err: &'a (dyn std::any::Any + Send + 'static)) -> &'a str {
// Apparently this is how you extract the panic message from a panic
if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
str_slice
} else if let Some(string) = panic_err.downcast_ref::<String>() {
string
} else {
"unknown error"
}
}

View File

@@ -8,6 +8,7 @@ pub mod acl;
pub mod errors;
pub mod events;
pub mod http;
pub mod http_client;
pub mod identifier;
pub mod matrix_const;
pub mod push;
@@ -50,6 +51,7 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
acl::register_module(py, m)?;
push::register_module(py, m)?;
events::register_module(py, m)?;
http_client::register_module(py, m)?;
rendezvous::register_module(py, m)?;
Ok(())

View File

@@ -30,9 +30,6 @@ from authlib.oauth2.rfc7662 import IntrospectionToken
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
from prometheus_client import Histogram
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
from synapse.api.auth.base import BaseAuth
from synapse.api.errors import (
AuthError,
@@ -43,8 +40,14 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.opentracing import (
active_span,
force_tracing,
inject_request_headers,
start_active_span,
)
from synapse.synapse_rust.http_client import HttpClient
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
@@ -179,6 +182,10 @@ class MSC3861DelegatedAuth(BaseAuth):
self._admin_token: Callable[[], Optional[str]] = self._config.admin_token
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
self._rust_http_client = HttpClient(
user_agent=self._http_client.user_agent.decode("utf8")
)
# # Token Introspection Cache
# This remembers what users/devices are represented by which access tokens,
# in order to reduce overall system load:
@@ -301,7 +308,6 @@ class MSC3861DelegatedAuth(BaseAuth):
introspection_endpoint = await self._introspection_endpoint()
raw_headers: Dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": str(self._http_client.user_agent, "utf-8"),
"Accept": "application/json",
# Tell MAS that we support reading the device ID as an explicit
# value, not encoded in the scope. This is supported by MAS 0.15+
@@ -315,38 +321,34 @@ class MSC3861DelegatedAuth(BaseAuth):
uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
)
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code, and we do the body encoding ourselves.
logger.debug("Fetching token from MAS")
start_time = self._clock.time()
try:
response = await self._http_client.request(
method="POST",
uri=uri,
data=body.encode("utf-8"),
headers=headers,
)
resp_body = await make_deferred_yieldable(readBody(response))
with start_active_span("mas-introspect-token"):
inject_request_headers(raw_headers)
with PreserveLoggingContext():
resp_body = await self._rust_http_client.post(
url=uri,
response_limit=1 * 1024 * 1024,
headers=raw_headers,
request_body=body,
)
except HttpResponseException as e:
end_time = self._clock.time()
introspection_response_timer.labels(e.code).observe(end_time - start_time)
raise
except Exception:
end_time = self._clock.time()
introspection_response_timer.labels("ERR").observe(end_time - start_time)
raise
end_time = self._clock.time()
introspection_response_timer.labels(response.code).observe(
end_time - start_time
)
logger.debug("Fetched token from MAS")
if response.code < 200 or response.code >= 300:
raise HttpResponseException(
response.code,
response.phrase.decode("ascii", errors="replace"),
resp_body,
)
end_time = self._clock.time()
introspection_response_timer.labels(200).observe(end_time - start_time)
resp = json_decoder.decode(resp_body.decode("utf-8"))

View File

@@ -796,6 +796,13 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
@ensure_active_span("inject the span into a header dict")
def inject_request_headers(headers: Dict[str, str]) -> None:
span = opentracing.tracer.active_span
assert span is not None
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, headers)
@ensure_active_span(
"get the active span context as a dict", ret=cast(Dict[str, str], {})
)

View File

@@ -0,0 +1,24 @@
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
from typing import Awaitable, Mapping
class HttpClient:
def __init__(self, user_agent: str) -> None: ...
def get(self, url: str, response_limit: int) -> Awaitable[bytes]: ...
def post(
self,
url: str,
response_limit: int,
headers: Mapping[str, str],
request_body: str,
) -> Awaitable[bytes]: ...

View File

@@ -19,9 +19,10 @@
#
#
import json
from http import HTTPStatus
from io import BytesIO
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Union
from unittest.mock import ANY, AsyncMock, Mock
from urllib.parse import parse_qs
@@ -33,12 +34,11 @@ from signedjson.key import (
from signedjson.sign import sign_json
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from synapse.api.errors import (
AuthError,
Codes,
HttpResponseException,
InvalidClientTokenError,
OAuthInsufficientScopeError,
SynapseError,
@@ -52,7 +52,7 @@ from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests.server import FakeChannel
from tests.test_utils import FakeResponse, get_awaitable_result
from tests.test_utils import get_awaitable_result
from tests.unittest import HomeserverTestCase, override_config, skip_unless
from tests.utils import HAS_AUTHLIB, checked_cast, mock_getRawHeaders
@@ -145,6 +145,9 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.auth = checked_cast(MSC3861DelegatedAuth, hs.get_auth())
self._rust_client = Mock(spec=["post"])
self.auth._rust_http_client = self._rust_client
return hs
def prepare(
@@ -157,9 +160,15 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
store.store_device(USER_ID, DEVICE, initial_device_display_name=None)
)
def _set_introspection_returnvalue(self, response_value: Any) -> AsyncMock:
self._rust_client.post = mock = AsyncMock(
return_value=json.dumps(response_value).encode("utf-8")
)
return mock
def _assertParams(self) -> None:
"""Assert that the request parameters are correct."""
params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
params = parse_qs(self._rust_client.post.call_args[1]["request_body"])
self.assertEqual(params["token"], ["mockAccessToken"])
self.assertEqual(params["client_id"], [CLIENT_ID])
self.assertEqual(params["client_secret"], [CLIENT_SECRET])
@@ -167,128 +176,125 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_inactive_token(self) -> None:
"""The handler should return a 403 where the token is inactive."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": False},
)
)
self._set_introspection_returnvalue({"active": False})
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
def test_active_no_scope(self) -> None:
"""The handler should return a 403 where no scope is given."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": True},
)
)
self._set_introspection_returnvalue({"active": True})
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
def test_active_user_no_subject(self) -> None:
"""The handler should return a 500 when no subject is present."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
)
self._set_introspection_returnvalue(
{"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
def test_active_no_user_scope(self) -> None:
"""The handler should return a 500 when no subject is present."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_DEVICE_SCOPE]),
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_DEVICE_SCOPE]),
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
def test_active_admin_not_user(self) -> None:
"""The handler should raise when the scope has admin right but not user."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
def test_active_admin(self) -> None:
"""The handler should return a requester with admin rights."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
@@ -301,26 +307,26 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin_highest_privilege(self) -> None:
"""The handler should resolve to the most permissive scope."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join(
[SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join(
[SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
@@ -333,24 +339,24 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user(self) -> None:
"""The handler should return a requester with normal user rights."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE]),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
@@ -363,24 +369,24 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_with_device(self) -> None:
"""The handler should return a requester with normal user rights and a device ID."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
@@ -393,32 +399,32 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_with_device_explicit_device_id(self) -> None:
"""The handler should return a requester with normal user rights and a device ID, given explicitly, as supported by MAS 0.15+"""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE]),
"device_id": DEVICE,
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE]),
"device_id": DEVICE,
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
requester = self.get_success(self.auth.get_user_by_req(request))
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
# It should have called with the 'X-MAS-Supports-Device-Id: 1' header
self.assertEqual(
self.http_client.request.call_args[1]["headers"].getRawHeaders(
b"X-MAS-Supports-Device-Id",
self._rust_client.post.call_args[1]["headers"].get(
"X-MAS-Supports-Device-Id",
),
[b"1"],
"1",
)
self._assertParams()
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
@@ -431,22 +437,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_multiple_devices(self) -> None:
"""The handler should raise an error if multiple devices are found in the scope."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join(
[
MATRIX_USER_SCOPE,
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
]
),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join(
[
MATRIX_USER_SCOPE,
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
]
),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
@@ -456,16 +459,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_not_allowed(self) -> None:
"""The handler should return an insufficient scope error."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
@@ -474,8 +474,11 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.auth.get_user_by_req(request), OAuthInsufficientScopeError
)
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
self.assertEqual(
@@ -486,16 +489,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_allowed(self) -> None:
"""The handler should return a requester with guest user rights and a device ID."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
@@ -504,8 +504,11 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.auth.get_user_by_req(request, allow_guest=True)
)
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.http_client.request.assert_called_once_with(
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
self._rust_client.post.assert_called_once_with(
url=INTROSPECTION_ENDPOINT,
response_limit=ANY,
request_body=ANY,
headers=ANY,
)
self._assertParams()
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
@@ -522,30 +525,28 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
# The introspection endpoint is returning an error.
self.http_client.request = AsyncMock(
return_value=FakeResponse(code=500, body=b"Internal Server Error")
)
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint request fails.
self.http_client.request = AsyncMock(side_effect=Exception())
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return a JSON object.
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200, payload=["this is an array", "not an object"]
self._rust_client.post = AsyncMock(
side_effect=HttpResponseException(
code=500, msg="Internal Server Error", response=b"{}"
)
)
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint request fails.
self._rust_client.post = AsyncMock(side_effect=Exception())
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return a JSON object.
self._set_introspection_returnvalue(["this is an array", "not an object"])
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return valid JSON.
self.http_client.request = AsyncMock(
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
)
self._set_introspection_returnvalue("this is not valid JSON")
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
@@ -554,23 +555,21 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
an expiry time, the introspection response is cached and then the entry is
re-requested after it has expired."""
self.http_client.request = introspection_mock = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join(
[
MATRIX_USER_SCOPE,
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
]
),
"username": USERNAME,
"expires_in": 60,
},
)
introspection_mock = self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join(
[
MATRIX_USER_SCOPE,
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
]
),
"username": USERNAME,
"expires_in": 60,
}
)
request = Mock(args={})
request.args[b"access_token"] = [b"mockAccessToken"]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -607,16 +606,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_cross_signing(self) -> None:
"""Try uploading device keys with OAuth delegation enabled."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
}
)
keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
channel = self.make_request(
@@ -778,16 +774,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
# Because we still support those endpoints with ASes, it checks the
# access token before returning 404
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
},
)
self._set_introspection_returnvalue(
{
"active": True,
"sub": SUBJECT,
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
"username": USERNAME,
},
)
self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True)
@@ -820,9 +813,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_admin_token(self) -> None:
"""The handler should return a requester with admin rights when admin_token is used."""
self.http_client.request = AsyncMock(
return_value=FakeResponse.json(code=200, payload={"active": False}),
)
self._set_introspection_returnvalue({"active": False})
request = Mock(args={})
request.args[b"access_token"] = [b"admin_token_value"]
@@ -839,7 +830,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
)
# There should be no call to the introspection endpoint
self.http_client.request.assert_not_called()
self._rust_client.post.assert_not_called()
@override_config({"mau_stats_only": True})
def test_request_tracking(self) -> None:
@@ -852,28 +843,23 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
known_token = "token-token-GOOD-:)"
async def mock_http_client_request(
method: str,
uri: str,
data: Optional[bytes] = None,
headers: Optional[Headers] = None,
) -> IResponse:
url: str, request_body: str, **kwargs: Any
) -> bytes:
"""Mocked auth provider response."""
assert method == "POST"
token = parse_qs(data)[b"token"][0].decode("utf-8")
token = parse_qs(request_body)["token"][0]
if token == known_token:
return FakeResponse.json(
code=200,
payload={
return json.dumps(
{
"active": True,
"scope": MATRIX_USER_SCOPE,
"sub": SUBJECT,
"username": USERNAME,
},
)
).encode("utf-8")
return FakeResponse.json(code=200, payload={"active": False})
return json.dumps({"active": False}).encode("utf-8")
self.http_client.request = mock_http_client_request
self._rust_client.post = mock_http_client_request
EXAMPLE_IPV4_ADDR = "123.123.123.123"
EXAMPLE_USER_AGENT = "httprettygood"