Compare commits

...

2 Commits

Author SHA1 Message Date
Quentin Gliech
676ad7b523 Pass the reactor explicitly to the HTTP client 2025-07-17 14:57:06 +02:00
Quentin Gliech
906a5f7ecd Spawn Rust futures in a separate task 2025-07-17 14:36:01 +02:00
5 changed files with 68 additions and 58 deletions

1
Cargo.lock generated
View File

@@ -1476,6 +1476,7 @@ dependencies = [
"lazy_static",
"log",
"mime",
"once_cell",
"pyo3",
"pyo3-log",
"pythonize",

View File

@@ -52,6 +52,7 @@ reqwest = { version = "0.12.15", default-features = false, features = [
http-body-util = "0.1.3"
futures = "0.3.31"
tokio = { version = "1.44.2", features = ["rt", "rt-multi-thread"] }
once_cell = "1.18.0"
[features]
extension-module = ["pyo3/extension-module"]

View File

@@ -12,44 +12,49 @@
* <https://www.gnu.org/licenses/agpl-3.0.html>.
*/
use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, sync::LazyLock};
use std::{collections::HashMap, future::Future};
use anyhow::Context;
use futures::{FutureExt, TryStreamExt};
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
use futures::TryStreamExt;
use once_cell::sync::OnceCell;
use pyo3::{create_exception, exceptions::PyException, prelude::*, types::PyString};
use reqwest::RequestBuilder;
use tokio::runtime::Runtime;
use crate::errors::HttpResponseException;
create_exception!(
synapse.synapse_rust.http_client,
RustPanicError,
PyException,
"A panic which happened in a Rust future"
);
/// The tokio runtime that we're using to run async Rust libs.
static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()
.unwrap()
});
static RUNTIME: OnceCell<Runtime> = OnceCell::new();
/// A reference to the `Deferred` python class.
static DEFERRED_CLASS: LazyLock<PyObject> = LazyLock::new(|| {
Python::with_gil(|py| {
py.import("twisted.internet.defer")
.expect("module 'twisted.internet.defer' should be importable")
.getattr("Deferred")
.expect("module 'twisted.internet.defer' should have a 'Deferred' class")
.unbind()
})
});
/// A reference to the `twisted.internet.defer` module.
static DEFER: OnceCell<PyObject> = OnceCell::new();
/// A reference to the twisted `reactor`.
static TWISTED_REACTOR: LazyLock<Py<PyModule>> = LazyLock::new(|| {
Python::with_gil(|py| {
py.import("twisted.internet.reactor")
.expect("module 'twisted.internet.reactor' should be importable")
.unbind()
/// Access the tokio runtime.
fn runtime() -> PyResult<&'static Runtime> {
RUNTIME.get_or_try_init(|| {
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(4)
.enable_all()
.build()
.context("building tokio runtime")?;
Ok(runtime)
})
});
}
/// Access to the `twisted.internet.defer` module.
fn defer(py: Python<'_>) -> PyResult<&Bound<PyAny>> {
Ok(DEFER
.get_or_try_init(|| py.import("twisted.internet.defer").map(Into::into))?
.bind(py))
}
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
@@ -57,8 +62,8 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
child_module.add_class::<HttpClient>()?;
// Make sure we fail early if we can't build the lazy statics.
LazyLock::force(&RUNTIME);
LazyLock::force(&DEFERRED_CLASS);
runtime()?;
defer(py)?;
m.add_submodule(&child_module)?;
@@ -72,26 +77,21 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
}
#[pyclass]
#[derive(Clone)]
struct HttpClient {
client: reqwest::Client,
reactor: PyObject,
}
#[pymethods]
impl HttpClient {
#[new]
pub fn py_new(user_agent: &str) -> PyResult<HttpClient> {
// The twisted reactor can only be imported after Synapse has been
// imported, to allow Synapse to change the twisted reactor. If we try
// and import the reactor too early twisted installs a default reactor,
// which can't be replaced.
LazyLock::force(&TWISTED_REACTOR);
pub fn py_new(reactor: PyObject, user_agent: &str) -> PyResult<HttpClient> {
Ok(HttpClient {
client: reqwest::Client::builder()
.user_agent(user_agent)
.build()
.context("building reqwest client")?,
reactor,
})
}
@@ -129,7 +129,7 @@ impl HttpClient {
builder: RequestBuilder,
response_limit: usize,
) -> PyResult<Bound<'a, PyAny>> {
create_deferred(py, async move {
create_deferred(py, self.reactor.clone_ref(py), async move {
let response = builder.send().await.context("sending request")?;
let status = response.status();
@@ -159,43 +159,48 @@ impl HttpClient {
/// tokio runtime.
///
/// Does not handle deferred cancellation or contextvars.
fn create_deferred<F, O>(py: Python, fut: F) -> PyResult<Bound<'_, PyAny>>
fn create_deferred<F, O>(py: Python, reactor: PyObject, fut: F) -> PyResult<Bound<'_, PyAny>>
where
F: Future<Output = PyResult<O>> + Send + 'static,
for<'a> O: IntoPyObject<'a>,
for<'a> O: IntoPyObject<'a> + Send + 'static,
{
let deferred = DEFERRED_CLASS.bind(py).call0()?;
let deferred = defer(py)?.call_method0("Deferred")?;
let deferred_callback = deferred.getattr("callback")?.unbind();
let deferred_errback = deferred.getattr("errback")?.unbind();
RUNTIME.spawn(async move {
// TODO: Is it safe to assert unwind safety here? I think so, as we
// don't use anything that could be tainted by the panic afterwards.
// Note that `.spawn(..)` asserts unwind safety on the future too.
let res = AssertUnwindSafe(fut).catch_unwind().await;
let rt = runtime()?;
let task = rt.spawn(fut);
rt.spawn(async move {
let res = task.await;
Python::with_gil(move |py| {
// Flatten the panic into standard python error
let res = match res {
Ok(r) => r,
Err(panic_err) => {
let panic_message = get_panic_message(&panic_err);
Err(PyException::new_err(
PyString::new(py, panic_message).unbind(),
))
}
Err(join_err) => match join_err.try_into_panic() {
Ok(panic_err) => {
let panic_message = get_panic_message(&panic_err);
Err(RustPanicError::new_err(
PyString::new(py, panic_message).unbind(),
))
}
Err(err) => Err(PyException::new_err(format!("Task cancelled: {err}"))),
},
};
let reactor = reactor.bind(py);
// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
match res {
Ok(obj) => {
TWISTED_REACTOR
.call_method(py, "callFromThread", (deferred_callback, obj), None)
reactor
.call_method("callFromThread", (deferred_callback, obj), None)
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
}
Err(err) => {
TWISTED_REACTOR
.call_method(py, "callFromThread", (deferred_errback, err), None)
reactor
.call_method("callFromThread", (deferred_errback, err), None)
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
}
}

View File

@@ -183,7 +183,8 @@ class MSC3861DelegatedAuth(BaseAuth):
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
self._rust_http_client = HttpClient(
user_agent=self._http_client.user_agent.decode("utf8")
reactor=hs.get_reactor(),
user_agent=self._http_client.user_agent.decode("utf8"),
)
# # Token Introspection Cache

View File

@@ -12,8 +12,10 @@
from typing import Awaitable, Mapping
from synapse.types import ISynapseReactor
class HttpClient:
def __init__(self, user_agent: str) -> None: ...
def __init__(self, reactor: ISynapseReactor, user_agent: str) -> None: ...
def get(self, url: str, response_limit: int) -> Awaitable[bytes]: ...
def post(
self,