mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
4 Commits
release-v1
...
erikj/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5817281f8 | ||
|
|
87406aa5d3 | ||
|
|
6842974391 | ||
|
|
c93ef61fa3 |
1114
Cargo.lock
generated
1114
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -21,14 +21,25 @@ name = "synapse.synapse_rust"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.63"
|
||||
env_logger = "0.10.0"
|
||||
futures = "0.3.25"
|
||||
futures-util = "0.3.25"
|
||||
http = "0.2.8"
|
||||
hyper = { version = "0.14.23", features = ["client", "http1", "http2", "runtime", "server", "full"] }
|
||||
hyper-tls = "0.5.0"
|
||||
lazy_static = "1.4.0"
|
||||
log = "0.4.17"
|
||||
native-tls = "0.2.11"
|
||||
pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
|
||||
pyo3-asyncio = { version = "0.17.0", features = ["tokio", "tokio-runtime"] }
|
||||
pyo3-log = "0.7.0"
|
||||
pythonize = "0.17.0"
|
||||
regex = "1.6.0"
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
serde_json = "1.0.85"
|
||||
tokio = "1.23.0"
|
||||
tokio-native-tls = "0.3.0"
|
||||
trust-dns-resolver = "0.22.0"
|
||||
|
||||
[build-dependencies]
|
||||
blake2 = "0.10.4"
|
||||
|
||||
158
rust/src/http/mod.rs
Normal file
158
rust/src/http/mod.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Error;
|
||||
use http::{Request, Uri};
|
||||
use hyper::Body;
|
||||
use log::info;
|
||||
use pyo3::{
|
||||
pyclass, pymethods,
|
||||
types::{PyBytes, PyModule},
|
||||
IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use self::resolver::{MatrixConnector, MatrixResolver};
|
||||
|
||||
pub mod resolver;
|
||||
|
||||
/// Called when registering modules with python.
|
||||
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let child_module = PyModule::new(py, "http")?;
|
||||
child_module.add_class::<HttpClient>()?;
|
||||
child_module.add_class::<MatrixResponse>()?;
|
||||
|
||||
m.add_submodule(child_module)?;
|
||||
|
||||
// We need to manually add the module to sys.modules to make `from
|
||||
// synapse.synapse_rust import push` work.
|
||||
py.import("sys")?
|
||||
.getattr("modules")?
|
||||
.set_item("synapse.synapse_rust.http", child_module)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Bytes(pub Vec<u8>);
|
||||
|
||||
impl ToPyObject for Bytes {
|
||||
fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
|
||||
PyBytes::new(py, &self.0).into_py(py)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPy<PyObject> for Bytes {
|
||||
fn into_py(self, py: Python<'_>) -> PyObject {
|
||||
self.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[pyclass]
|
||||
pub struct MatrixResponse {
|
||||
#[pyo3(get)]
|
||||
pub code: u16,
|
||||
#[pyo3(get)]
|
||||
pub phrase: &'static str,
|
||||
#[pyo3(get)]
|
||||
pub content: Bytes,
|
||||
#[pyo3(get)]
|
||||
pub headers: HashMap<String, Bytes>,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct HttpClient {
|
||||
client: hyper::Client<MatrixConnector>,
|
||||
resolver: MatrixResolver,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
let resolver = MatrixResolver::new()?;
|
||||
|
||||
let client =
|
||||
hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone()));
|
||||
|
||||
Ok(HttpClient { client, resolver })
|
||||
}
|
||||
|
||||
pub async fn async_request(
|
||||
&self,
|
||||
url: String,
|
||||
method: String,
|
||||
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
|
||||
body: Option<Vec<u8>>,
|
||||
) -> Result<MatrixResponse, Error> {
|
||||
let uri: Uri = url.try_into()?;
|
||||
|
||||
let mut builder = Request::builder().method(&*method).uri(uri.clone());
|
||||
|
||||
for (key, values) in headers {
|
||||
for value in values {
|
||||
builder = builder.header(key.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
if uri.scheme_str() == Some("matrix") {
|
||||
let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?;
|
||||
if let Some(endpoint) = endpoints.first() {
|
||||
builder = builder.header("Host", &endpoint.host_header);
|
||||
}
|
||||
}
|
||||
|
||||
let request = if let Some(body) = body {
|
||||
builder.body(Body::from(body))?
|
||||
} else {
|
||||
builder.body(Body::empty())?
|
||||
};
|
||||
|
||||
let response = self.client.request(request).await?;
|
||||
|
||||
let code = response.status().as_u16();
|
||||
let phrase = response.status().canonical_reason().unwrap_or_default();
|
||||
|
||||
let headers = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), Bytes(v.as_bytes().to_owned())))
|
||||
.collect();
|
||||
|
||||
let body = response.into_body();
|
||||
|
||||
let bytes = hyper::body::to_bytes(body).await?;
|
||||
let content = Bytes(bytes.to_vec());
|
||||
|
||||
Ok(MatrixResponse {
|
||||
code,
|
||||
phrase,
|
||||
content,
|
||||
headers,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl HttpClient {
|
||||
#[new]
|
||||
fn py_new() -> Result<Self, Error> {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
fn request<'a>(
|
||||
&'a self,
|
||||
py: Python<'a>,
|
||||
url: String,
|
||||
method: String,
|
||||
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
|
||||
body: Option<Vec<u8>>,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
pyo3::prepare_freethreaded_python();
|
||||
|
||||
let client = self.clone();
|
||||
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
let resp = client.async_request(url, method, headers, body).await?;
|
||||
Ok(resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
432
rust/src/http/resolver.rs
Normal file
432
rust/src/http/resolver.rs
Normal file
@@ -0,0 +1,432 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::net::IpAddr;
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::{
|
||||
io::Cursor,
|
||||
sync::{Arc, Mutex},
|
||||
task::{self, Poll},
|
||||
};
|
||||
|
||||
use anyhow::{bail, Error};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures_util::stream::StreamExt;
|
||||
use http::Uri;
|
||||
use hyper::client::connect::Connection;
|
||||
use hyper::client::connect::{Connected, HttpConnector};
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::service::Service;
|
||||
use hyper::Client;
|
||||
use hyper_tls::HttpsConnector;
|
||||
use hyper_tls::MaybeHttpsStream;
|
||||
use log::{debug, info};
|
||||
use native_tls::TlsConnector;
|
||||
use serde::Deserialize;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
|
||||
use trust_dns_resolver::error::ResolveErrorKind;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Endpoint {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
|
||||
pub host_header: String,
|
||||
pub tls_name: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MatrixResolver {
|
||||
resolver: trust_dns_resolver::TokioAsyncResolver,
|
||||
http_client: Client<HttpsConnector<HttpConnector>>,
|
||||
}
|
||||
|
||||
impl MatrixResolver {
|
||||
pub fn new() -> Result<MatrixResolver, Error> {
|
||||
let http_client = hyper::Client::builder().build(HttpsConnector::new());
|
||||
|
||||
MatrixResolver::with_client(http_client)
|
||||
}
|
||||
|
||||
pub fn with_client(
|
||||
http_client: Client<HttpsConnector<HttpConnector>>,
|
||||
) -> Result<MatrixResolver, Error> {
|
||||
let resolver = trust_dns_resolver::TokioAsyncResolver::tokio_from_system_conf()?;
|
||||
|
||||
Ok(MatrixResolver {
|
||||
resolver,
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Does SRV lookup
|
||||
pub async fn resolve_server_name_from_uri(&self, uri: &Uri) -> Result<Vec<Endpoint>, Error> {
|
||||
let host = uri.host().expect("URI has no host").to_string();
|
||||
let port = uri.port_u16();
|
||||
|
||||
self.resolve_server_name_from_host_port(host, port).await
|
||||
}
|
||||
|
||||
pub async fn resolve_server_name_from_host_port(
|
||||
&self,
|
||||
mut host: String,
|
||||
mut port: Option<u16>,
|
||||
) -> Result<Vec<Endpoint>, Error> {
|
||||
let mut authority = if let Some(p) = port {
|
||||
format!("{}:{}", host, p)
|
||||
} else {
|
||||
host.to_string()
|
||||
};
|
||||
|
||||
// If a literal IP or includes port then we shortcircuit.
|
||||
if host.parse::<IpAddr>().is_ok() || port.is_some() {
|
||||
return Ok(vec![Endpoint {
|
||||
host: host.to_string(),
|
||||
port: port.unwrap_or(8448),
|
||||
|
||||
host_header: authority.to_string(),
|
||||
tls_name: host.to_string(),
|
||||
}]);
|
||||
}
|
||||
|
||||
// Do well-known delegation lookup.
|
||||
if let Some(server) = get_well_known(&self.http_client, &host).await {
|
||||
let a = http::uri::Authority::from_str(&server.server)?;
|
||||
host = a.host().to_string();
|
||||
port = a.port_u16();
|
||||
authority = a.to_string();
|
||||
}
|
||||
|
||||
// If a literal IP or includes port then we shortcircuit.
|
||||
if host.parse::<IpAddr>().is_ok() || port.is_some() {
|
||||
return Ok(vec![Endpoint {
|
||||
host: host.clone(),
|
||||
port: port.unwrap_or(8448),
|
||||
|
||||
host_header: authority.to_string(),
|
||||
tls_name: host.clone(),
|
||||
}]);
|
||||
}
|
||||
|
||||
let result = self
|
||||
.resolver
|
||||
.srv_lookup(format!("_matrix._tcp.{}", host))
|
||||
.await;
|
||||
|
||||
let records = match result {
|
||||
Ok(records) => records,
|
||||
Err(err) => match err.kind() {
|
||||
ResolveErrorKind::NoRecordsFound { .. } => {
|
||||
return Ok(vec![Endpoint {
|
||||
host: host.clone(),
|
||||
port: 8448,
|
||||
host_header: authority.to_string(),
|
||||
tls_name: host.clone(),
|
||||
}])
|
||||
}
|
||||
_ => return Err(err.into()),
|
||||
},
|
||||
};
|
||||
|
||||
let mut priority_map: BTreeMap<u16, Vec<_>> = BTreeMap::new();
|
||||
|
||||
let mut count = 0;
|
||||
for record in records {
|
||||
count += 1;
|
||||
let priority = record.priority();
|
||||
priority_map.entry(priority).or_default().push(record);
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(count);
|
||||
|
||||
for (_priority, records) in priority_map {
|
||||
// TODO: Correctly shuffle records
|
||||
results.extend(records.into_iter().map(|record| Endpoint {
|
||||
host: record.target().to_utf8(),
|
||||
port: record.port(),
|
||||
|
||||
host_header: host.to_string(),
|
||||
tls_name: host.to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_well_known<C>(http_client: &Client<C>, host: &str) -> Option<WellKnownServer>
|
||||
where
|
||||
C: Service<Uri> + Clone + Sync + Send + 'static,
|
||||
C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
C::Future: Unpin + Send,
|
||||
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
|
||||
{
|
||||
// TODO: Add timeout.
|
||||
|
||||
let uri = hyper::Uri::builder()
|
||||
.scheme("https")
|
||||
.authority(host)
|
||||
.path_and_query("/.well-known/matrix/server")
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let mut body = http_client.get(uri).await.ok()?.into_body();
|
||||
|
||||
let mut vec = Vec::new();
|
||||
while let Some(next) = body.next().await {
|
||||
let chunk = next.ok()?;
|
||||
vec.extend(chunk);
|
||||
}
|
||||
|
||||
serde_json::from_slice(&vec).ok()?
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct WellKnownServer {
|
||||
#[serde(rename = "m.server")]
|
||||
server: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MatrixConnector {
|
||||
resolver: MatrixResolver,
|
||||
}
|
||||
|
||||
impl MatrixConnector {
|
||||
pub fn with_resolver(resolver: MatrixResolver) -> MatrixConnector {
|
||||
MatrixConnector { resolver }
|
||||
}
|
||||
}
|
||||
|
||||
impl Service<Uri> for MatrixConnector {
|
||||
type Response = MaybeHttpsStream<TcpStream>;
|
||||
type Error = Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// This connector is always ready, but others might not be.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, dst: Uri) -> Self::Future {
|
||||
let resolver = self.resolver.clone();
|
||||
|
||||
if dst.scheme_str() != Some("matrix") {
|
||||
debug!("Got non-matrix scheme");
|
||||
return HttpsConnector::new()
|
||||
.call(dst)
|
||||
.map_err(|e| Error::msg(e))
|
||||
.boxed();
|
||||
}
|
||||
|
||||
async move {
|
||||
let endpoints = resolver
|
||||
.resolve_server_name_from_host_port(
|
||||
dst.host().expect("hostname").to_string(),
|
||||
dst.port_u16(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Got endpoints: {:?}", endpoints);
|
||||
|
||||
for endpoint in endpoints {
|
||||
match try_connecting(&dst, &endpoint).await {
|
||||
Ok(r) => return Ok(r),
|
||||
// Errors here are not unexpected, and we just move on
|
||||
// with our lives.
|
||||
Err(e) => info!(
|
||||
"Failed to connect to {} via {}:{} because {}",
|
||||
dst.host().expect("hostname"),
|
||||
endpoint.host,
|
||||
endpoint.port,
|
||||
e,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
bail!(
|
||||
"failed to resolve host: {:?} port {:?}",
|
||||
dst.host(),
|
||||
dst.port()
|
||||
)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to connect to a particular endpoint.
|
||||
async fn try_connecting(
|
||||
dst: &Uri,
|
||||
endpoint: &Endpoint,
|
||||
) -> Result<MaybeHttpsStream<TcpStream>, Error> {
|
||||
let tcp = TcpStream::connect((&endpoint.host as &str, endpoint.port)).await?;
|
||||
|
||||
let connector: AsyncTlsConnector = if dst.host().expect("hostname").contains("localhost") {
|
||||
TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.build()?
|
||||
.into()
|
||||
} else {
|
||||
TlsConnector::new().unwrap().into()
|
||||
};
|
||||
|
||||
let tls = connector.connect(&endpoint.tls_name, tcp).await?;
|
||||
|
||||
Ok(tls.into())
|
||||
}
|
||||
|
||||
/// A connector that reutrns a connection which returns 200 OK to all connections.
|
||||
#[derive(Clone)]
|
||||
pub struct TestConnector;
|
||||
|
||||
impl Service<Uri> for TestConnector {
|
||||
type Response = TestConnection;
|
||||
type Error = Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// This connector is always ready, but others might not be.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _dst: Uri) -> Self::Future {
|
||||
let (client, server) = TestConnection::double_ended();
|
||||
|
||||
{
|
||||
let service = hyper::service::service_fn(|_| async move {
|
||||
Ok(hyper::Response::new(hyper::Body::from("Hello World")))
|
||||
as Result<_, hyper::http::Error>
|
||||
});
|
||||
let fut = Http::new().serve_connection(server, service);
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
futures::future::ok(client).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestConnectionInner {
|
||||
outbound_buffer: Cursor<Vec<u8>>,
|
||||
inbound_buffer: Cursor<Vec<u8>>,
|
||||
wakers: Vec<futures::task::Waker>,
|
||||
}
|
||||
|
||||
/// A in memory connection for use with tests.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct TestConnection {
|
||||
inner: Arc<Mutex<TestConnectionInner>>,
|
||||
direction: bool,
|
||||
}
|
||||
|
||||
impl TestConnection {
|
||||
pub fn double_ended() -> (TestConnection, TestConnection) {
|
||||
let inner: Arc<Mutex<TestConnectionInner>> = Arc::default();
|
||||
|
||||
let a = TestConnection {
|
||||
inner: inner.clone(),
|
||||
direction: false,
|
||||
};
|
||||
|
||||
let b = TestConnection {
|
||||
inner,
|
||||
direction: true,
|
||||
};
|
||||
|
||||
(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TestConnection {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
let buffer = if self.direction {
|
||||
&mut conn.inbound_buffer
|
||||
} else {
|
||||
&mut conn.outbound_buffer
|
||||
};
|
||||
|
||||
let bytes_read = std::io::Read::read(buffer, buf.initialize_unfilled())?;
|
||||
buf.advance(bytes_read);
|
||||
if bytes_read > 0 {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
conn.wakers.push(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestConnection {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
if self.direction {
|
||||
conn.outbound_buffer.get_mut().extend_from_slice(buf);
|
||||
} else {
|
||||
conn.inbound_buffer.get_mut().extend_from_slice(buf);
|
||||
}
|
||||
|
||||
for waker in conn.wakers.drain(..) {
|
||||
waker.wake()
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
if self.direction {
|
||||
Pin::new(&mut conn.outbound_buffer).poll_flush(cx)
|
||||
} else {
|
||||
Pin::new(&mut conn.inbound_buffer).poll_flush(cx)
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
if self.direction {
|
||||
Pin::new(&mut conn.outbound_buffer).poll_shutdown(cx)
|
||||
} else {
|
||||
Pin::new(&mut conn.inbound_buffer).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for TestConnection {
|
||||
fn connected(&self) -> Connected {
|
||||
Connected::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_connection() {
|
||||
let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(TestConnector);
|
||||
|
||||
let response = client
|
||||
.get("http://localhost".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(response.status().is_success());
|
||||
|
||||
let bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
assert_eq!(&bytes[..], b"Hello World");
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use pyo3::prelude::*;
|
||||
|
||||
pub mod http;
|
||||
pub mod push;
|
||||
|
||||
/// Returns the hash of all the rust source files at the time it was compiled.
|
||||
@@ -26,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
|
||||
|
||||
push::register_module(py, m)?;
|
||||
http::register_module(py, m)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
16
stubs/synapse/synapse_rust/http.pyi
Normal file
16
stubs/synapse/synapse_rust/http.pyi
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
class MatrixResponse:
|
||||
code: int
|
||||
phrase: str
|
||||
content: bytes
|
||||
headers: Dict[str, str]
|
||||
|
||||
class HttpClient:
|
||||
async def request(
|
||||
self,
|
||||
url: str,
|
||||
method: str,
|
||||
headers: Dict[bytes, List[bytes]],
|
||||
body: Optional[bytes],
|
||||
) -> MatrixResponse: ...
|
||||
@@ -29,7 +29,7 @@ if sys.version_info < (3, 7):
|
||||
sys.exit(1)
|
||||
|
||||
# Allow using the asyncio reactor via env var.
|
||||
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")):
|
||||
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")) or True:
|
||||
from incremental import Version
|
||||
|
||||
import twisted
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import asyncio
|
||||
import cgi
|
||||
import codecs
|
||||
import logging
|
||||
@@ -42,14 +43,18 @@ from canonicaljson import encode_canonical_json
|
||||
from prometheus_client import Counter
|
||||
from signedjson.sign import sign_json
|
||||
from typing_extensions import Literal
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.task import Cooperator
|
||||
from twisted.web.client import ResponseFailed
|
||||
from twisted.internet.testing import StringTransport
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import Response, ResponseDone, ResponseFailed
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IBodyProducer, IResponse
|
||||
from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer, IResponse
|
||||
|
||||
import synapse.metrics
|
||||
import synapse.util.retryutils
|
||||
@@ -75,6 +80,7 @@ from synapse.http.types import QueryParams
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.synapse_rust.http import HttpClient
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
|
||||
@@ -199,6 +205,33 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
|
||||
return json_decoder.decode(self._buffer.getvalue())
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@implementer(IResponse)
|
||||
class RustResponse:
|
||||
version: tuple
|
||||
|
||||
code: int
|
||||
|
||||
phrase: bytes
|
||||
|
||||
headers: Headers
|
||||
|
||||
length: Union[int, UNKNOWN_LENGTH]
|
||||
|
||||
# request: Optional[IClientRequest]
|
||||
|
||||
# previousResponse: Optional[IResponse]
|
||||
|
||||
_data: bytes
|
||||
|
||||
def deliverBody(self, protocol: Protocol):
|
||||
protocol.dataReceived(self._data)
|
||||
protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
|
||||
|
||||
def setPreviousResponse(self, response: IResponse):
|
||||
pass
|
||||
|
||||
|
||||
async def _handle_response(
|
||||
reactor: IReactorTime,
|
||||
timeout_sec: float,
|
||||
@@ -372,6 +405,8 @@ class MatrixFederationHttpClient:
|
||||
|
||||
self._sleeper = AwakenableSleeper(self.reactor)
|
||||
|
||||
self._rust_client = HttpClient()
|
||||
|
||||
def wake_destination(self, destination: str) -> None:
|
||||
"""Called when the remote server may have come back online."""
|
||||
|
||||
@@ -556,11 +591,8 @@ class MatrixFederationHttpClient:
|
||||
destination_bytes, method_bytes, url_to_sign_bytes, json
|
||||
)
|
||||
data = encode_canonical_json(json)
|
||||
producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
|
||||
BytesIO(data), cooperator=self._cooperator
|
||||
)
|
||||
else:
|
||||
producer = None
|
||||
data = None
|
||||
auth_headers = self.build_auth_headers(
|
||||
destination_bytes, method_bytes, url_to_sign_bytes
|
||||
)
|
||||
@@ -591,23 +623,33 @@ class MatrixFederationHttpClient:
|
||||
# * The `Deferred` that joins the forks back together is
|
||||
# wrapped in `make_deferred_yieldable` to restore the
|
||||
# logging context regardless of the path taken.
|
||||
request_deferred = run_in_background(
|
||||
self.agent.request,
|
||||
method_bytes,
|
||||
url_bytes,
|
||||
headers=Headers(headers_dict),
|
||||
bodyProducer=producer,
|
||||
)
|
||||
request_deferred = timeout_deferred(
|
||||
request_deferred,
|
||||
timeout=_sec_timeout,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
# request_deferred = run_in_background(
|
||||
# self._rust_client.request,
|
||||
# url_str,
|
||||
# request.method,
|
||||
# headers_dict,
|
||||
# data,
|
||||
# )
|
||||
# request_deferred = timeout_deferred(
|
||||
# request_deferred,
|
||||
# timeout=_sec_timeout,
|
||||
# reactor=self.reactor,
|
||||
# )
|
||||
|
||||
response = await make_deferred_yieldable(request_deferred)
|
||||
# response = await make_deferred_yieldable(request_deferred)
|
||||
|
||||
response_d = run_in_background(
|
||||
self._rust_client.request,
|
||||
url_str,
|
||||
request.method,
|
||||
headers_dict,
|
||||
data,
|
||||
)
|
||||
response = await make_deferred_yieldable(response_d)
|
||||
except DNSLookupError as e:
|
||||
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
|
||||
except Exception as e:
|
||||
logger.exception("ERROR")
|
||||
raise RequestSendFailed(e, can_retry=True) from e
|
||||
|
||||
incoming_responses_counter.labels(
|
||||
@@ -615,7 +657,7 @@ class MatrixFederationHttpClient:
|
||||
).inc()
|
||||
|
||||
set_tag(tags.HTTP_STATUS_CODE, response.code)
|
||||
response_phrase = response.phrase.decode("ascii", errors="replace")
|
||||
response_phrase = response.phrase
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
logger.debug(
|
||||
@@ -635,25 +677,7 @@ class MatrixFederationHttpClient:
|
||||
)
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
d = treq.content(response)
|
||||
d = timeout_deferred(
|
||||
d, timeout=_sec_timeout, reactor=self.reactor
|
||||
)
|
||||
|
||||
try:
|
||||
body = await make_deferred_yieldable(d)
|
||||
except Exception as e:
|
||||
# Eh, we're already going to raise an exception so lets
|
||||
# ignore if this fails.
|
||||
logger.warning(
|
||||
"{%s} [%s] Failed to get error response: %s %s: %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
request.method,
|
||||
url_str,
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
body = None
|
||||
body = response.content
|
||||
|
||||
exc = HttpResponseException(
|
||||
response.code, response_phrase, body
|
||||
@@ -715,7 +739,19 @@ class MatrixFederationHttpClient:
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
headers = Headers()
|
||||
for key, value in response.headers.items():
|
||||
headers.addRawHeader(key, value)
|
||||
|
||||
return RustResponse(
|
||||
("HTTP", 1, 1),
|
||||
response.code,
|
||||
response.phrase.encode("ascii"),
|
||||
headers,
|
||||
UNKNOWN_LENGTH,
|
||||
response.content,
|
||||
)
|
||||
|
||||
def build_auth_headers(
|
||||
self,
|
||||
|
||||
@@ -26,6 +26,7 @@ import logging
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from asyncio import Future
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -814,6 +815,8 @@ def run_in_background( # type: ignore[misc]
|
||||
res = defer.ensureDeferred(res)
|
||||
elif isinstance(res, defer.Deferred):
|
||||
pass
|
||||
elif isinstance(res, Future):
|
||||
res = defer.Deferred.fromFuture(res)
|
||||
elif isinstance(res, Awaitable):
|
||||
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
|
||||
# or `Future` from `make_awaitable`.
|
||||
|
||||
Reference in New Issue
Block a user