Compare commits

...

4 Commits

Author SHA1 Message Date
Erik Johnston
f5817281f8 Fixup2 2022-12-15 14:05:59 +00:00
Erik Johnston
87406aa5d3 Fixup 2022-12-15 13:20:20 +00:00
Erik Johnston
6842974391 Fixup 2022-12-15 13:15:47 +00:00
Erik Johnston
c93ef61fa3 WIP Rust HTTP for federation 2022-12-14 11:02:16 +00:00
9 changed files with 1807 additions and 47 deletions

1114
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -21,14 +21,25 @@ name = "synapse.synapse_rust"
[dependencies]
anyhow = "1.0.63"
env_logger = "0.10.0"
futures = "0.3.25"
futures-util = "0.3.25"
http = "0.2.8"
hyper = { version = "0.14.23", features = ["client", "http1", "http2", "runtime", "server", "full"] }
hyper-tls = "0.5.0"
lazy_static = "1.4.0"
log = "0.4.17"
native-tls = "0.2.11"
pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
pyo3-asyncio = { version = "0.17.0", features = ["tokio", "tokio-runtime"] }
pyo3-log = "0.7.0"
pythonize = "0.17.0"
regex = "1.6.0"
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"
tokio = "1.23.0"
tokio-native-tls = "0.3.0"
trust-dns-resolver = "0.22.0"
[build-dependencies]
blake2 = "0.10.4"

158
rust/src/http/mod.rs Normal file
View File

@@ -0,0 +1,158 @@
use std::collections::HashMap;
use anyhow::Error;
use http::{Request, Uri};
use hyper::Body;
use log::info;
use pyo3::{
pyclass, pymethods,
types::{PyBytes, PyModule},
IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
};
use self::resolver::{MatrixConnector, MatrixResolver};
pub mod resolver;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "http")?;
child_module.add_class::<HttpClient>()?;
child_module.add_class::<MatrixResponse>()?;
m.add_submodule(child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import push` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.http", child_module)?;
Ok(())
}
#[derive(Clone, Debug)]
pub struct Bytes(pub Vec<u8>);
impl ToPyObject for Bytes {
fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
PyBytes::new(py, &self.0).into_py(py)
}
}
impl IntoPy<PyObject> for Bytes {
fn into_py(self, py: Python<'_>) -> PyObject {
self.to_object(py)
}
}
#[derive(Debug)]
#[pyclass]
pub struct MatrixResponse {
#[pyo3(get)]
pub code: u16,
#[pyo3(get)]
pub phrase: &'static str,
#[pyo3(get)]
pub content: Bytes,
#[pyo3(get)]
pub headers: HashMap<String, Bytes>,
}
#[pyclass]
#[derive(Clone)]
pub struct HttpClient {
client: hyper::Client<MatrixConnector>,
resolver: MatrixResolver,
}
impl HttpClient {
pub fn new() -> Result<Self, Error> {
let resolver = MatrixResolver::new()?;
let client =
hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone()));
Ok(HttpClient { client, resolver })
}
pub async fn async_request(
&self,
url: String,
method: String,
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
body: Option<Vec<u8>>,
) -> Result<MatrixResponse, Error> {
let uri: Uri = url.try_into()?;
let mut builder = Request::builder().method(&*method).uri(uri.clone());
for (key, values) in headers {
for value in values {
builder = builder.header(key.clone(), value);
}
}
if uri.scheme_str() == Some("matrix") {
let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?;
if let Some(endpoint) = endpoints.first() {
builder = builder.header("Host", &endpoint.host_header);
}
}
let request = if let Some(body) = body {
builder.body(Body::from(body))?
} else {
builder.body(Body::empty())?
};
let response = self.client.request(request).await?;
let code = response.status().as_u16();
let phrase = response.status().canonical_reason().unwrap_or_default();
let headers = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), Bytes(v.as_bytes().to_owned())))
.collect();
let body = response.into_body();
let bytes = hyper::body::to_bytes(body).await?;
let content = Bytes(bytes.to_vec());
Ok(MatrixResponse {
code,
phrase,
content,
headers,
})
}
}
#[pymethods]
impl HttpClient {
#[new]
fn py_new() -> Result<Self, Error> {
Self::new()
}
fn request<'a>(
&'a self,
py: Python<'a>,
url: String,
method: String,
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
body: Option<Vec<u8>>,
) -> PyResult<&'a PyAny> {
pyo3::prepare_freethreaded_python();
let client = self.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
let resp = client.async_request(url, method, headers, body).await?;
Ok(resp)
})
}
}

432
rust/src/http/resolver.rs Normal file
View File

@@ -0,0 +1,432 @@
use std::collections::BTreeMap;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::{
io::Cursor,
sync::{Arc, Mutex},
task::{self, Poll},
};
use anyhow::{bail, Error};
use futures::{FutureExt, TryFutureExt};
use futures_util::stream::StreamExt;
use http::Uri;
use hyper::client::connect::Connection;
use hyper::client::connect::{Connected, HttpConnector};
use hyper::server::conn::Http;
use hyper::service::Service;
use hyper::Client;
use hyper_tls::HttpsConnector;
use hyper_tls::MaybeHttpsStream;
use log::{debug, info};
use native_tls::TlsConnector;
use serde::Deserialize;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
use trust_dns_resolver::error::ResolveErrorKind;
#[derive(Debug, Clone)]
pub struct Endpoint {
pub host: String,
pub port: u16,
pub host_header: String,
pub tls_name: String,
}
#[derive(Clone)]
pub struct MatrixResolver {
resolver: trust_dns_resolver::TokioAsyncResolver,
http_client: Client<HttpsConnector<HttpConnector>>,
}
impl MatrixResolver {
pub fn new() -> Result<MatrixResolver, Error> {
let http_client = hyper::Client::builder().build(HttpsConnector::new());
MatrixResolver::with_client(http_client)
}
pub fn with_client(
http_client: Client<HttpsConnector<HttpConnector>>,
) -> Result<MatrixResolver, Error> {
let resolver = trust_dns_resolver::TokioAsyncResolver::tokio_from_system_conf()?;
Ok(MatrixResolver {
resolver,
http_client,
})
}
/// Does SRV lookup
pub async fn resolve_server_name_from_uri(&self, uri: &Uri) -> Result<Vec<Endpoint>, Error> {
let host = uri.host().expect("URI has no host").to_string();
let port = uri.port_u16();
self.resolve_server_name_from_host_port(host, port).await
}
pub async fn resolve_server_name_from_host_port(
&self,
mut host: String,
mut port: Option<u16>,
) -> Result<Vec<Endpoint>, Error> {
let mut authority = if let Some(p) = port {
format!("{}:{}", host, p)
} else {
host.to_string()
};
// If a literal IP or includes port then we shortcircuit.
if host.parse::<IpAddr>().is_ok() || port.is_some() {
return Ok(vec![Endpoint {
host: host.to_string(),
port: port.unwrap_or(8448),
host_header: authority.to_string(),
tls_name: host.to_string(),
}]);
}
// Do well-known delegation lookup.
if let Some(server) = get_well_known(&self.http_client, &host).await {
let a = http::uri::Authority::from_str(&server.server)?;
host = a.host().to_string();
port = a.port_u16();
authority = a.to_string();
}
// If a literal IP or includes port then we shortcircuit.
if host.parse::<IpAddr>().is_ok() || port.is_some() {
return Ok(vec![Endpoint {
host: host.clone(),
port: port.unwrap_or(8448),
host_header: authority.to_string(),
tls_name: host.clone(),
}]);
}
let result = self
.resolver
.srv_lookup(format!("_matrix._tcp.{}", host))
.await;
let records = match result {
Ok(records) => records,
Err(err) => match err.kind() {
ResolveErrorKind::NoRecordsFound { .. } => {
return Ok(vec![Endpoint {
host: host.clone(),
port: 8448,
host_header: authority.to_string(),
tls_name: host.clone(),
}])
}
_ => return Err(err.into()),
},
};
let mut priority_map: BTreeMap<u16, Vec<_>> = BTreeMap::new();
let mut count = 0;
for record in records {
count += 1;
let priority = record.priority();
priority_map.entry(priority).or_default().push(record);
}
let mut results = Vec::with_capacity(count);
for (_priority, records) in priority_map {
// TODO: Correctly shuffle records
results.extend(records.into_iter().map(|record| Endpoint {
host: record.target().to_utf8(),
port: record.port(),
host_header: host.to_string(),
tls_name: host.to_string(),
}))
}
Ok(results)
}
}
async fn get_well_known<C>(http_client: &Client<C>, host: &str) -> Option<WellKnownServer>
where
C: Service<Uri> + Clone + Sync + Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
C::Future: Unpin + Send,
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
// TODO: Add timeout.
let uri = hyper::Uri::builder()
.scheme("https")
.authority(host)
.path_and_query("/.well-known/matrix/server")
.build()
.ok()?;
let mut body = http_client.get(uri).await.ok()?.into_body();
let mut vec = Vec::new();
while let Some(next) = body.next().await {
let chunk = next.ok()?;
vec.extend(chunk);
}
serde_json::from_slice(&vec).ok()?
}
#[derive(Deserialize)]
struct WellKnownServer {
#[serde(rename = "m.server")]
server: String,
}
#[derive(Clone)]
pub struct MatrixConnector {
resolver: MatrixResolver,
}
impl MatrixConnector {
pub fn with_resolver(resolver: MatrixResolver) -> MatrixConnector {
MatrixConnector { resolver }
}
}
impl Service<Uri> for MatrixConnector {
type Response = MaybeHttpsStream<TcpStream>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// This connector is always ready, but others might not be.
Poll::Ready(Ok(()))
}
fn call(&mut self, dst: Uri) -> Self::Future {
let resolver = self.resolver.clone();
if dst.scheme_str() != Some("matrix") {
debug!("Got non-matrix scheme");
return HttpsConnector::new()
.call(dst)
.map_err(|e| Error::msg(e))
.boxed();
}
async move {
let endpoints = resolver
.resolve_server_name_from_host_port(
dst.host().expect("hostname").to_string(),
dst.port_u16(),
)
.await?;
debug!("Got endpoints: {:?}", endpoints);
for endpoint in endpoints {
match try_connecting(&dst, &endpoint).await {
Ok(r) => return Ok(r),
// Errors here are not unexpected, and we just move on
// with our lives.
Err(e) => info!(
"Failed to connect to {} via {}:{} because {}",
dst.host().expect("hostname"),
endpoint.host,
endpoint.port,
e,
),
}
}
bail!(
"failed to resolve host: {:?} port {:?}",
dst.host(),
dst.port()
)
}
.boxed()
}
}
/// Attempts to connect to a particular endpoint.
async fn try_connecting(
dst: &Uri,
endpoint: &Endpoint,
) -> Result<MaybeHttpsStream<TcpStream>, Error> {
let tcp = TcpStream::connect((&endpoint.host as &str, endpoint.port)).await?;
let connector: AsyncTlsConnector = if dst.host().expect("hostname").contains("localhost") {
TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build()?
.into()
} else {
TlsConnector::new().unwrap().into()
};
let tls = connector.connect(&endpoint.tls_name, tcp).await?;
Ok(tls.into())
}
/// A connector that reutrns a connection which returns 200 OK to all connections.
#[derive(Clone)]
pub struct TestConnector;
impl Service<Uri> for TestConnector {
type Response = TestConnection;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// This connector is always ready, but others might not be.
Poll::Ready(Ok(()))
}
fn call(&mut self, _dst: Uri) -> Self::Future {
let (client, server) = TestConnection::double_ended();
{
let service = hyper::service::service_fn(|_| async move {
Ok(hyper::Response::new(hyper::Body::from("Hello World")))
as Result<_, hyper::http::Error>
});
let fut = Http::new().serve_connection(server, service);
tokio::spawn(fut);
}
futures::future::ok(client).boxed()
}
}
#[derive(Default)]
struct TestConnectionInner {
outbound_buffer: Cursor<Vec<u8>>,
inbound_buffer: Cursor<Vec<u8>>,
wakers: Vec<futures::task::Waker>,
}
/// A in memory connection for use with tests.
#[derive(Clone, Default)]
pub struct TestConnection {
inner: Arc<Mutex<TestConnectionInner>>,
direction: bool,
}
impl TestConnection {
pub fn double_ended() -> (TestConnection, TestConnection) {
let inner: Arc<Mutex<TestConnectionInner>> = Arc::default();
let a = TestConnection {
inner: inner.clone(),
direction: false,
};
let b = TestConnection {
inner,
direction: true,
};
(a, b)
}
}
impl AsyncRead for TestConnection {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
let buffer = if self.direction {
&mut conn.inbound_buffer
} else {
&mut conn.outbound_buffer
};
let bytes_read = std::io::Read::read(buffer, buf.initialize_unfilled())?;
buf.advance(bytes_read);
if bytes_read > 0 {
Poll::Ready(Ok(()))
} else {
conn.wakers.push(cx.waker().clone());
Poll::Pending
}
}
}
impl AsyncWrite for TestConnection {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
if self.direction {
conn.outbound_buffer.get_mut().extend_from_slice(buf);
} else {
conn.inbound_buffer.get_mut().extend_from_slice(buf);
}
for waker in conn.wakers.drain(..) {
waker.wake()
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
if self.direction {
Pin::new(&mut conn.outbound_buffer).poll_flush(cx)
} else {
Pin::new(&mut conn.inbound_buffer).poll_flush(cx)
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
if self.direction {
Pin::new(&mut conn.outbound_buffer).poll_shutdown(cx)
} else {
Pin::new(&mut conn.inbound_buffer).poll_shutdown(cx)
}
}
}
impl Connection for TestConnection {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[tokio::test]
async fn test_memory_connection() {
let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(TestConnector);
let response = client
.get("http://localhost".parse().unwrap())
.await
.unwrap();
assert!(response.status().is_success());
let bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&bytes[..], b"Hello World");
}

View File

@@ -1,5 +1,6 @@
use pyo3::prelude::*;
pub mod http;
pub mod push;
/// Returns the hash of all the rust source files at the time it was compiled.
@@ -26,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
push::register_module(py, m)?;
http::register_module(py, m)?;
Ok(())
}

View File

@@ -0,0 +1,16 @@
from typing import Dict, List, Optional
class MatrixResponse:
code: int
phrase: str
content: bytes
headers: Dict[str, str]
class HttpClient:
async def request(
self,
url: str,
method: str,
headers: Dict[bytes, List[bytes]],
body: Optional[bytes],
) -> MatrixResponse: ...

View File

@@ -29,7 +29,7 @@ if sys.version_info < (3, 7):
sys.exit(1)
# Allow using the asyncio reactor via env var.
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")):
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")) or True:
from incremental import Version
import twisted

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import asyncio
import cgi
import codecs
import logging
@@ -42,14 +43,18 @@ from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
from typing_extensions import Literal
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.internet.testing import StringTransport
from twisted.python.failure import Failure
from twisted.web.client import Response, ResponseDone, ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
@@ -75,6 +80,7 @@ from synapse.http.types import QueryParams
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.synapse_rust.http import HttpClient
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
@@ -199,6 +205,33 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
return json_decoder.decode(self._buffer.getvalue())
@attr.s(auto_attribs=True)
@implementer(IResponse)
class RustResponse:
version: tuple
code: int
phrase: bytes
headers: Headers
length: Union[int, UNKNOWN_LENGTH]
# request: Optional[IClientRequest]
# previousResponse: Optional[IResponse]
_data: bytes
def deliverBody(self, protocol: Protocol):
protocol.dataReceived(self._data)
protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
def setPreviousResponse(self, response: IResponse):
pass
async def _handle_response(
reactor: IReactorTime,
timeout_sec: float,
@@ -372,6 +405,8 @@ class MatrixFederationHttpClient:
self._sleeper = AwakenableSleeper(self.reactor)
self._rust_client = HttpClient()
def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online."""
@@ -556,11 +591,8 @@ class MatrixFederationHttpClient:
destination_bytes, method_bytes, url_to_sign_bytes, json
)
data = encode_canonical_json(json)
producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
)
else:
producer = None
data = None
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes
)
@@ -591,23 +623,33 @@ class MatrixFederationHttpClient:
# * The `Deferred` that joins the forks back together is
# wrapped in `make_deferred_yieldable` to restore the
# logging context regardless of the path taken.
request_deferred = run_in_background(
self.agent.request,
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.reactor,
)
# request_deferred = run_in_background(
# self._rust_client.request,
# url_str,
# request.method,
# headers_dict,
# data,
# )
# request_deferred = timeout_deferred(
# request_deferred,
# timeout=_sec_timeout,
# reactor=self.reactor,
# )
response = await make_deferred_yieldable(request_deferred)
# response = await make_deferred_yieldable(request_deferred)
response_d = run_in_background(
self._rust_client.request,
url_str,
request.method,
headers_dict,
data,
)
response = await make_deferred_yieldable(response_d)
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
logger.exception("ERROR")
raise RequestSendFailed(e, can_retry=True) from e
incoming_responses_counter.labels(
@@ -615,7 +657,7 @@ class MatrixFederationHttpClient:
).inc()
set_tag(tags.HTTP_STATUS_CODE, response.code)
response_phrase = response.phrase.decode("ascii", errors="replace")
response_phrase = response.phrase
if 200 <= response.code < 300:
logger.debug(
@@ -635,25 +677,7 @@ class MatrixFederationHttpClient:
)
# :'(
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
d, timeout=_sec_timeout, reactor=self.reactor
)
try:
body = await make_deferred_yieldable(d)
except Exception as e:
# Eh, we're already going to raise an exception so lets
# ignore if this fails.
logger.warning(
"{%s} [%s] Failed to get error response: %s %s: %s",
request.txn_id,
request.destination,
request.method,
url_str,
_flatten_response_never_received(e),
)
body = None
body = response.content
exc = HttpResponseException(
response.code, response_phrase, body
@@ -715,7 +739,19 @@ class MatrixFederationHttpClient:
_flatten_response_never_received(e),
)
raise
return response
headers = Headers()
for key, value in response.headers.items():
headers.addRawHeader(key, value)
return RustResponse(
("HTTP", 1, 1),
response.code,
response.phrase.encode("ascii"),
headers,
UNKNOWN_LENGTH,
response.content,
)
def build_auth_headers(
self,

View File

@@ -26,6 +26,7 @@ import logging
import threading
import typing
import warnings
from asyncio import Future
from types import TracebackType
from typing import (
TYPE_CHECKING,
@@ -814,6 +815,8 @@ def run_in_background( # type: ignore[misc]
res = defer.ensureDeferred(res)
elif isinstance(res, defer.Deferred):
pass
elif isinstance(res, Future):
res = defer.Deferred.fromFuture(res)
elif isinstance(res, Awaitable):
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
# or `Future` from `make_awaitable`.