mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
18 Commits
v1.139.1
...
erikj/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25c05d4fb4 | ||
|
|
7283b6bcd8 | ||
|
|
661734111d | ||
|
|
f2bd5b042e | ||
|
|
f501987677 | ||
|
|
48d2217f2d | ||
|
|
540b46088a | ||
|
|
d5845d5442 | ||
|
|
a8a34bb811 | ||
|
|
59473e3377 | ||
|
|
236d7a7ef0 | ||
|
|
61b7ed02d8 | ||
|
|
27da4ecde2 | ||
|
|
44a301a32f | ||
|
|
34c65ee95b | ||
|
|
95f3b249c4 | ||
|
|
7254716b7e | ||
|
|
a20a0bcbe2 |
22
.github/workflows/tests.yml
vendored
22
.github/workflows/tests.yml
vendored
@@ -85,7 +85,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
|
||||
with:
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- name: Setup Poetry
|
||||
@@ -210,7 +210,7 @@ jobs:
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
|
||||
with:
|
||||
@@ -227,7 +227,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@0d72692bcfbf448b1e2afa01a67f71b455a9dcec # 1.86.0
|
||||
with:
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@56f84321dbccf38fb67ce29ab63e4754056677e0 # master (rust 1.85.1)
|
||||
with:
|
||||
toolchain: nightly-2022-12-01
|
||||
toolchain: nightly-2025-04-23
|
||||
components: clippy
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
@@ -265,7 +265,7 @@ jobs:
|
||||
uses: dtolnay/rust-toolchain@56f84321dbccf38fb67ce29ab63e4754056677e0 # master (rust 1.85.1)
|
||||
with:
|
||||
# We use nightly so that it correctly groups together imports
|
||||
toolchain: nightly-2022-12-01
|
||||
toolchain: nightly-2025-04-23
|
||||
components: rustfmt
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
@@ -362,7 +362,7 @@ jobs:
|
||||
postgres:${{ matrix.job.postgres-version }}
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- uses: matrix-org/setup-python-poetry@5bbf6603c5c930615ec8a29f1b5d7d258d905aa4 # v2.0.0
|
||||
@@ -404,7 +404,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
# There aren't wheels for some of the older deps, so we need to install
|
||||
@@ -519,7 +519,7 @@ jobs:
|
||||
run: cat sytest-blacklist .ci/worker-blacklist > synapse-blacklist-with-workers
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- name: Run SyTest
|
||||
@@ -663,7 +663,7 @@ jobs:
|
||||
path: synapse
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- name: Prepare Complement's Prerequisites
|
||||
@@ -695,7 +695,7 @@ jobs:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Install Rust
|
||||
uses: dtolnay/rust-toolchain@e05ebb0e73db581a4877c6ce762e29fe1e0b5073 # 1.66.0
|
||||
uses: dtolnay/rust-toolchain@c1678930c21fb233e4987c4ae12158f9125e5762 # 1.81.0
|
||||
- uses: Swatinem/rust-cache@9d47c6ad4b02e050fd481d890b2ea34778fd09d6 # v2.7.8
|
||||
|
||||
- run: cargo test
|
||||
|
||||
1333
Cargo.lock
generated
1333
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
1
changelog.d/18357.misc
Normal file
1
changelog.d/18357.misc
Normal file
@@ -0,0 +1 @@
|
||||
Increase performance of introspecting access tokens when using delegated auth.
|
||||
@@ -7,7 +7,7 @@ name = "synapse"
|
||||
version = "0.1.0"
|
||||
|
||||
edition = "2021"
|
||||
rust-version = "1.66.0"
|
||||
rust-version = "1.81.0"
|
||||
|
||||
[lib]
|
||||
name = "synapse"
|
||||
@@ -36,13 +36,21 @@ pyo3 = { version = "0.24.2", features = [
|
||||
"abi3",
|
||||
"abi3-py39",
|
||||
] }
|
||||
pyo3-log = "0.12.0"
|
||||
pyo3-log = "0.12.3"
|
||||
pythonize = "0.24.0"
|
||||
regex = "1.6.0"
|
||||
sha2 = "0.10.8"
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
serde_json = "1.0.85"
|
||||
ulid = "1.1.2"
|
||||
reqwest = { version = "0.12.15", default-features = false, features = [
|
||||
"http2",
|
||||
"stream",
|
||||
"rustls-tls-native-roots",
|
||||
] }
|
||||
http-body-util = "0.1.3"
|
||||
futures = "0.3.31"
|
||||
tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] }
|
||||
|
||||
[features]
|
||||
extension-module = ["pyo3/extension-module"]
|
||||
|
||||
@@ -58,3 +58,15 @@ impl NotFoundError {
|
||||
NotFoundError::new_err(())
|
||||
}
|
||||
}
|
||||
|
||||
import_exception!(synapse.api.errors, HttpResponseException);
|
||||
|
||||
impl HttpResponseException {
|
||||
pub fn new(status: StatusCode, bytes: Vec<u8>) -> pyo3::PyErr {
|
||||
HttpResponseException::new_err((
|
||||
status.as_u16(),
|
||||
status.canonical_reason().unwrap_or_default(),
|
||||
bytes,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
218
rust/src/http_client.rs
Normal file
218
rust/src/http_client.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2025 New Vector, Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*/
|
||||
|
||||
use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, sync::LazyLock};
|
||||
|
||||
use anyhow::Context;
|
||||
use futures::{FutureExt, TryStreamExt};
|
||||
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
|
||||
use reqwest::RequestBuilder;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use crate::errors::HttpResponseException;
|
||||
|
||||
/// The tokio runtime that we're using to run async Rust libs.
|
||||
static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
/// A reference to the `Deferred` python class.
|
||||
static DEFERRED_CLASS: LazyLock<PyObject> = LazyLock::new(|| {
|
||||
Python::with_gil(|py| {
|
||||
py.import("twisted.internet.defer")
|
||||
.expect("module 'twisted.internet.defer' should be importable")
|
||||
.getattr("Deferred")
|
||||
.expect("module 'twisted.internet.defer' should have a 'Deferred' class")
|
||||
.unbind()
|
||||
})
|
||||
});
|
||||
|
||||
/// A reference to the twisted `reactor`.
|
||||
static TWISTED_REACTOR: LazyLock<Py<PyModule>> = LazyLock::new(|| {
|
||||
Python::with_gil(|py| {
|
||||
py.import("twisted.internet.reactor")
|
||||
.expect("module 'twisted.internet.reactor' should be importable")
|
||||
.unbind()
|
||||
})
|
||||
});
|
||||
|
||||
/// Called when registering modules with python.
|
||||
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
let child_module: Bound<'_, PyModule> = PyModule::new(py, "http_client")?;
|
||||
child_module.add_class::<HttpClient>()?;
|
||||
|
||||
// Make sure we fail early if we can't build the lazy statics.
|
||||
LazyLock::force(&RUNTIME);
|
||||
LazyLock::force(&DEFERRED_CLASS);
|
||||
|
||||
m.add_submodule(&child_module)?;
|
||||
|
||||
// We need to manually add the module to sys.modules to make `from
|
||||
// synapse.synapse_rust import acl` work.
|
||||
py.import("sys")?
|
||||
.getattr("modules")?
|
||||
.set_item("synapse.synapse_rust.http_client", child_module)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
struct HttpClient {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl HttpClient {
|
||||
#[new]
|
||||
pub fn py_new(user_agent: &str) -> PyResult<HttpClient> {
|
||||
// The twisted reactor can only be imported after Synapse has been
|
||||
// imported, to allow Synapse to change the twisted reactor. If we try
|
||||
// and import the reactor too early twisted installs a default reactor,
|
||||
// which can't be replaced.
|
||||
LazyLock::force(&TWISTED_REACTOR);
|
||||
|
||||
Ok(HttpClient {
|
||||
client: reqwest::Client::builder()
|
||||
.user_agent(user_agent)
|
||||
.build()
|
||||
.context("building reqwest client")?,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get<'a>(
|
||||
&self,
|
||||
py: Python<'a>,
|
||||
url: String,
|
||||
response_limit: usize,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
self.send_request(py, self.client.get(url), response_limit)
|
||||
}
|
||||
|
||||
pub fn post<'a>(
|
||||
&self,
|
||||
py: Python<'a>,
|
||||
url: String,
|
||||
response_limit: usize,
|
||||
headers: HashMap<String, String>,
|
||||
request_body: String,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
let mut builder = self.client.post(url);
|
||||
for (name, value) in headers {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
builder = builder.body(request_body);
|
||||
|
||||
self.send_request(py, builder, response_limit)
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
fn send_request<'a>(
|
||||
&self,
|
||||
py: Python<'a>,
|
||||
builder: RequestBuilder,
|
||||
response_limit: usize,
|
||||
) -> PyResult<Bound<'a, PyAny>> {
|
||||
create_deferred(py, async move {
|
||||
let response = builder.send().await.context("sending request")?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut buffer = Vec::new();
|
||||
while let Some(chunk) = stream.try_next().await.context("reading body")? {
|
||||
if buffer.len() + chunk.len() > response_limit {
|
||||
Err(anyhow::anyhow!("Response size too large"))?;
|
||||
}
|
||||
|
||||
buffer.extend_from_slice(&chunk);
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(HttpResponseException::new(status, buffer));
|
||||
}
|
||||
|
||||
let r = Python::with_gil(|py| buffer.into_pyobject(py).map(|o| o.unbind()))?;
|
||||
|
||||
Ok(r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a twisted deferred from the given future, spawning the task on the
|
||||
/// tokio runtime.
|
||||
///
|
||||
/// Does not handle deferred cancellation or contextvars.
|
||||
fn create_deferred<F, O>(py: Python, fut: F) -> PyResult<Bound<'_, PyAny>>
|
||||
where
|
||||
F: Future<Output = PyResult<O>> + Send + 'static,
|
||||
for<'a> O: IntoPyObject<'a>,
|
||||
{
|
||||
let deferred = DEFERRED_CLASS.bind(py).call0()?;
|
||||
let deferred_callback = deferred.getattr("callback")?.unbind();
|
||||
let deferred_errback = deferred.getattr("errback")?.unbind();
|
||||
|
||||
RUNTIME.spawn(async move {
|
||||
// TODO: Is it safe to assert unwind safety here? I think so, as we
|
||||
// don't use anything that could be tainted by the panic afterwards.
|
||||
// Note that `.spawn(..)` asserts unwind safety on the future too.
|
||||
let res = AssertUnwindSafe(fut).catch_unwind().await;
|
||||
|
||||
Python::with_gil(move |py| {
|
||||
// Flatten the panic into standard python error
|
||||
let res = match res {
|
||||
Ok(r) => r,
|
||||
Err(panic_err) => {
|
||||
let panic_message = get_panic_message(&panic_err);
|
||||
Err(PyException::new_err(
|
||||
PyString::new(py, panic_message).unbind(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
|
||||
match res {
|
||||
Ok(obj) => {
|
||||
TWISTED_REACTOR
|
||||
.call_method(py, "callFromThread", (deferred_callback, obj), None)
|
||||
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
|
||||
}
|
||||
Err(err) => {
|
||||
TWISTED_REACTOR
|
||||
.call_method(py, "callFromThread", (deferred_errback, err), None)
|
||||
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Ok(deferred)
|
||||
}
|
||||
|
||||
/// Try and get the panic message out of the panic
|
||||
fn get_panic_message<'a>(panic_err: &'a (dyn std::any::Any + Send + 'static)) -> &'a str {
|
||||
// Apparently this is how you extract the panic message from a panic
|
||||
if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
|
||||
str_slice
|
||||
} else if let Some(string) = panic_err.downcast_ref::<String>() {
|
||||
string
|
||||
} else {
|
||||
"unknown error"
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ pub mod acl;
|
||||
pub mod errors;
|
||||
pub mod events;
|
||||
pub mod http;
|
||||
pub mod http_client;
|
||||
pub mod identifier;
|
||||
pub mod matrix_const;
|
||||
pub mod push;
|
||||
@@ -50,6 +51,7 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
acl::register_module(py, m)?;
|
||||
push::register_module(py, m)?;
|
||||
events::register_module(py, m)?;
|
||||
http_client::register_module(py, m)?;
|
||||
rendezvous::register_module(py, m)?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -30,9 +30,6 @@ from authlib.oauth2.rfc7662 import IntrospectionToken
|
||||
from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from twisted.web.client import readBody
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.api.auth.base import BaseAuth
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -43,8 +40,14 @@ from synapse.api.errors import (
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.logging.opentracing import (
|
||||
active_span,
|
||||
force_tracing,
|
||||
inject_request_headers,
|
||||
start_active_span,
|
||||
)
|
||||
from synapse.synapse_rust.http_client import HttpClient
|
||||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
@@ -179,6 +182,10 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
self._admin_token: Callable[[], Optional[str]] = self._config.admin_token
|
||||
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
|
||||
|
||||
self._rust_http_client = HttpClient(
|
||||
user_agent=self._http_client.user_agent.decode("utf8")
|
||||
)
|
||||
|
||||
# # Token Introspection Cache
|
||||
# This remembers what users/devices are represented by which access tokens,
|
||||
# in order to reduce overall system load:
|
||||
@@ -301,7 +308,6 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
introspection_endpoint = await self._introspection_endpoint()
|
||||
raw_headers: Dict[str, str] = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": str(self._http_client.user_agent, "utf-8"),
|
||||
"Accept": "application/json",
|
||||
# Tell MAS that we support reading the device ID as an explicit
|
||||
# value, not encoded in the scope. This is supported by MAS 0.15+
|
||||
@@ -315,38 +321,34 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
uri, raw_headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body
|
||||
)
|
||||
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||
|
||||
# Do the actual request
|
||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||
# check the HTTP status code, and we do the body encoding ourselves.
|
||||
|
||||
logger.debug("Fetching token from MAS")
|
||||
start_time = self._clock.time()
|
||||
try:
|
||||
response = await self._http_client.request(
|
||||
method="POST",
|
||||
uri=uri,
|
||||
data=body.encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
resp_body = await make_deferred_yieldable(readBody(response))
|
||||
with start_active_span("mas-introspect-token"):
|
||||
inject_request_headers(raw_headers)
|
||||
with PreserveLoggingContext():
|
||||
resp_body = await self._rust_http_client.post(
|
||||
url=uri,
|
||||
response_limit=1 * 1024 * 1024,
|
||||
headers=raw_headers,
|
||||
request_body=body,
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels(e.code).observe(end_time - start_time)
|
||||
raise
|
||||
except Exception:
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels("ERR").observe(end_time - start_time)
|
||||
raise
|
||||
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels(response.code).observe(
|
||||
end_time - start_time
|
||||
)
|
||||
logger.debug("Fetched token from MAS")
|
||||
|
||||
if response.code < 200 or response.code >= 300:
|
||||
raise HttpResponseException(
|
||||
response.code,
|
||||
response.phrase.decode("ascii", errors="replace"),
|
||||
resp_body,
|
||||
)
|
||||
end_time = self._clock.time()
|
||||
introspection_response_timer.labels(200).observe(end_time - start_time)
|
||||
|
||||
resp = json_decoder.decode(resp_body.decode("utf-8"))
|
||||
|
||||
|
||||
@@ -796,6 +796,13 @@ def inject_response_headers(response_headers: Headers) -> None:
|
||||
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
|
||||
|
||||
|
||||
@ensure_active_span("inject the span into a header dict")
|
||||
def inject_request_headers(headers: Dict[str, str]) -> None:
|
||||
span = opentracing.tracer.active_span
|
||||
assert span is not None
|
||||
opentracing.tracer.inject(span.context, opentracing.Format.HTTP_HEADERS, headers)
|
||||
|
||||
|
||||
@ensure_active_span(
|
||||
"get the active span context as a dict", ret=cast(Dict[str, str], {})
|
||||
)
|
||||
|
||||
24
synapse/synapse_rust/http_client.pyi
Normal file
24
synapse/synapse_rust/http_client.pyi
Normal file
@@ -0,0 +1,24 @@
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
from typing import Awaitable, Mapping
|
||||
|
||||
class HttpClient:
|
||||
def __init__(self, user_agent: str) -> None: ...
|
||||
def get(self, url: str, response_limit: int) -> Awaitable[bytes]: ...
|
||||
def post(
|
||||
self,
|
||||
url: str,
|
||||
response_limit: int,
|
||||
headers: Mapping[str, str],
|
||||
request_body: str,
|
||||
) -> Awaitable[bytes]: ...
|
||||
@@ -19,9 +19,10 @@
|
||||
#
|
||||
#
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, Union
|
||||
from unittest.mock import ANY, AsyncMock, Mock
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
@@ -33,12 +34,11 @@ from signedjson.key import (
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IResponse
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
HttpResponseException,
|
||||
InvalidClientTokenError,
|
||||
OAuthInsufficientScopeError,
|
||||
SynapseError,
|
||||
@@ -52,7 +52,7 @@ from synapse.types import JsonDict, UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.server import FakeChannel
|
||||
from tests.test_utils import FakeResponse, get_awaitable_result
|
||||
from tests.test_utils import get_awaitable_result
|
||||
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
||||
from tests.utils import HAS_AUTHLIB, checked_cast, mock_getRawHeaders
|
||||
|
||||
@@ -145,6 +145,9 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
|
||||
self.auth = checked_cast(MSC3861DelegatedAuth, hs.get_auth())
|
||||
|
||||
self._rust_client = Mock(spec=["post"])
|
||||
self.auth._rust_http_client = self._rust_client
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(
|
||||
@@ -157,9 +160,15 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
store.store_device(USER_ID, DEVICE, initial_device_display_name=None)
|
||||
)
|
||||
|
||||
def _set_introspection_returnvalue(self, response_value: Any) -> AsyncMock:
|
||||
self._rust_client.post = mock = AsyncMock(
|
||||
return_value=json.dumps(response_value).encode("utf-8")
|
||||
)
|
||||
return mock
|
||||
|
||||
def _assertParams(self) -> None:
|
||||
"""Assert that the request parameters are correct."""
|
||||
params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8"))
|
||||
params = parse_qs(self._rust_client.post.call_args[1]["request_body"])
|
||||
self.assertEqual(params["token"], ["mockAccessToken"])
|
||||
self.assertEqual(params["client_id"], [CLIENT_ID])
|
||||
self.assertEqual(params["client_secret"], [CLIENT_SECRET])
|
||||
@@ -167,128 +176,125 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_inactive_token(self) -> None:
|
||||
"""The handler should return a 403 where the token is inactive."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": False},
|
||||
)
|
||||
)
|
||||
self._set_introspection_returnvalue({"active": False})
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_no_scope(self) -> None:
|
||||
"""The handler should return a 403 where no scope is given."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": True},
|
||||
)
|
||||
)
|
||||
self._set_introspection_returnvalue({"active": True})
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_user_no_subject(self) -> None:
|
||||
"""The handler should return a 500 when no subject is present."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}
|
||||
)
|
||||
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_no_user_scope(self) -> None:
|
||||
"""The handler should return a 500 when no subject is present."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_DEVICE_SCOPE]),
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_DEVICE_SCOPE]),
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_admin_not_user(self) -> None:
|
||||
"""The handler should raise when the scope has admin right but not user."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
|
||||
def test_active_admin(self) -> None:
|
||||
"""The handler should return a requester with admin rights."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -301,26 +307,26 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_admin_highest_privilege(self) -> None:
|
||||
"""The handler should resolve to the most permissive scope."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
|
||||
),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE]
|
||||
),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -333,24 +339,24 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_user(self) -> None:
|
||||
"""The handler should return a requester with normal user rights."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -363,24 +369,24 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_user_with_device(self) -> None:
|
||||
"""The handler should return a requester with normal user rights and a device ID."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -393,32 +399,32 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_user_with_device_explicit_device_id(self) -> None:
|
||||
"""The handler should return a requester with normal user rights and a device ID, given explicitly, as supported by MAS 0.15+"""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"device_id": DEVICE,
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE]),
|
||||
"device_id": DEVICE,
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
requester = self.get_success(self.auth.get_user_by_req(request))
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
# It should have called with the 'X-MAS-Supports-Device-Id: 1' header
|
||||
self.assertEqual(
|
||||
self.http_client.request.call_args[1]["headers"].getRawHeaders(
|
||||
b"X-MAS-Supports-Device-Id",
|
||||
self._rust_client.post.call_args[1]["headers"].get(
|
||||
"X-MAS-Supports-Device-Id",
|
||||
),
|
||||
[b"1"],
|
||||
"1",
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -431,22 +437,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_multiple_devices(self) -> None:
|
||||
"""The handler should raise an error if multiple devices are found in the scope."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
@@ -456,16 +459,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_guest_not_allowed(self) -> None:
|
||||
"""The handler should return an insufficient scope error."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
@@ -474,8 +474,11 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
self.auth.get_user_by_req(request), OAuthInsufficientScopeError
|
||||
)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(
|
||||
@@ -486,16 +489,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_active_guest_allowed(self) -> None:
|
||||
"""The handler should return a requester with guest user rights and a device ID."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
@@ -504,8 +504,11 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
self.auth.get_user_by_req(request, allow_guest=True)
|
||||
)
|
||||
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
||||
self.http_client.request.assert_called_once_with(
|
||||
method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY
|
||||
self._rust_client.post.assert_called_once_with(
|
||||
url=INTROSPECTION_ENDPOINT,
|
||||
response_limit=ANY,
|
||||
request_body=ANY,
|
||||
headers=ANY,
|
||||
)
|
||||
self._assertParams()
|
||||
self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME))
|
||||
@@ -522,30 +525,28 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
|
||||
# The introspection endpoint is returning an error.
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse(code=500, body=b"Internal Server Error")
|
||||
)
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint request fails.
|
||||
self.http_client.request = AsyncMock(side_effect=Exception())
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return a JSON object.
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200, payload=["this is an array", "not an object"]
|
||||
self._rust_client.post = AsyncMock(
|
||||
side_effect=HttpResponseException(
|
||||
code=500, msg="Internal Server Error", response=b"{}"
|
||||
)
|
||||
)
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint request fails.
|
||||
self._rust_client.post = AsyncMock(side_effect=Exception())
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return a JSON object.
|
||||
self._set_introspection_returnvalue(["this is an array", "not an object"])
|
||||
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
# The introspection endpoint does not return valid JSON.
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
|
||||
)
|
||||
self._set_introspection_returnvalue("this is not valid JSON")
|
||||
|
||||
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
|
||||
self.assertEqual(error.value.code, 503)
|
||||
|
||||
@@ -554,23 +555,21 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
an expiry time, the introspection response is cached and then the entry is
|
||||
re-requested after it has expired."""
|
||||
|
||||
self.http_client.request = introspection_mock = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
introspection_mock = self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join(
|
||||
[
|
||||
MATRIX_USER_SCOPE,
|
||||
f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC",
|
||||
]
|
||||
),
|
||||
"username": USERNAME,
|
||||
"expires_in": 60,
|
||||
}
|
||||
)
|
||||
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"mockAccessToken"]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
@@ -607,16 +606,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
def test_cross_signing(self) -> None:
|
||||
"""Try uploading device keys with OAuth delegation enabled."""
|
||||
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
}
|
||||
)
|
||||
keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
|
||||
channel = self.make_request(
|
||||
@@ -778,16 +774,13 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
|
||||
# Because we still support those endpoints with ASes, it checks the
|
||||
# access token before returning 404
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
self._set_introspection_returnvalue(
|
||||
{
|
||||
"active": True,
|
||||
"sub": SUBJECT,
|
||||
"scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]),
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
|
||||
self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True)
|
||||
@@ -820,9 +813,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
|
||||
def test_admin_token(self) -> None:
|
||||
"""The handler should return a requester with admin rights when admin_token is used."""
|
||||
self.http_client.request = AsyncMock(
|
||||
return_value=FakeResponse.json(code=200, payload={"active": False}),
|
||||
)
|
||||
self._set_introspection_returnvalue({"active": False})
|
||||
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [b"admin_token_value"]
|
||||
@@ -839,7 +830,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# There should be no call to the introspection endpoint
|
||||
self.http_client.request.assert_not_called()
|
||||
self._rust_client.post.assert_not_called()
|
||||
|
||||
@override_config({"mau_stats_only": True})
|
||||
def test_request_tracking(self) -> None:
|
||||
@@ -852,28 +843,23 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
|
||||
known_token = "token-token-GOOD-:)"
|
||||
|
||||
async def mock_http_client_request(
|
||||
method: str,
|
||||
uri: str,
|
||||
data: Optional[bytes] = None,
|
||||
headers: Optional[Headers] = None,
|
||||
) -> IResponse:
|
||||
url: str, request_body: str, **kwargs: Any
|
||||
) -> bytes:
|
||||
"""Mocked auth provider response."""
|
||||
assert method == "POST"
|
||||
token = parse_qs(data)[b"token"][0].decode("utf-8")
|
||||
token = parse_qs(request_body)["token"][0]
|
||||
if token == known_token:
|
||||
return FakeResponse.json(
|
||||
code=200,
|
||||
payload={
|
||||
return json.dumps(
|
||||
{
|
||||
"active": True,
|
||||
"scope": MATRIX_USER_SCOPE,
|
||||
"sub": SUBJECT,
|
||||
"username": USERNAME,
|
||||
},
|
||||
)
|
||||
).encode("utf-8")
|
||||
|
||||
return FakeResponse.json(code=200, payload={"active": False})
|
||||
return json.dumps({"active": False}).encode("utf-8")
|
||||
|
||||
self.http_client.request = mock_http_client_request
|
||||
self._rust_client.post = mock_http_client_request
|
||||
|
||||
EXAMPLE_IPV4_ADDR = "123.123.123.123"
|
||||
EXAMPLE_USER_AGENT = "httprettygood"
|
||||
|
||||
Reference in New Issue
Block a user