mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-07 01:20:16 +00:00
Compare commits
4 Commits
madlittlem
...
erikj/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5817281f8 | ||
|
|
87406aa5d3 | ||
|
|
6842974391 | ||
|
|
c93ef61fa3 |
1114
Cargo.lock
generated
1114
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -21,14 +21,25 @@ name = "synapse.synapse_rust"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.63"
|
anyhow = "1.0.63"
|
||||||
|
env_logger = "0.10.0"
|
||||||
|
futures = "0.3.25"
|
||||||
|
futures-util = "0.3.25"
|
||||||
|
http = "0.2.8"
|
||||||
|
hyper = { version = "0.14.23", features = ["client", "http1", "http2", "runtime", "server", "full"] }
|
||||||
|
hyper-tls = "0.5.0"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
log = "0.4.17"
|
log = "0.4.17"
|
||||||
|
native-tls = "0.2.11"
|
||||||
pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
|
pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
|
||||||
|
pyo3-asyncio = { version = "0.17.0", features = ["tokio", "tokio-runtime"] }
|
||||||
pyo3-log = "0.7.0"
|
pyo3-log = "0.7.0"
|
||||||
pythonize = "0.17.0"
|
pythonize = "0.17.0"
|
||||||
regex = "1.6.0"
|
regex = "1.6.0"
|
||||||
serde = { version = "1.0.144", features = ["derive"] }
|
serde = { version = "1.0.144", features = ["derive"] }
|
||||||
serde_json = "1.0.85"
|
serde_json = "1.0.85"
|
||||||
|
tokio = "1.23.0"
|
||||||
|
tokio-native-tls = "0.3.0"
|
||||||
|
trust-dns-resolver = "0.22.0"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
blake2 = "0.10.4"
|
blake2 = "0.10.4"
|
||||||
|
|||||||
158
rust/src/http/mod.rs
Normal file
158
rust/src/http/mod.rs
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use anyhow::Error;
|
||||||
|
use http::{Request, Uri};
|
||||||
|
use hyper::Body;
|
||||||
|
use log::info;
|
||||||
|
use pyo3::{
|
||||||
|
pyclass, pymethods,
|
||||||
|
types::{PyBytes, PyModule},
|
||||||
|
IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
|
||||||
|
};
|
||||||
|
|
||||||
|
use self::resolver::{MatrixConnector, MatrixResolver};
|
||||||
|
|
||||||
|
pub mod resolver;
|
||||||
|
|
||||||
|
/// Called when registering modules with python.
|
||||||
|
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
|
let child_module = PyModule::new(py, "http")?;
|
||||||
|
child_module.add_class::<HttpClient>()?;
|
||||||
|
child_module.add_class::<MatrixResponse>()?;
|
||||||
|
|
||||||
|
m.add_submodule(child_module)?;
|
||||||
|
|
||||||
|
// We need to manually add the module to sys.modules to make `from
|
||||||
|
// synapse.synapse_rust import push` work.
|
||||||
|
py.import("sys")?
|
||||||
|
.getattr("modules")?
|
||||||
|
.set_item("synapse.synapse_rust.http", child_module)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Bytes(pub Vec<u8>);
|
||||||
|
|
||||||
|
impl ToPyObject for Bytes {
|
||||||
|
fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
|
||||||
|
PyBytes::new(py, &self.0).into_py(py)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoPy<PyObject> for Bytes {
|
||||||
|
fn into_py(self, py: Python<'_>) -> PyObject {
|
||||||
|
self.to_object(py)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
#[pyclass]
|
||||||
|
pub struct MatrixResponse {
|
||||||
|
#[pyo3(get)]
|
||||||
|
pub code: u16,
|
||||||
|
#[pyo3(get)]
|
||||||
|
pub phrase: &'static str,
|
||||||
|
#[pyo3(get)]
|
||||||
|
pub content: Bytes,
|
||||||
|
#[pyo3(get)]
|
||||||
|
pub headers: HashMap<String, Bytes>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct HttpClient {
|
||||||
|
client: hyper::Client<MatrixConnector>,
|
||||||
|
resolver: MatrixResolver,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HttpClient {
|
||||||
|
pub fn new() -> Result<Self, Error> {
|
||||||
|
let resolver = MatrixResolver::new()?;
|
||||||
|
|
||||||
|
let client =
|
||||||
|
hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone()));
|
||||||
|
|
||||||
|
Ok(HttpClient { client, resolver })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn async_request(
|
||||||
|
&self,
|
||||||
|
url: String,
|
||||||
|
method: String,
|
||||||
|
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
|
||||||
|
body: Option<Vec<u8>>,
|
||||||
|
) -> Result<MatrixResponse, Error> {
|
||||||
|
let uri: Uri = url.try_into()?;
|
||||||
|
|
||||||
|
let mut builder = Request::builder().method(&*method).uri(uri.clone());
|
||||||
|
|
||||||
|
for (key, values) in headers {
|
||||||
|
for value in values {
|
||||||
|
builder = builder.header(key.clone(), value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if uri.scheme_str() == Some("matrix") {
|
||||||
|
let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?;
|
||||||
|
if let Some(endpoint) = endpoints.first() {
|
||||||
|
builder = builder.header("Host", &endpoint.host_header);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = if let Some(body) = body {
|
||||||
|
builder.body(Body::from(body))?
|
||||||
|
} else {
|
||||||
|
builder.body(Body::empty())?
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self.client.request(request).await?;
|
||||||
|
|
||||||
|
let code = response.status().as_u16();
|
||||||
|
let phrase = response.status().canonical_reason().unwrap_or_default();
|
||||||
|
|
||||||
|
let headers = response
|
||||||
|
.headers()
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.to_string(), Bytes(v.as_bytes().to_owned())))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let body = response.into_body();
|
||||||
|
|
||||||
|
let bytes = hyper::body::to_bytes(body).await?;
|
||||||
|
let content = Bytes(bytes.to_vec());
|
||||||
|
|
||||||
|
Ok(MatrixResponse {
|
||||||
|
code,
|
||||||
|
phrase,
|
||||||
|
content,
|
||||||
|
headers,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl HttpClient {
|
||||||
|
#[new]
|
||||||
|
fn py_new() -> Result<Self, Error> {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn request<'a>(
|
||||||
|
&'a self,
|
||||||
|
py: Python<'a>,
|
||||||
|
url: String,
|
||||||
|
method: String,
|
||||||
|
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
|
||||||
|
body: Option<Vec<u8>>,
|
||||||
|
) -> PyResult<&'a PyAny> {
|
||||||
|
pyo3::prepare_freethreaded_python();
|
||||||
|
|
||||||
|
let client = self.clone();
|
||||||
|
|
||||||
|
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||||
|
let resp = client.async_request(url, method, headers, body).await?;
|
||||||
|
Ok(resp)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
432
rust/src/http/resolver.rs
Normal file
432
rust/src/http/resolver.rs
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::str::FromStr;
|
||||||
|
use std::{
|
||||||
|
io::Cursor,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
task::{self, Poll},
|
||||||
|
};
|
||||||
|
|
||||||
|
use anyhow::{bail, Error};
|
||||||
|
use futures::{FutureExt, TryFutureExt};
|
||||||
|
use futures_util::stream::StreamExt;
|
||||||
|
use http::Uri;
|
||||||
|
use hyper::client::connect::Connection;
|
||||||
|
use hyper::client::connect::{Connected, HttpConnector};
|
||||||
|
use hyper::server::conn::Http;
|
||||||
|
use hyper::service::Service;
|
||||||
|
use hyper::Client;
|
||||||
|
use hyper_tls::HttpsConnector;
|
||||||
|
use hyper_tls::MaybeHttpsStream;
|
||||||
|
use log::{debug, info};
|
||||||
|
use native_tls::TlsConnector;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
|
||||||
|
use trust_dns_resolver::error::ResolveErrorKind;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Endpoint {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
|
||||||
|
pub host_header: String,
|
||||||
|
pub tls_name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MatrixResolver {
|
||||||
|
resolver: trust_dns_resolver::TokioAsyncResolver,
|
||||||
|
http_client: Client<HttpsConnector<HttpConnector>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MatrixResolver {
|
||||||
|
pub fn new() -> Result<MatrixResolver, Error> {
|
||||||
|
let http_client = hyper::Client::builder().build(HttpsConnector::new());
|
||||||
|
|
||||||
|
MatrixResolver::with_client(http_client)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_client(
|
||||||
|
http_client: Client<HttpsConnector<HttpConnector>>,
|
||||||
|
) -> Result<MatrixResolver, Error> {
|
||||||
|
let resolver = trust_dns_resolver::TokioAsyncResolver::tokio_from_system_conf()?;
|
||||||
|
|
||||||
|
Ok(MatrixResolver {
|
||||||
|
resolver,
|
||||||
|
http_client,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Does SRV lookup
|
||||||
|
pub async fn resolve_server_name_from_uri(&self, uri: &Uri) -> Result<Vec<Endpoint>, Error> {
|
||||||
|
let host = uri.host().expect("URI has no host").to_string();
|
||||||
|
let port = uri.port_u16();
|
||||||
|
|
||||||
|
self.resolve_server_name_from_host_port(host, port).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn resolve_server_name_from_host_port(
|
||||||
|
&self,
|
||||||
|
mut host: String,
|
||||||
|
mut port: Option<u16>,
|
||||||
|
) -> Result<Vec<Endpoint>, Error> {
|
||||||
|
let mut authority = if let Some(p) = port {
|
||||||
|
format!("{}:{}", host, p)
|
||||||
|
} else {
|
||||||
|
host.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
// If a literal IP or includes port then we shortcircuit.
|
||||||
|
if host.parse::<IpAddr>().is_ok() || port.is_some() {
|
||||||
|
return Ok(vec![Endpoint {
|
||||||
|
host: host.to_string(),
|
||||||
|
port: port.unwrap_or(8448),
|
||||||
|
|
||||||
|
host_header: authority.to_string(),
|
||||||
|
tls_name: host.to_string(),
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do well-known delegation lookup.
|
||||||
|
if let Some(server) = get_well_known(&self.http_client, &host).await {
|
||||||
|
let a = http::uri::Authority::from_str(&server.server)?;
|
||||||
|
host = a.host().to_string();
|
||||||
|
port = a.port_u16();
|
||||||
|
authority = a.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a literal IP or includes port then we shortcircuit.
|
||||||
|
if host.parse::<IpAddr>().is_ok() || port.is_some() {
|
||||||
|
return Ok(vec![Endpoint {
|
||||||
|
host: host.clone(),
|
||||||
|
port: port.unwrap_or(8448),
|
||||||
|
|
||||||
|
host_header: authority.to_string(),
|
||||||
|
tls_name: host.clone(),
|
||||||
|
}]);
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.resolver
|
||||||
|
.srv_lookup(format!("_matrix._tcp.{}", host))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let records = match result {
|
||||||
|
Ok(records) => records,
|
||||||
|
Err(err) => match err.kind() {
|
||||||
|
ResolveErrorKind::NoRecordsFound { .. } => {
|
||||||
|
return Ok(vec![Endpoint {
|
||||||
|
host: host.clone(),
|
||||||
|
port: 8448,
|
||||||
|
host_header: authority.to_string(),
|
||||||
|
tls_name: host.clone(),
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
_ => return Err(err.into()),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut priority_map: BTreeMap<u16, Vec<_>> = BTreeMap::new();
|
||||||
|
|
||||||
|
let mut count = 0;
|
||||||
|
for record in records {
|
||||||
|
count += 1;
|
||||||
|
let priority = record.priority();
|
||||||
|
priority_map.entry(priority).or_default().push(record);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut results = Vec::with_capacity(count);
|
||||||
|
|
||||||
|
for (_priority, records) in priority_map {
|
||||||
|
// TODO: Correctly shuffle records
|
||||||
|
results.extend(records.into_iter().map(|record| Endpoint {
|
||||||
|
host: record.target().to_utf8(),
|
||||||
|
port: record.port(),
|
||||||
|
|
||||||
|
host_header: host.to_string(),
|
||||||
|
tls_name: host.to_string(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_well_known<C>(http_client: &Client<C>, host: &str) -> Option<WellKnownServer>
|
||||||
|
where
|
||||||
|
C: Service<Uri> + Clone + Sync + Send + 'static,
|
||||||
|
C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
C::Future: Unpin + Send,
|
||||||
|
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
|
||||||
|
{
|
||||||
|
// TODO: Add timeout.
|
||||||
|
|
||||||
|
let uri = hyper::Uri::builder()
|
||||||
|
.scheme("https")
|
||||||
|
.authority(host)
|
||||||
|
.path_and_query("/.well-known/matrix/server")
|
||||||
|
.build()
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
let mut body = http_client.get(uri).await.ok()?.into_body();
|
||||||
|
|
||||||
|
let mut vec = Vec::new();
|
||||||
|
while let Some(next) = body.next().await {
|
||||||
|
let chunk = next.ok()?;
|
||||||
|
vec.extend(chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
serde_json::from_slice(&vec).ok()?
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct WellKnownServer {
|
||||||
|
#[serde(rename = "m.server")]
|
||||||
|
server: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MatrixConnector {
|
||||||
|
resolver: MatrixResolver,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MatrixConnector {
|
||||||
|
pub fn with_resolver(resolver: MatrixResolver) -> MatrixConnector {
|
||||||
|
MatrixConnector { resolver }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service<Uri> for MatrixConnector {
|
||||||
|
type Response = MaybeHttpsStream<TcpStream>;
|
||||||
|
type Error = Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
// This connector is always ready, but others might not be.
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, dst: Uri) -> Self::Future {
|
||||||
|
let resolver = self.resolver.clone();
|
||||||
|
|
||||||
|
if dst.scheme_str() != Some("matrix") {
|
||||||
|
debug!("Got non-matrix scheme");
|
||||||
|
return HttpsConnector::new()
|
||||||
|
.call(dst)
|
||||||
|
.map_err(|e| Error::msg(e))
|
||||||
|
.boxed();
|
||||||
|
}
|
||||||
|
|
||||||
|
async move {
|
||||||
|
let endpoints = resolver
|
||||||
|
.resolve_server_name_from_host_port(
|
||||||
|
dst.host().expect("hostname").to_string(),
|
||||||
|
dst.port_u16(),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
debug!("Got endpoints: {:?}", endpoints);
|
||||||
|
|
||||||
|
for endpoint in endpoints {
|
||||||
|
match try_connecting(&dst, &endpoint).await {
|
||||||
|
Ok(r) => return Ok(r),
|
||||||
|
// Errors here are not unexpected, and we just move on
|
||||||
|
// with our lives.
|
||||||
|
Err(e) => info!(
|
||||||
|
"Failed to connect to {} via {}:{} because {}",
|
||||||
|
dst.host().expect("hostname"),
|
||||||
|
endpoint.host,
|
||||||
|
endpoint.port,
|
||||||
|
e,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!(
|
||||||
|
"failed to resolve host: {:?} port {:?}",
|
||||||
|
dst.host(),
|
||||||
|
dst.port()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempts to connect to a particular endpoint.
|
||||||
|
async fn try_connecting(
|
||||||
|
dst: &Uri,
|
||||||
|
endpoint: &Endpoint,
|
||||||
|
) -> Result<MaybeHttpsStream<TcpStream>, Error> {
|
||||||
|
let tcp = TcpStream::connect((&endpoint.host as &str, endpoint.port)).await?;
|
||||||
|
|
||||||
|
let connector: AsyncTlsConnector = if dst.host().expect("hostname").contains("localhost") {
|
||||||
|
TlsConnector::builder()
|
||||||
|
.danger_accept_invalid_certs(true)
|
||||||
|
.build()?
|
||||||
|
.into()
|
||||||
|
} else {
|
||||||
|
TlsConnector::new().unwrap().into()
|
||||||
|
};
|
||||||
|
|
||||||
|
let tls = connector.connect(&endpoint.tls_name, tcp).await?;
|
||||||
|
|
||||||
|
Ok(tls.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A connector that reutrns a connection which returns 200 OK to all connections.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct TestConnector;
|
||||||
|
|
||||||
|
impl Service<Uri> for TestConnector {
|
||||||
|
type Response = TestConnection;
|
||||||
|
type Error = Error;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||||
|
|
||||||
|
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||||
|
// This connector is always ready, but others might not be.
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn call(&mut self, _dst: Uri) -> Self::Future {
|
||||||
|
let (client, server) = TestConnection::double_ended();
|
||||||
|
|
||||||
|
{
|
||||||
|
let service = hyper::service::service_fn(|_| async move {
|
||||||
|
Ok(hyper::Response::new(hyper::Body::from("Hello World")))
|
||||||
|
as Result<_, hyper::http::Error>
|
||||||
|
});
|
||||||
|
let fut = Http::new().serve_connection(server, service);
|
||||||
|
tokio::spawn(fut);
|
||||||
|
}
|
||||||
|
|
||||||
|
futures::future::ok(client).boxed()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct TestConnectionInner {
|
||||||
|
outbound_buffer: Cursor<Vec<u8>>,
|
||||||
|
inbound_buffer: Cursor<Vec<u8>>,
|
||||||
|
wakers: Vec<futures::task::Waker>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A in memory connection for use with tests.
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct TestConnection {
|
||||||
|
inner: Arc<Mutex<TestConnectionInner>>,
|
||||||
|
direction: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestConnection {
|
||||||
|
pub fn double_ended() -> (TestConnection, TestConnection) {
|
||||||
|
let inner: Arc<Mutex<TestConnectionInner>> = Arc::default();
|
||||||
|
|
||||||
|
let a = TestConnection {
|
||||||
|
inner: inner.clone(),
|
||||||
|
direction: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let b = TestConnection {
|
||||||
|
inner,
|
||||||
|
direction: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
(a, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncRead for TestConnection {
|
||||||
|
fn poll_read(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
buf: &mut ReadBuf<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
let mut conn = self.inner.lock().expect("mutex");
|
||||||
|
|
||||||
|
let buffer = if self.direction {
|
||||||
|
&mut conn.inbound_buffer
|
||||||
|
} else {
|
||||||
|
&mut conn.outbound_buffer
|
||||||
|
};
|
||||||
|
|
||||||
|
let bytes_read = std::io::Read::read(buffer, buf.initialize_unfilled())?;
|
||||||
|
buf.advance(bytes_read);
|
||||||
|
if bytes_read > 0 {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
conn.wakers.push(cx.waker().clone());
|
||||||
|
Poll::Pending
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AsyncWrite for TestConnection {
|
||||||
|
fn poll_write(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
_cx: &mut task::Context<'_>,
|
||||||
|
buf: &[u8],
|
||||||
|
) -> Poll<Result<usize, std::io::Error>> {
|
||||||
|
let mut conn = self.inner.lock().expect("mutex");
|
||||||
|
|
||||||
|
if self.direction {
|
||||||
|
conn.outbound_buffer.get_mut().extend_from_slice(buf);
|
||||||
|
} else {
|
||||||
|
conn.inbound_buffer.get_mut().extend_from_slice(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
for waker in conn.wakers.drain(..) {
|
||||||
|
waker.wake()
|
||||||
|
}
|
||||||
|
|
||||||
|
Poll::Ready(Ok(buf.len()))
|
||||||
|
}
|
||||||
|
fn poll_flush(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
let mut conn = self.inner.lock().expect("mutex");
|
||||||
|
|
||||||
|
if self.direction {
|
||||||
|
Pin::new(&mut conn.outbound_buffer).poll_flush(cx)
|
||||||
|
} else {
|
||||||
|
Pin::new(&mut conn.inbound_buffer).poll_flush(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn poll_shutdown(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut task::Context<'_>,
|
||||||
|
) -> Poll<Result<(), std::io::Error>> {
|
||||||
|
let mut conn = self.inner.lock().expect("mutex");
|
||||||
|
|
||||||
|
if self.direction {
|
||||||
|
Pin::new(&mut conn.outbound_buffer).poll_shutdown(cx)
|
||||||
|
} else {
|
||||||
|
Pin::new(&mut conn.inbound_buffer).poll_shutdown(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Connection for TestConnection {
|
||||||
|
fn connected(&self) -> Connected {
|
||||||
|
Connected::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_memory_connection() {
|
||||||
|
let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(TestConnector);
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.get("http://localhost".parse().unwrap())
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(response.status().is_success());
|
||||||
|
|
||||||
|
let bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||||
|
assert_eq!(&bytes[..], b"Hello World");
|
||||||
|
}
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
|
||||||
|
pub mod http;
|
||||||
pub mod push;
|
pub mod push;
|
||||||
|
|
||||||
/// Returns the hash of all the rust source files at the time it was compiled.
|
/// Returns the hash of all the rust source files at the time it was compiled.
|
||||||
@@ -26,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
|
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
|
||||||
|
|
||||||
push::register_module(py, m)?;
|
push::register_module(py, m)?;
|
||||||
|
http::register_module(py, m)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
16
stubs/synapse/synapse_rust/http.pyi
Normal file
16
stubs/synapse/synapse_rust/http.pyi
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
class MatrixResponse:
|
||||||
|
code: int
|
||||||
|
phrase: str
|
||||||
|
content: bytes
|
||||||
|
headers: Dict[str, str]
|
||||||
|
|
||||||
|
class HttpClient:
|
||||||
|
async def request(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
method: str,
|
||||||
|
headers: Dict[bytes, List[bytes]],
|
||||||
|
body: Optional[bytes],
|
||||||
|
) -> MatrixResponse: ...
|
||||||
@@ -29,7 +29,7 @@ if sys.version_info < (3, 7):
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# Allow using the asyncio reactor via env var.
|
# Allow using the asyncio reactor via env var.
|
||||||
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")):
|
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")) or True:
|
||||||
from incremental import Version
|
from incremental import Version
|
||||||
|
|
||||||
import twisted
|
import twisted
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import abc
|
import abc
|
||||||
|
import asyncio
|
||||||
import cgi
|
import cgi
|
||||||
import codecs
|
import codecs
|
||||||
import logging
|
import logging
|
||||||
@@ -42,14 +43,18 @@ from canonicaljson import encode_canonical_json
|
|||||||
from prometheus_client import Counter
|
from prometheus_client import Counter
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.internet.interfaces import IReactorTime
|
from twisted.internet.interfaces import IReactorTime
|
||||||
|
from twisted.internet.protocol import Protocol
|
||||||
from twisted.internet.task import Cooperator
|
from twisted.internet.task import Cooperator
|
||||||
from twisted.web.client import ResponseFailed
|
from twisted.internet.testing import StringTransport
|
||||||
|
from twisted.python.failure import Failure
|
||||||
|
from twisted.web.client import Response, ResponseDone, ResponseFailed
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import IBodyProducer, IResponse
|
from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer, IResponse
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
import synapse.util.retryutils
|
import synapse.util.retryutils
|
||||||
@@ -75,6 +80,7 @@ from synapse.http.types import QueryParams
|
|||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||||
|
from synapse.synapse_rust.http import HttpClient
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
|
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
|
||||||
@@ -199,6 +205,33 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
|
|||||||
return json_decoder.decode(self._buffer.getvalue())
|
return json_decoder.decode(self._buffer.getvalue())
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True)
|
||||||
|
@implementer(IResponse)
|
||||||
|
class RustResponse:
|
||||||
|
version: tuple
|
||||||
|
|
||||||
|
code: int
|
||||||
|
|
||||||
|
phrase: bytes
|
||||||
|
|
||||||
|
headers: Headers
|
||||||
|
|
||||||
|
length: Union[int, UNKNOWN_LENGTH]
|
||||||
|
|
||||||
|
# request: Optional[IClientRequest]
|
||||||
|
|
||||||
|
# previousResponse: Optional[IResponse]
|
||||||
|
|
||||||
|
_data: bytes
|
||||||
|
|
||||||
|
def deliverBody(self, protocol: Protocol):
|
||||||
|
protocol.dataReceived(self._data)
|
||||||
|
protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
|
||||||
|
|
||||||
|
def setPreviousResponse(self, response: IResponse):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def _handle_response(
|
async def _handle_response(
|
||||||
reactor: IReactorTime,
|
reactor: IReactorTime,
|
||||||
timeout_sec: float,
|
timeout_sec: float,
|
||||||
@@ -372,6 +405,8 @@ class MatrixFederationHttpClient:
|
|||||||
|
|
||||||
self._sleeper = AwakenableSleeper(self.reactor)
|
self._sleeper = AwakenableSleeper(self.reactor)
|
||||||
|
|
||||||
|
self._rust_client = HttpClient()
|
||||||
|
|
||||||
def wake_destination(self, destination: str) -> None:
|
def wake_destination(self, destination: str) -> None:
|
||||||
"""Called when the remote server may have come back online."""
|
"""Called when the remote server may have come back online."""
|
||||||
|
|
||||||
@@ -556,11 +591,8 @@ class MatrixFederationHttpClient:
|
|||||||
destination_bytes, method_bytes, url_to_sign_bytes, json
|
destination_bytes, method_bytes, url_to_sign_bytes, json
|
||||||
)
|
)
|
||||||
data = encode_canonical_json(json)
|
data = encode_canonical_json(json)
|
||||||
producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
|
|
||||||
BytesIO(data), cooperator=self._cooperator
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
producer = None
|
data = None
|
||||||
auth_headers = self.build_auth_headers(
|
auth_headers = self.build_auth_headers(
|
||||||
destination_bytes, method_bytes, url_to_sign_bytes
|
destination_bytes, method_bytes, url_to_sign_bytes
|
||||||
)
|
)
|
||||||
@@ -591,23 +623,33 @@ class MatrixFederationHttpClient:
|
|||||||
# * The `Deferred` that joins the forks back together is
|
# * The `Deferred` that joins the forks back together is
|
||||||
# wrapped in `make_deferred_yieldable` to restore the
|
# wrapped in `make_deferred_yieldable` to restore the
|
||||||
# logging context regardless of the path taken.
|
# logging context regardless of the path taken.
|
||||||
request_deferred = run_in_background(
|
# request_deferred = run_in_background(
|
||||||
self.agent.request,
|
# self._rust_client.request,
|
||||||
method_bytes,
|
# url_str,
|
||||||
url_bytes,
|
# request.method,
|
||||||
headers=Headers(headers_dict),
|
# headers_dict,
|
||||||
bodyProducer=producer,
|
# data,
|
||||||
)
|
# )
|
||||||
request_deferred = timeout_deferred(
|
# request_deferred = timeout_deferred(
|
||||||
request_deferred,
|
# request_deferred,
|
||||||
timeout=_sec_timeout,
|
# timeout=_sec_timeout,
|
||||||
reactor=self.reactor,
|
# reactor=self.reactor,
|
||||||
)
|
# )
|
||||||
|
|
||||||
response = await make_deferred_yieldable(request_deferred)
|
# response = await make_deferred_yieldable(request_deferred)
|
||||||
|
|
||||||
|
response_d = run_in_background(
|
||||||
|
self._rust_client.request,
|
||||||
|
url_str,
|
||||||
|
request.method,
|
||||||
|
headers_dict,
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
response = await make_deferred_yieldable(response_d)
|
||||||
except DNSLookupError as e:
|
except DNSLookupError as e:
|
||||||
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
|
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.exception("ERROR")
|
||||||
raise RequestSendFailed(e, can_retry=True) from e
|
raise RequestSendFailed(e, can_retry=True) from e
|
||||||
|
|
||||||
incoming_responses_counter.labels(
|
incoming_responses_counter.labels(
|
||||||
@@ -615,7 +657,7 @@ class MatrixFederationHttpClient:
|
|||||||
).inc()
|
).inc()
|
||||||
|
|
||||||
set_tag(tags.HTTP_STATUS_CODE, response.code)
|
set_tag(tags.HTTP_STATUS_CODE, response.code)
|
||||||
response_phrase = response.phrase.decode("ascii", errors="replace")
|
response_phrase = response.phrase
|
||||||
|
|
||||||
if 200 <= response.code < 300:
|
if 200 <= response.code < 300:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -635,25 +677,7 @@ class MatrixFederationHttpClient:
|
|||||||
)
|
)
|
||||||
# :'(
|
# :'(
|
||||||
# Update transactions table?
|
# Update transactions table?
|
||||||
d = treq.content(response)
|
body = response.content
|
||||||
d = timeout_deferred(
|
|
||||||
d, timeout=_sec_timeout, reactor=self.reactor
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
body = await make_deferred_yieldable(d)
|
|
||||||
except Exception as e:
|
|
||||||
# Eh, we're already going to raise an exception so lets
|
|
||||||
# ignore if this fails.
|
|
||||||
logger.warning(
|
|
||||||
"{%s} [%s] Failed to get error response: %s %s: %s",
|
|
||||||
request.txn_id,
|
|
||||||
request.destination,
|
|
||||||
request.method,
|
|
||||||
url_str,
|
|
||||||
_flatten_response_never_received(e),
|
|
||||||
)
|
|
||||||
body = None
|
|
||||||
|
|
||||||
exc = HttpResponseException(
|
exc = HttpResponseException(
|
||||||
response.code, response_phrase, body
|
response.code, response_phrase, body
|
||||||
@@ -715,7 +739,19 @@ class MatrixFederationHttpClient:
|
|||||||
_flatten_response_never_received(e),
|
_flatten_response_never_received(e),
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
return response
|
|
||||||
|
headers = Headers()
|
||||||
|
for key, value in response.headers.items():
|
||||||
|
headers.addRawHeader(key, value)
|
||||||
|
|
||||||
|
return RustResponse(
|
||||||
|
("HTTP", 1, 1),
|
||||||
|
response.code,
|
||||||
|
response.phrase.encode("ascii"),
|
||||||
|
headers,
|
||||||
|
UNKNOWN_LENGTH,
|
||||||
|
response.content,
|
||||||
|
)
|
||||||
|
|
||||||
def build_auth_headers(
|
def build_auth_headers(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
|
from asyncio import Future
|
||||||
from types import TracebackType
|
from types import TracebackType
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@@ -814,6 +815,8 @@ def run_in_background( # type: ignore[misc]
|
|||||||
res = defer.ensureDeferred(res)
|
res = defer.ensureDeferred(res)
|
||||||
elif isinstance(res, defer.Deferred):
|
elif isinstance(res, defer.Deferred):
|
||||||
pass
|
pass
|
||||||
|
elif isinstance(res, Future):
|
||||||
|
res = defer.Deferred.fromFuture(res)
|
||||||
elif isinstance(res, Awaitable):
|
elif isinstance(res, Awaitable):
|
||||||
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
|
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
|
||||||
# or `Future` from `make_awaitable`.
|
# or `Future` from `make_awaitable`.
|
||||||
|
|||||||
Reference in New Issue
Block a user