Compare commits

...

1 Commits

Author SHA1 Message Date
Erik Johnston
ac05d88bcd WORKER PROXY WIP 2023-05-10 14:15:25 +01:00
7 changed files with 204 additions and 34 deletions

View File

@@ -381,6 +381,7 @@ def listen_unix(
def listen_http(
hs: "HomeServer",
listener_config: ListenerConfig,
root_resource: Resource,
version_string: str,
@@ -401,6 +402,7 @@ def listen_http(
version_string,
max_request_body_size=max_request_body_size,
reactor=reactor,
federation_agent=hs.get_federation_http_client().agent,
)
if isinstance(listener_config, TCPListenerConfig):

View File

@@ -223,6 +223,7 @@ class GenericWorkerServer(HomeServer):
root_resource = create_resource_tree(resources, OptionsResource())
_base.listen_http(
self,
listener_config,
root_resource,
self.version_string,

View File

@@ -139,6 +139,7 @@ class SynapseHomeServer(HomeServer):
root_resource = OptionsResource()
ports = listen_http(
self,
listener_config,
create_resource_tree(resources, root_resource),
self.version_string,

View File

@@ -50,7 +50,7 @@ from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
@@ -398,7 +398,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent: IAgent = BlacklistingAgentWrapper(
federation_agent,
ip_blacklist=hs.config.server.federation_ip_range_blacklist,
)

150
synapse/http/proxy.py Normal file
View File

@@ -0,0 +1,150 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Optional, Tuple, cast
from twisted.internet import protocol
from twisted.internet.interfaces import ITCPTransport
from twisted.internet.protocol import connectionDone
from twisted.python import failure
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IResponse
from twisted.web.resource import IResource
from twisted.web.server import Site
from synapse.http import QuieterFileBodyProducer
from synapse.http.server import _AsyncResource
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
from synapse.util.async_helpers import timeout_deferred
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest
logger = logging.getLogger(__name__)
class ProxyResource(_AsyncResource):
isLeaf = True
def __init__(self, reactor: ISynapseReactor, federation_agent: IAgent):
super().__init__(True)
self.reactor = reactor
self.agent = federation_agent
async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
assert request.uri.startswith(b"matrix://")
logger.info("Got proxy request %s", request.uri)
headers = Headers()
for header_name in (b"User-Agent", b"Authorization", b"Content-Type"):
header_value = request.getHeader(header_name)
if header_value:
headers.addRawHeader(header_name, header_value)
request_deferred = run_in_background(
self.agent.request,
request.method,
request.uri,
headers=headers,
bodyProducer=QuieterFileBodyProducer(request.content),
)
request_deferred = timeout_deferred(
request_deferred,
timeout=90,
reactor=self.reactor,
)
response = await make_deferred_yieldable(request_deferred)
logger.info("Got proxy response %s", response.code)
return response.code, response
def _send_response(
self,
request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
response = cast(IResponse, response_object)
request.setResponseCode(code)
# Copy headers.
for k, v in response.headers.getAllRawHeaders():
request.responseHeaders.setRawHeaders(k, v)
response.deliverBody(_ProxyResponseBody(request))
def _send_error_response(
self,
f: failure.Failure,
request: "SynapseRequest",
) -> None:
request.setResponseCode(502)
request.finish()
class _ProxyResponseBody(protocol.Protocol):
transport: Optional[ITCPTransport] = None
def __init__(self, request: "SynapseRequest") -> None:
self._request = request
def dataReceived(self, data: bytes) -> None:
if self._request._disconnected and self.transport is not None:
self.transport.abortConnection()
return
self._request.write(data)
def connectionLost(self, reason: Failure = connectionDone) -> None:
if self._request.finished:
return
if reason.check(ResponseDone):
self._request.finish()
elif reason.check(PotentialDataLoss):
# TODO: ARGH
self._request.finish()
else:
self._request.transport.abortConnection()
class ProxySite(Site):
def __init__(
self,
resource: IResource,
reactor: ISynapseReactor,
federation_agent: IAgent,
):
super().__init__(resource, reactor=reactor)
self._proxy_resource = ProxyResource(reactor, federation_agent)
def getResourceFor(self, request: "SynapseRequest") -> IResource:
uri = urllib.parse.urlparse(request.uri)
if uri.scheme == b"matrix":
return self._proxy_resource
return super().getResourceFor(request)

View File

@@ -18,6 +18,7 @@ import html
import logging
import types
import urllib
import urllib.parse
from http import HTTPStatus
from http.client import FOUND
from inspect import isawaitable
@@ -65,7 +66,6 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
@@ -76,6 +76,7 @@ from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
import opentracing
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -102,7 +103,7 @@ HTTP_STATUS_REQUEST_CANCELLED = 499
def return_json_error(
f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
f: failure.Failure, request: "SynapseRequest", config: Optional[HomeServerConfig]
) -> None:
"""Sends a JSON error response to clients."""
@@ -214,8 +215,8 @@ def return_html_error(
def wrap_async_request_handler(
h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]]
) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed
@@ -229,7 +230,7 @@ def wrap_async_request_handler(
"""
async def wrapped_async_request_handler(
self: "_AsyncResource", request: SynapseRequest
self: "_AsyncResource", request: "SynapseRequest"
) -> None:
with request.processing():
await h(self, request)
@@ -294,7 +295,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context
def render(self, request: SynapseRequest) -> int:
def render(self, request: "SynapseRequest") -> int:
"""This gets called by twisted every time someone sends us a request."""
request.render_deferred = defer.ensureDeferred(
self._async_render_wrapper(request)
@@ -302,7 +303,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return NOT_DONE_YET
@wrap_async_request_handler
async def _async_render_wrapper(self, request: SynapseRequest) -> None:
async def _async_render_wrapper(self, request: "SynapseRequest") -> None:
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
@@ -320,9 +321,14 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.exception(
"Error handling request", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore[arg-type]
)
self._send_error_response(f, request)
async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
async def _async_render(
self, request: "SynapseRequest"
) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for
different routing.
@@ -352,7 +358,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod
def _send_response(
self,
request: SynapseRequest,
request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
@@ -362,7 +368,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
def _send_error_response(
self,
f: failure.Failure,
request: SynapseRequest,
request: "SynapseRequest",
) -> None:
raise NotImplementedError()
@@ -378,7 +384,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response(
self,
request: SynapseRequest,
request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
@@ -395,7 +401,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_error_response(
self,
f: failure.Failure,
request: SynapseRequest,
request: "SynapseRequest",
) -> None:
"""Implements _AsyncResource._send_error_response"""
return_json_error(f, request, None)
@@ -467,7 +473,7 @@ class JsonResource(DirectServeJsonResource):
)
def _get_handler_for_request(
self, request: SynapseRequest
self, request: "SynapseRequest"
) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
@@ -497,7 +503,7 @@ class JsonResource(DirectServeJsonResource):
# Huh. No one wanted to handle that? Fiiiiiine.
raise UnrecognizedRequestError(code=404)
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
request.is_render_cancellable = is_function_cancellable(callback)
@@ -529,7 +535,7 @@ class JsonResource(DirectServeJsonResource):
def _send_error_response(
self,
f: failure.Failure,
request: SynapseRequest,
request: "SynapseRequest",
) -> None:
"""Implements _AsyncResource._send_error_response"""
return_json_error(f, request, self.hs.config)
@@ -545,7 +551,7 @@ class DirectServeHtmlResource(_AsyncResource):
def _send_response(
self,
request: SynapseRequest,
request: "SynapseRequest",
code: int,
response_object: Any,
) -> None:
@@ -559,7 +565,7 @@ class DirectServeHtmlResource(_AsyncResource):
def _send_error_response(
self,
f: failure.Failure,
request: SynapseRequest,
request: "SynapseRequest",
) -> None:
"""Implements _AsyncResource._send_error_response"""
return_html_error(f, request, self.ERROR_TEMPLATE)
@@ -586,7 +592,7 @@ class UnrecognizedRequestResource(resource.Resource):
errcode of M_UNRECOGNIZED.
"""
def render(self, request: SynapseRequest) -> int:
def render(self, request: "SynapseRequest") -> int:
f = failure.Failure(UnrecognizedRequestError(code=404))
return_json_error(f, request, None)
# A response has already been sent but Twisted requires either NOT_DONE_YET
@@ -616,7 +622,7 @@ class RootRedirect(resource.Resource):
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request: SynapseRequest) -> bytes:
def render_OPTIONS(self, request: "SynapseRequest") -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
@@ -731,7 +737,7 @@ def _encode_json_bytes(json_object: object) -> bytes:
def respond_with_json(
request: SynapseRequest,
request: "SynapseRequest",
code: int,
json_object: Any,
send_cors: bool = False,
@@ -781,7 +787,7 @@ def respond_with_json(
def respond_with_json_bytes(
request: SynapseRequest,
request: "SynapseRequest",
code: int,
json_bytes: bytes,
send_cors: bool = False,
@@ -819,7 +825,7 @@ def respond_with_json_bytes(
async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
request: "SynapseRequest",
json_encoder: Callable[[Any], bytes],
json_object: Any,
) -> None:
@@ -877,7 +883,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)
def set_cors_headers(request: SynapseRequest) -> None:
def set_cors_headers(request: "SynapseRequest") -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@@ -975,7 +981,7 @@ def set_clickjacking_protection_headers(request: Request) -> None:
def respond_with_redirect(
request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False
request: "SynapseRequest", url: bytes, statusCode: int = FOUND, cors: bool = False
) -> None:
"""
Write a 302 (or other specified status code) response to the request, if it is still alive.

View File

@@ -21,25 +21,28 @@ from zope.interface import implementer
from twisted.internet.address import UNIXAddress
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure
from twisted.web.http import HTTPChannel
from twisted.web.iweb import IAgent
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
from twisted.web.server import Request
from synapse.config.server import ListenerConfig
from synapse.http import get_request_user_agent, redact_uri
from synapse.http.proxy import ProxySite
from synapse.http.request_metrics import RequestMetrics, requests_counter
from synapse.logging.context import (
ContextRequest,
LoggingContext,
PreserveLoggingContext,
)
from synapse.types import Requester
from synapse.types import ISynapseReactor, Requester
if TYPE_CHECKING:
import opentracing
logger = logging.getLogger(__name__)
_next_request_seq = 0
@@ -102,7 +105,7 @@ class SynapseRequest(Request):
# A boolean indicating whether `render_deferred` should be cancelled if the
# client disconnects early. Expected to be set by the coroutine started by
# `Resource.render`, if rendering is asynchronous.
self.is_render_cancellable = False
self.is_render_cancellable: bool = False
global _next_request_seq
self.request_seq = _next_request_seq
@@ -596,7 +599,7 @@ class _XForwardedForAddress:
host: str
class SynapseSite(Site):
class SynapseSite(ProxySite):
"""
Synapse-specific twisted http Site
@@ -618,7 +621,8 @@ class SynapseSite(Site):
resource: IResource,
server_version_string: str,
max_request_body_size: int,
reactor: IReactorTime,
reactor: ISynapseReactor,
federation_agent: IAgent,
):
"""
@@ -633,7 +637,11 @@ class SynapseSite(Site):
dropping the connection
reactor: reactor to be used to manage connection timeouts
"""
Site.__init__(self, resource, reactor=reactor)
super().__init__(
resource=resource,
reactor=reactor,
federation_agent=federation_agent,
)
self.site_tag = site_tag
self.reactor = reactor
@@ -644,7 +652,9 @@ class SynapseSite(Site):
request_id_header = config.http_options.request_id_header
self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886
self.experimental_cors_msc3886: bool = (
config.http_options.experimental_cors_msc3886
)
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(