mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Speed up MAS token introspection (#18357)
We do this by shoving it into Rust. We believe our python http client is a bit slow. Also bumps minimum rust version to 1.81.0, released last September (over six months ago) To allow for async Rust, includes some adapters between Tokio in Rust and the Twisted reactor in Python.
This commit is contained in:
22
.github/workflows/tests.yml
vendored
22
.github/workflows/tests.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
|
||||
with:
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- name: Setup Poetry
|
||||
@@ -210,7 +210,7 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
|
||||
with:
|
||||
@@ -227,7 +227,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@0d72692bcfbf448b1e2afa01a67f71b455a9dcec # 1.86.0
|
||||
with:
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@56f84321dbccf38fb67ce29ab63e4754056677e0 # master (rust 1.85.1)
|
||||
with:
|
||||
toolchain: nightly-2022-12-01
|
||||
toolchain: nightly-2025-04-23
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
@@ -265,7 +265,7 @@ jobs:
|
||||
uses: dtolnay/rust-toolchain@56f84321dbccf38fb67ce29ab63e4754056677e0 # master (rust 1.85.1)
|
||||
with:
|
||||
# We use nightly so that it correctly groups together imports
|
||||
toolchain: nightly-2022-12-01
|
||||
toolchain: nightly-2025-04-23
|
||||
components: rustfmt
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
@@ -362,7 +362,7 @@ jobs:
|
||||
postgres:${{ matrix.job.postgres-version }}
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
|
||||
@@ -404,7 +404,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
# There aren't wheels for some of the older deps, so we need to install
|
||||
@@ -519,7 +519,7 @@ jobs:
|
||||
run: cat sytest-blacklist .ci/worker-blacklist > synapse-blacklist-with-workers
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- name: Run SyTest
|
||||
@@ -663,7 +663,7 @@ jobs:
|
||||
path: synapse
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- name: Prepare Complement's Prerequisites
|
||||
@@ -695,7 +695,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- run: cargo test
|
||||
|
||||
1333
Cargo.lock
generated
1333
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
1
changelog.d/18357.misc
Normal file
1
changelog.d/18357.misc
Normal file
@@ -0,0 +1 @@
|
||||
Increase performance of introspecting access tokens when using delegated auth.
|
||||
@@ -7,7 +7,7 @@ name = "synapse"
|
||||
version = "0.1.0"
|
||||
|
||||
edition = "2021"
|
||||
rust-version = "1.66.0"
|
||||
rust-version = "1.81.0"
|
||||
|
||||
[lib]
|
||||
name = "synapse"
|
||||
@@ -36,13 +36,21 @@ pyo3 = { version = "0.24.2", features = [
|
||||
"abi3",
|
||||
"abi3-py39",
|
||||
] }
|
||||
pyo3-log = "0.12.0"
|
||||
pyo3-log = "0.12.3"
|
||||
pythonize = "0.24.0"
|
||||
regex = "1.6.0"
|
||||
sha2 = "0.10.8"
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
serde_json = "1.0.85"
|
||||
ulid = "1.1.2"
|
||||
reqwest = { version = "0.12.15", default-features = false, features = [
|
||||
"http2",
|
||||
"stream",
|
||||
"rustls-tls-native-roots",
|
||||
] }
|
||||
http-body-util = "0.1.3"
|
||||
futures = "0.3.31"
|
||||
tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] }
|
||||
|
||||
[features]
|
||||
extension-module = ["pyo3/extension-module"]
|
||||
|
||||
@@ -58,3 +58,15 @@ impl NotFoundError {
|
||||
NotFoundError::new_err(())
|
||||
}
|
||||
}
|
||||
|
||||
import_exception!(synapse.api.errors, HttpResponseException);
|
||||
|
||||
impl HttpResponseException {
|
||||
pub fn new(status: StatusCode, bytes: Vec<u8>) -> pyo3::PyErr {
|
||||
HttpResponseException::new_err((
|
||||
status.as_u16(),
|
||||
status.canonical_reason().unwrap_or_default(),
|
||||
bytes,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
218
rust/src/http_client.rs
Normal file
218
rust/src/http_client.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2025 New Vector, Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*/
|
||||
|
||||
use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, sync::LazyLock};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::{FutureExt, TryStreamExt};
|
||||
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
|
||||
use reqwest::RequestBuilder;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use crate::errors::HttpResponseException;
|
||||
|
||||
/// The tokio runtime that we're using to run async Rust libs.
|
||||
static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
/// A reference to the `Deferred` python class.
|
||||
static DEFERRED_CLASS: LazyLock<PyObject> = LazyLock::new(|| {
|
||||
Python::with_gil(|py| {
|
||||
py.import("twisted.internet.defer")
|
||||
.expect("module 'twisted.internet.defer' should be importable")
|
||||
.getattr("Deferred")
|
||||
.expect("module 'twisted.internet.defer' should have a 'Deferred' class")
|
||||
.unbind()
|
||||
})
|
||||
});
|
||||
|
||||
/// A reference to the twisted `reactor`.
|
||||
static TWISTED_REACTOR: LazyLock<Py<PyModule>> = LazyLock::new(|| {
|
||||
Python::with_gil(|py| {
|
||||
py.import("twisted.internet.reactor")
|
||||
.expect("module 'twisted.internet.reactor' should be importable")
|
||||
.unbind()
|
||||
})
|
||||
});
|
||||
|
||||
/// Called when registering modules with python.
|
||||
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
let child_module: Bound<'_, PyModule> = PyModule::new(py, "http_client")?;
|
||||
child_module.add_class::<HttpClient>()?;
|
||||
|
||||
// Make sure we fail early if we can't build the lazy statics.
|
||||
LazyLock::force(&RUNTIME);
|
||||
LazyLock::force(&DEFERRED_CLASS);
|
||||
|
||||
m.add_submodule(&child_module)?;
|
||||
|
||||
// We need to manually add the module to sys.modules to make `from
|
||||
// synapse.synapse_rust import acl` work.
|
||||
py.import("sys")?
|
||||
.getattr("modules")?
|
||||
.set_item("synapse.synapse_rust.http_client", child_module)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
struct HttpClient {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl HttpClient {
|
||||
#[new]
|
||||
pub fn py_new(user_agent: &str) -> PyResult<HttpClient> {
|
||||
// The twisted reactor can only be imported after Synapse has been
|
||||
// imported, to allow Synapse to change the twisted reactor. If we try
|
||||
// and import the reactor too early twisted installs a default reactor,
|
||||
// which can't be replaced.
|
||||
LazyLock::force(&TWISTED_REACTOR);
|
||||
|
||||
Ok(HttpClient {
|
||||
client: reqwest::Client::builder()
|
||||
.user_agent(user_agent)
|
||||
.build()
|
||||
.context("building reqwest client")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get<'a>(
|
||||
&self,
|
||||
py: Python<'a>,
|
||||
url: String,
|
||||
response_limit: usize,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
self.send_request(py, self.client.get(url), response_limit)
|
||||
}
|
||||
|
||||
pub fn post<'a>(
|
||||
&self,
|
||||
py: Python<'a>,
|
||||
url: String,
|
||||
response_limit: usize,
|
||||
headers: HashMap<String, String>,
|
||||
request_body: String,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let mut builder = self.client.post(url);
|
||||
for (name, value) in headers {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
builder = builder.body(request_body);
|
||||
|
||||
self.send_request(py, builder, response_limit)
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
fn send_request<'a>(
|
||||
&self,
|
||||
py: Python<'a>,
|
||||
builder: RequestBuilder,
|
||||
response_limit: usize,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
create_deferred(py, async move {
|
||||
let response = builder.send().await.context("sending request")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = Vec::new();
|
||||
while let Some(chunk) = stream.try_next().await.context("reading body")? {
|
||||
if buffer.len() + chunk.len() > response_limit {
|
||||
Err(anyhow::anyhow!("Response size too large"))?;
|
||||
}
|
||||
|
||||
buffer.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(HttpResponseException::new(status, buffer));
|
||||
}
|
||||
|
||||
let r = Python::with_gil(|py| buffer.into_pyobject(py).map(|o| o.unbind()))?;
|
||||
|
||||
Ok(r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a twisted deferred from the given future, spawning the task on the
|
||||
/// tokio runtime.
|
||||
///
|
||||
/// Does not handle deferred cancellation or contextvars.
|
||||
fn create_deferred<F, O>(py: Python, fut: F) -> PyResult<Bound<'_, PyAny>>
|
||||
where
|
||||
F: Future<Output = PyResult<O>> + Send + 'static,
|
||||
for<'a> O: IntoPyObject<'a>,
|
||||
{
|
||||
let deferred = DEFERRED_CLASS.bind(py).call0()?;
|
||||
let deferred_callback = deferred.getattr("callback")?.unbind();
|
||||
let deferred_errback = deferred.getattr("errback")?.unbind();
|
||||
|
||||
RUNTIME.spawn(async move {
|
||||
// TODO: Is it safe to assert unwind safety here? I think so, as we
|
||||
// don't use anything that could be tainted by the panic afterwards.
|
||||
// Note that `.spawn(..)` asserts unwind safety on the future too.
|
||||
let res = AssertUnwindSafe(fut).catch_unwind().await;
|
||||
|
||||
Python::with_gil(move |py| {
|
||||
// Flatten the panic into standard python error
|
||||
let res = match res {
|
||||
Ok(r) => r,
|
||||
Err(panic_err) => {
|
||||
let panic_message = get_panic_message(&panic_err);
|
||||
Err(PyException::new_err(
|
||||
PyString::new(py, panic_message).unbind(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
|
||||
match res {
|
||||
Ok(obj) => {
|
||||
TWISTED_REACTOR
|
||||
.call_method(py, "callFromThread", (deferred_callback, obj), None)
|
||||
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
|
||||
}
|
||||
Err(err) => {
|
||||
TWISTED_REACTOR
|
||||
.call_method(py, "callFromThread", (deferred_errback, err), None)
|
||||
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Ok(deferred)
|
||||
}
|
||||
|
||||
/// Try and get the panic message out of the panic
|
||||
fn get_panic_message<'a>(panic_err: &'a (dyn std::any::Any + Send + 'static)) -> &'a str {
|
||||
// Apparently this is how you extract the panic message from a panic
|
||||
if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
|
||||
str_slice
|
||||
} else if let Some(string) = panic_err.downcast_ref::<String>() {
|
||||
string
|
||||
} else {
|
||||
"unknown error"
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ pub mod acl;
|
||||
pub mod errors;
|
||||
pub mod events;
|
||||
pub mod http;
|
||||
pub mod http_client;
|
||||
pub mod identifier;
|
||||
pub mod matrix_const;
|
||||
pub mod push;
|
||||
@@ -50,6 +51,7 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
acl::register_module(py, m)?;
|
||||
push::register_module(py, m)?;
|
||||
events::register_module(py, m)?;
|
||||
http_client::register_module(py, m)?;
|
||||
rendezvous::register_module(py, m)?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -30,9 +30,6 @@ from authlib.oauth2.rfc7662 import IntrospectionToken
|
||||
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from twisted.web.client import readBody
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.api.auth.base import BaseAuth
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -43,8 +40,14 @@ from synapse.api.errors import (
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.logging.opentracing import (
|
||||
active_span,
|
||||
force_tracing,
|
||||
inject_request_headers,
|
||||
start_active_span,
|
||||
)
|
||||
from synapse.synapse_rust.http_client import HttpClient
|
||||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
@@ -179,6 +182,10 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
self._admin_token: Callable[[], Optional[str]] = self._config.admin_token
|
||||
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||
|
||||
self._rust_http_client = HttpClient(
|
||||
user_agent=self._http_client.user_agent.decode("utf8")
|
||||
)
|
||||
|
||||
# # Token Introspection Cache
|
||||
# This remembers what users/devices are represented by which access tokens,
|
||||
# in order to reduce overall system load:
|
||||
@@ -301,7 +308,6 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
introspection_endpoint = await self._introspection_endpoint()
|
||||
raw_headers: Dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": str(self._http_client.user_agent, "utf-8"),
|
||||
"Accept": "application/json",
|
||||
# Tell MAS that we support reading the device ID as an explicit
|
||||
# value, not encoded in the scope. This is supported by MAS 0.15+
|
||||
@@ -315,38 +321,34 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
uri, raw_headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
|
||||
)
|
||||
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||
|
||||
# Do the actual request
|
||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||
# check the HTTP status code, and we do the body encoding ourselves.
|
||||
|
||||
logger.debug("Fetching token from MAS")
|
||||
start_time = self._clock.time()
|
||||
try:
|
||||
response = await self._http_client.request(
|
||||
method="POST",
|
||||
uri=uri,
|
||||
data=body.encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
resp_body = await make_deferred_yieldable(readBody(response))
|
||||
with start_active_span("mas-introspect-token"):
|
||||
inject_request_headers(raw_headers)
|
||||
with PreserveLoggingContext():
|
||||
resp_body = await self._rust_http_client.post(
|
||||
url=uri,
|
||||
response_limit=1 * 1024 * 1024,
|
||||
headers=raw_headers,
|
||||
request_body=body,
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels(e.code).observe(end_time - start_time)
|
||||
raise
|
||||
except Exception:
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels("ERR").observe(end_time - start_time)
|
||||
raise
|
||||
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels(response.code).observe(
|
||||
end_time - start_time
|
||||
)
|
||||
logger.debug("Fetched token from MAS")
|
||||
|
||||
if response.code < 200 or response.code >= 300:
|
||||
raise HttpResponseException(
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
resp_body,
|
||||
)
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels(200).observe(end_time - start_time)
|
||||
|
||||
resp = json_decoder.decode(resp_body.decode("utf-8"))
|
||||
|
||||
|
||||
@@ -796,6 +796,13 @@ def inject_response_headers(response_headers: Headers) -> None:
|
||||
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
|
||||
|
||||
|
||||
@ensure_active_span("inject the span into a header dict")
|
||||
def inject_request_headers(headers: Dict[str, str]) -> None:
|
||||
span = opentracing.tracer.active_span
|
||||
assert span is not None
|
||||
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, headers)
|
||||
|
||||
|
||||
@ensure_active_span(
|
||||
"get the active span context as a dict", ret=cast(Dict[str, str], {})
|
||||
)
|
||||
|
||||
24
synapse/synapse_rust/http_client.pyi
Normal file
24
synapse/synapse_rust/http_client.pyi
Normal file
@@ -0,0 +1,24 @@
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
from typing import Awaitable, Mapping
|
||||
|
||||
class HttpClient:
|
||||
def __init__(self, user_agent: str) -> None: ...
|
||||
def get(self, url: str, response_limit: int) -> Awaitable[bytes]: ...
|
||||
def post(
|
||||
self,
|
||||
url: str,
|
||||
response_limit: int,
|
||||
headers: Mapping[str, str],
|
||||
request_body: str,
|
||||
) -> Awaitable[bytes]: ...
|
||||
@@ -19,9 +19,10 @@
|
||||
#
|
||||
#
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Union
|
||||
from unittest.mock import ANY, AsyncMock, Mock
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@@ -33,12 +34,11 @@ from signedjson.key import (
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IResponse
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
HttpResponseException,
|
||||
InvalidClientTokenError,
|
||||
OAuthInsufficientScopeError,
|
||||
SynapseError,
|
||||
@@ -52,7 +52,7 @@ from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.server import FakeChannel
|
||||
from tests.test_utils import FakeResponse, get_awaitable_result
|
||||
from tests.test_utils import get_awaitable_result
|
||||
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
||||
from tests.utils import HAS_AUTHLIB, checked_cast, mock_getRawHeaders
|
||||
|
||||
@@ -145,6 +145,9 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
|
||||
self.auth = checked_cast(MSC3861DelegatedAuth, hs.get_auth())
|
||||
|
||||
self._rust_client = Mock(spec=["post"])
|
||||
self.auth._rust_http_client = self._rust_client
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(
|
||||
@@ -157,9 +160,15 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
store.store_device(USER_ID, DEVICE, initial_device_display_name=None)
|
||||
)
|
||||
|
||||
def _set_introspection_returnvalue(self, response_value: Any) -> AsyncMock:
|
||||
self._rust_client.post = mock = AsyncMock(
|
||||
return_value=json.dumps(response_value).encode("utf-8")
|
||||
)
|
||||
return mock
|
||||
|
||||
def _assertParams(self) -> None:
|
||||
"""Assert that the request parameters are correct."""
|
||||
params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
|
||||
params = parse_qs(self._rust_client.post.call_args[1]["request_body"])
|
||||
self.assertEqual(params["token"], ["mockAccessToken"])
|
||||
self.assertEqual(params["client_id"], [CLIENT_ID])
|
||||
self.assertEqual(params["client_secret"], [CLIENT_SECRET])
|
||||
@@ -167,128 +176,125 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_inactive_token(self) -> None:
|
||||
"""The handler should return a 403 where the token is inactive."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": False},
|
||||
)
|
||||
)
|
||||
self._set_introspection_returnvalue({"active": False})
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_no_scope(self) -> None:
|
||||
"""The handler should return a 403 where no scope is given."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": True},
|
||||
)
|
||||
)
|
||||
self._set_introspection_returnvalue({"active": True})
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_user_no_subject(self) -> None:
|
||||
"""The handler should return a 500 when no subject is present."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}
|
||||
)
|
||||
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_no_user_scope(self) -> None:
|
||||
"""The handler should return a 500 when no subject is present."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_DEVICE_SCOPE]),
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_DEVICE_SCOPE]),
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_admin_not_user(self) -> None:
|
||||
"""The handler should raise when the scope has admin right but not user."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_admin(self) -> None:
|
||||
"""The handler should return a requester with admin rights."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -301,26 +307,26 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_admin_highest_privilege(self) -> None:
|
||||
"""The handler should resolve to the most permissive scope."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
|
||||
),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
|
||||
),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -333,24 +339,24 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_user(self) -> None:
|
||||
"""The handler should return a requester with normal user rights."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -363,24 +369,24 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_user_with_device(self) -> None:
|
||||
"""The handler should return a requester with normal user rights and a device ID."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -393,32 +399,32 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_user_with_device_explicit_device_id(self) -> None:
|
||||
"""The handler should return a requester with normal user rights and a device ID, given explicitly, as supported by MAS 0.15+"""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"device_id": DEVICE,
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"device_id": DEVICE,
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
# It should have called with the 'X-MAS-Supports-Device-Id: 1' header
|
||||
self.assertEqual(
|
||||
self.http_client.request.call_args[1]["headers"].getRawHeaders(
|
||||
b"X-MAS-Supports-Device-Id",
|
||||
self._rust_client.post.call_args[1]["headers"].get(
|
||||
"X-MAS-Supports-Device-Id",
|
||||
),
|
||||
[b"1"],
|
||||
"1",
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -431,22 +437,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_multiple_devices(self) -> None:
|
||||
"""The handler should raise an error if multiple devices are found in the scope."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
@@ -456,16 +459,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_guest_not_allowed(self) -> None:
|
||||
"""The handler should return an insufficient scope error."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
@@ -474,8 +474,11 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
self.auth.get_user_by_req(request), OAuthInsufficientScopeError
|
||||
)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(
|
||||
@@ -486,16 +489,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_guest_allowed(self) -> None:
|
||||
"""The handler should return a requester with guest user rights and a device ID."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
@@ -504,8 +504,11 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
self.auth.get_user_by_req(request, allow_guest=True)
|
||||
)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -522,30 +525,28 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
# The introspection endpoint is returning an error.
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse(code=500, body=b"Internal Server Error")
|
||||
)
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint request fails.
|
||||
self.http_client.request = AsyncMock(side_effect=Exception())
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return a JSON object.
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200, payload=["this is an array", "not an object"]
|
||||
self._rust_client.post = AsyncMock(
|
||||
side_effect=HttpResponseException(
|
||||
code=500, msg="Internal Server Error", response=b"{}"
|
||||
)
|
||||
)
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint request fails.
|
||||
self._rust_client.post = AsyncMock(side_effect=Exception())
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return a JSON object.
|
||||
self._set_introspection_returnvalue(["this is an array", "not an object"])
|
||||
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return valid JSON.
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
|
||||
)
|
||||
self._set_introspection_returnvalue("this is not valid JSON")
|
||||
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
@@ -554,23 +555,21 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
an expiry time, the introspection response is cached and then the entry is
|
||||
re-requested after it has expired."""
|
||||
|
||||
self.http_client.request = introspection_mock = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
introspection_mock = self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
"expires_in": 60,
|
||||
}
|
||||
)
|
||||
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
@@ -607,16 +606,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_cross_signing(self) -> None:
|
||||
"""Try uploading device keys with OAuth delegation enabled."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
|
||||
channel = self.make_request(
|
||||
@@ -778,16 +774,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
|
||||
# Because we still support those endpoints with ASes, it checks the
|
||||
# access token before returning 404
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True)
|
||||
@@ -820,9 +813,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
|
||||
def test_admin_token(self) -> None:
|
||||
"""The handler should return a requester with admin rights when admin_token is used."""
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(code=200, payload={"active": False}),
|
||||
)
|
||||
self._set_introspection_returnvalue({"active": False})
|
||||
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"admin_token_value"]
|
||||
@@ -839,7 +830,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# There should be no call to the introspection endpoint
|
||||
self.http_client.request.assert_not_called()
|
||||
self._rust_client.post.assert_not_called()
|
||||
|
||||
@override_config({"mau_stats_only": True})
|
||||
def test_request_tracking(self) -> None:
|
||||
@@ -852,28 +843,23 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
known_token = "token-token-GOOD-:)"
|
||||
|
||||
async def mock_http_client_request(
|
||||
method: str,
|
||||
uri: str,
|
||||
data: Optional[bytes] = None,
|
||||
headers: Optional[Headers] = None,
|
||||
) -> IResponse:
|
||||
url: str, request_body: str, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Mocked auth provider response."""
|
||||
assert method == "POST"
|
||||
token = parse_qs(data)[b"token"][0].decode("utf-8")
|
||||
token = parse_qs(request_body)["token"][0]
|
||||
if token == known_token:
|
||||
return FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
return json.dumps(
|
||||
{
|
||||
"active": True,
|
||||
"scope": MATRIX_USER_SCOPE,
|
||||
"sub": SUBJECT,
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
).encode("utf-8")
|
||||
|
||||
return FakeResponse.json(code=200, payload={"active": False})
|
||||
return json.dumps({"active": False}).encode("utf-8")
|
||||
|
||||
self.http_client.request = mock_http_client_request
|
||||
self._rust_client.post = mock_http_client_request
|
||||
|
||||
EXAMPLE_IPV4_ADDR = "123.123.123.123"
|
||||
EXAMPLE_USER_AGENT = "httprettygood"
|
||||
|
||||
Reference in New Issue
Block a user