Compare commits

...

12 Commits

Author SHA1 Message Date
Erik Johnston
16e8580571 Write JSON response using a producer 2021-09-21 08:58:28 +01:00
Erik Johnston
0ce9315f9f Fix tests 2021-09-17 17:30:32 +01:00
Erik Johnston
c2d84edffc Don't create temporary function 2021-09-17 17:12:38 +01:00
Erik Johnston
d1f25c6df4 Fix tests 2021-09-17 17:12:09 +01:00
Erik Johnston
6c0bc18139 Fix tests 2021-09-17 15:28:39 +01:00
Erik Johnston
8521a0c976 Newsfile 2021-09-17 15:06:33 +01:00
Erik Johnston
e369a20d0a Encode JSON responses on a thread 2021-09-17 14:58:24 +01:00
Erik Johnston
40c99c22ff Add a _write_json_to_request_in_thread 2021-09-17 14:53:02 +01:00
Erik Johnston
fbcbfb4aa4 Require SynapseRequest for respond_with_json 2021-09-17 14:53:02 +01:00
Erik Johnston
2f8abe0905 Add reactor to SynapseRequest 2021-09-17 14:52:58 +01:00
Erik Johnston
341a92b7d0 Fix SynapseRequest.site type.
There were two issues: 1) `channel` is actually private type so its hard
to type `.site`, and 2) `.site` was actually overwriting an existing
member and so we need to rename it.
2021-09-17 14:41:29 +01:00
Erik Johnston
d18c71abab Add types to http.site 2021-09-17 14:41:29 +01:00
26 changed files with 219 additions and 117 deletions

View File

@@ -0,0 +1 @@
Speed up responding with large JSON objects to requests.

View File

@@ -21,7 +21,6 @@ import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
from io import BytesIO
from typing import (
Any,
Awaitable,
@@ -37,7 +36,7 @@ from typing import (
)
import jinja2
from canonicaljson import iterencode_canonical_json
from canonicaljson import encode_canonical_json
from typing_extensions import Protocol
from zope.interface import implementer
@@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces
from twisted.python import failure
from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File, NoRangeStaticProducer
from twisted.web.static import File
from twisted.web.util import redirectTo
from synapse.api.errors import (
@@ -56,10 +55,11 @@ from synapse.api.errors import (
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import preserve_fn
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq
logger = logging.getLogger(__name__)
@@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response(
self,
request: Request,
request: SynapseRequest,
code: int,
response_object: Any,
):
@@ -620,16 +620,15 @@ class _ByteProducer:
self._request = None
def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
def _encode_json_bytes(json_object: Any) -> bytes:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
for chunk in json_encoder.iterencode(json_object):
yield chunk.encode("utf-8")
return json_encoder.encode(json_object).encode("utf-8")
def respond_with_json(
request: Request,
request: SynapseRequest,
code: int,
json_object: Any,
send_cors: bool = False,
@@ -659,7 +658,7 @@ def respond_with_json(
return None
if canonical_json:
encoder = iterencode_canonical_json
encoder = encode_canonical_json
else:
encoder = _encode_json_bytes
@@ -670,7 +669,9 @@ def respond_with_json(
if send_cors:
set_cors_headers(request)
_ByteProducer(request, encoder(json_object))
run_in_background(
_async_write_json_to_request_in_thread, request, encoder, json_object
)
return NOT_DONE_YET
@@ -706,15 +707,35 @@ def respond_with_json_bytes(
if send_cors:
set_cors_headers(request)
# note that this is zero-copy (the bytesio shares a copy-on-write buffer with
# the original `bytes`).
bytes_io = BytesIO(json_bytes)
producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
_write_json_bytes_to_request(request, json_bytes)
return NOT_DONE_YET
def _write_json_bytes_to_request(request: Request, json_bytes: bytes) -> None:
"""Writes the JSON bytes to the request using an appropriate producer.
Note: This should be used instead of `Request.write` to correctly handle
large response bodies.
"""
# The problem with dumping all of the json response into the `Request`
# object at once (via `Request.write`) is that doing so starts the timeout
# for the next request to be received: so if it takes longer than 60s to
# stream back the response to the client, the client never gets it.
#
# The correct solution is to use a Producer; then the timeout is only
# started once all of the content is sent over the TCP connection.
# To make sure we don't write the whole of the json at once we split it up
# into chunks.
chunk_size = 4096
bytes_generator = chunk_seq(json_bytes, chunk_size)
# We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the
# unit tests can't cope with being given a pull producer.
_ByteProducer(request, bytes_generator)
def set_cors_headers(request: Request):
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@@ -809,3 +830,24 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
):
"""Encodes the given JSON object on a thread and then writes it to the
request.
This is done so that encoding large JSON objects doesn't block the reactor
thread.
Note: We don't use JsonEncoder.iterencode here as that falls back to the
Python implementation (rather than the C backend), which is *much* more
expensive.
"""
json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
_write_json_bytes_to_request(request, json_str)

View File

@@ -14,14 +14,15 @@
import contextlib
import logging
import time
from typing import Optional, Tuple, Union
from typing import Generator, Optional, Tuple, Union
import attr
from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.http import HTTPChannel
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
logcontext: the log context for this request
"""
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
Request.__init__(self, channel, *args, **kw)
def __init__(
self,
channel: HTTPChannel,
site: "SynapseSite",
*args,
max_request_body_size: int = 1024,
**kw,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site
self.synapse_site = site
self.reactor = site.reactor
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -83,13 +92,13 @@ class SynapseRequest(Request):
self._is_processing = False
# the time when the asynchronous request handler completed its processing
self._processing_finished_time = None
self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection
# dropped)
self.finish_time = None
self.finish_time: Optional[float] = None
def __repr__(self):
def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
@@ -97,10 +106,10 @@ class SynapseRequest(Request):
self.get_method(),
self.get_redacted_uri(),
self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag,
self.synapse_site.site_tag,
)
def handleContentChunk(self, data):
def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size:
@@ -139,7 +148,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self):
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str:
@@ -205,7 +214,7 @@ class SynapseRequest(Request):
return None, None
def render(self, resrc):
def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource.
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
request=ContextRequest(
request_id=request_id,
ip_address=self.getClientIP(),
site_tag=self.site.site_tag,
site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point.
requester=None,
authenticated_entity=None,
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
self.setHeader("Server", self.synapse_site.server_version_string)
with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
def processing(self) -> Generator[None, None, None]:
"""Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is:
@@ -282,7 +291,7 @@ class SynapseRequest(Request):
if self.finish_time is not None:
self._finished_processing()
def finish(self):
def finish(self) -> None:
"""Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do
@@ -295,7 +304,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext):
self._finished_processing()
def connectionLost(self, reason):
def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and
@@ -327,7 +336,7 @@ class SynapseRequest(Request):
if not self._is_processing:
self._finished_processing()
def _started_processing(self, servlet_name):
def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes,
@@ -346,17 +355,19 @@ class SynapseRequest(Request):
self.start_time, name=servlet_name, method=self.get_method()
)
self.site.access_logger.debug(
self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.synapse_site.site_tag,
self.get_method(),
self.get_redacted_uri(),
)
def _finished_processing(self):
def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
assert self.finish_time is not None
usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None:
@@ -386,13 +397,13 @@ class SynapseRequest(Request):
if authenticated_entity:
requester = f"{authenticated_entity}|{requester}"
self.site.access_logger.log(
self.synapse_site.access_logger.log(
log_level,
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.site.site_tag,
self.synapse_site.site_tag,
requester,
processing_time,
response_send_time,
@@ -437,7 +448,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False
def requestReceived(self, command, path, version):
def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the
@@ -445,7 +456,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers()
return super().requestReceived(command, path, version)
def _process_forwarded_headers(self):
def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers:
return
@@ -470,7 +481,7 @@ class XForwardedForRequest(SynapseRequest):
)
self._forwarded_https = True
def isSecure(self):
def isSecure(self) -> bool:
if self._forwarded_https:
return True
return super().isSecure()
@@ -520,7 +531,7 @@ class SynapseSite(Site):
site_tag: str,
config: ListenerConfig,
resource: IResource,
server_version_string,
server_version_string: str,
max_request_body_size: int,
reactor: IReactorTime,
):
@@ -540,19 +551,23 @@ class SynapseSite(Site):
Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag
self.reactor = reactor
assert config.http_options is not None
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request:
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued
channel,
self,
max_request_body_size=max_request_body_size,
queued=queued,
)
self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
def log(self, request: SynapseRequest) -> None:
pass

View File

@@ -184,7 +184,7 @@ class EmailPusher(Pusher):
should_notify_at = max(notif_ready_at, room_ready_at)
if should_notify_at < self.clock.time_msec():
if should_notify_at <= self.clock.time_msec():
# one of our notifications is ready for sending, so we send
# *one* email updating the user on their notifications,
# we then consider all previously outstanding notifications

View File

@@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -100,7 +99,7 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
async def _async_render_GET(self, request: Request) -> None:
async def _async_render_GET(self, request: SynapseRequest) -> None:
assert request.postpath is not None
if len(request.postpath) == 1:
(server,) = request.postpath
@@ -117,7 +116,7 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request: Request) -> None:
async def _async_render_POST(self, request: SynapseRequest) -> None:
content = parse_json_object_from_request(request)
query = content["server_keys"]
@@ -126,7 +125,7 @@ class RemoteKey(DirectServeJsonResource):
async def query_keys(
self,
request: Request,
request: SynapseRequest,
query: JsonDict,
query_remote_on_cache_miss: bool = False,
) -> None:

View File

@@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii
@@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
)
def respond_404(request: Request) -> None:
def respond_404(request: SynapseRequest) -> None:
respond_with_json(
request,
404,
@@ -84,7 +85,7 @@ def respond_404(request: Request) -> None:
async def respond_with_file(
request: Request,
request: SynapseRequest,
media_type: str,
file_path: str,
file_size: Optional[int] = None,
@@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool:
async def respond_with_responder(
request: Request,
request: SynapseRequest,
responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],

View File

@@ -16,8 +16,6 @@
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest
@@ -39,5 +37,5 @@ class MediaConfigResource(DirectServeJsonResource):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request: Request) -> None:
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)

View File

@@ -15,10 +15,9 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
from synapse.http.site import SynapseRequest
from ._base import parse_media_id, respond_404
@@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource):
self.media_repo = media_repo
self.server_name = hs.hostname
async def _async_render_GET(self, request: Request) -> None:
async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",

View File

@@ -23,7 +23,6 @@ import twisted.internet.error
import twisted.web.http
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import (
FederationDeniedError,
@@ -34,6 +33,7 @@ from synapse.api.errors import (
)
from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
@@ -187,7 +187,7 @@ class MediaRepository:
return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
self, request: SynapseRequest, media_id: str, name: Optional[str]
) -> None:
"""Responds to requests for local media, if exists, or returns 404.
@@ -221,7 +221,11 @@ class MediaRepository:
)
async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str]
self,
request: SynapseRequest,
server_name: str,
media_id: str,
name: Optional[str],
) -> None:
"""Respond to requests for remote media.

View File

@@ -29,7 +29,6 @@ import attr
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -167,7 +166,7 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
async def _async_render_OPTIONS(self, request: Request) -> None:
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)

View File

@@ -17,11 +17,10 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
@@ -57,7 +56,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
async def _async_render_GET(self, request: Request) -> None:
async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_local_thumbnail(
self,
request: Request,
request: SynapseRequest,
media_id: str,
width: int,
height: int,
@@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail(
self,
request: Request,
request: SynapseRequest,
media_id: str,
desired_width: int,
desired_height: int,
@@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
request: Request,
request: SynapseRequest,
server_name: str,
media_id: str,
desired_width: int,
@@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_remote_thumbnail(
self,
request: Request,
request: SynapseRequest,
server_name: str,
media_id: str,
width: int,
@@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_and_respond_with_thumbnail(
self,
request: Request,
request: SynapseRequest,
desired_width: int,
desired_height: int,
desired_method: str,

View File

@@ -16,8 +16,6 @@
import logging
from typing import IO, TYPE_CHECKING, Dict, List, Optional
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_bytes_from_args
@@ -46,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request: Request) -> None:
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: SynapseRequest) -> None:

View File

@@ -21,13 +21,28 @@ from typing import (
Iterable,
Iterator,
Mapping,
Sequence,
Set,
Sized,
Tuple,
TypeVar,
)
from typing_extensions import Protocol
T = TypeVar("T")
S = TypeVar("S", bound="_SelfSlice")
class _SelfSlice(Sized, Protocol):
"""A helper protocol that matches types where taking a slice results in the
same type being returned.
This is more specific than `Sequence`, which allows another `Sequence` to be
returned.
"""
def __getitem__(self: S, i: slice) -> S:
...
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
@@ -46,7 +61,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
return iter(lambda: tuple(islice(sourceiter, size)), ())
def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]:
"""Split the given sequence into chunks of the given size
The last chunk may be shorter than the given size.

View File

@@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
channel = make_request(
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
@@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler)
channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
channel = make_request(
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
site.site_tag = "test-site"
site.server_version_string = "Server v1"
request = SynapseRequest(FakeChannel(site, None))
site.reactor = Mock()
request = SynapseRequest(FakeChannel(site, None), site)
# Call requestReceived to finish instantiating the object.
request.content = BytesIO()
# Partially skip some of the internal processing of SynapseRequest.

View File

@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
resource = hs.get_media_repository_resource().children[b"download"]
channel = make_request(
self.reactor,
FakeSite(resource),
FakeSite(resource, self.reactor),
"GET",
f"/{target}/{media_id}",
shorthand=False,

View File

@@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Ensure a piece of media is quarantined when trying to access it."""
channel = make_request(
self.reactor,
FakeSite(self.download_resource),
FakeSite(self.download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media
channel = make_request(
self.reactor,
FakeSite(self.download_resource),
FakeSite(self.download_resource, self.reactor),
"GET",
server_name_and_media_id,
shorthand=False,
@@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access each piece of media
channel = make_request(
self.reactor,
FakeSite(self.download_resource),
FakeSite(self.download_resource, self.reactor),
"GET",
server_and_media_id_2,
shorthand=False,

View File

@@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media
channel = make_request(
self.reactor,
FakeSite(download_resource),
FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media
channel = make_request(
self.reactor,
FakeSite(download_resource),
FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,
@@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
FakeSite(download_resource),
FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,

View File

@@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts`
channel = make_request(
self.reactor,
FakeSite(download_resource),
FakeSite(download_resource, self.reactor),
"GET",
server_and_media_id,
shorthand=False,

View File

@@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Load the password reset confirmation page
channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
FakeSite(self.submit_token_resource, self.reactor),
"GET",
path,
shorthand=False,
@@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Confirm the password reset
channel = make_request(
self.reactor,
FakeSite(self.submit_token_resource),
FakeSite(self.submit_token_resource, self.reactor),
"POST",
path,
content=b"",

View File

@@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"""You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs)
channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
self.reactor,
FakeSite(resource, self.reactor),
"GET",
"/consent?v=1",
shorthand=False,
)
self.assertEqual(channel.code, 200)
@@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
)
channel = make_request(
self.reactor,
FakeSite(resource),
FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,
@@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# POST to the consent page, saying we've agreed
channel = make_request(
self.reactor,
FakeSite(resource),
FakeSite(resource, self.reactor),
"POST",
consent_uri + "&v=" + version,
access_token=access_token,
@@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# changed
channel = make_request(
self.reactor,
FakeSite(resource),
FakeSite(resource, self.reactor),
"GET",
consent_uri,
access_token=access_token,

View File

@@ -372,7 +372,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request(
self.hs.get_reactor(),
FakeSite(resource),
FakeSite(resource, self.hs.get_reactor()),
"POST",
path,
content=image_data,

View File

@@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body.
"""
channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel)
req = SynapseRequest(channel, self.site)
req.content = BytesIO(b"")
req.requestReceived(
b"GET",
@@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
)
channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel)
req = SynapseRequest(channel, self.site)
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(

View File

@@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
FakeSite(self.download_resource),
FakeSite(self.download_resource, self.reactor),
"GET",
self.media_id,
shorthand=False,
@@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale"
channel = make_request(
self.reactor,
FakeSite(self.thumbnail_resource),
FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
@@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request(
self.reactor,
FakeSite(self.thumbnail_resource),
FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,
@@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
FakeSite(self.thumbnail_resource),
FakeSite(self.thumbnail_resource, self.reactor),
"GET",
self.media_id + params,
shorthand=False,

View File

@@ -19,6 +19,7 @@ from twisted.internet.interfaces import (
IPullProducer,
IPushProducer,
IReactorPluggableNameResolver,
IReactorTime,
IResolverSimple,
ITransport,
)
@@ -181,13 +182,14 @@ class FakeSite:
site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource):
def __init__(self, resource: IResource, reactor: IReactorTime):
"""
Args:
resource: the resource to be used for rendering all requests
"""
self._resource = resource
self.reactor = reactor
def getResourceFor(self, request):
return self._resource
@@ -268,7 +270,7 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel)
req = request(channel, site)
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END)

View File

@@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase):
)
make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
self.reactor,
FakeSite(res, self.reactor),
b"GET",
b"/_matrix/foo/%E2%98%83?a=%E2%98%83",
)
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
@@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"500")
@@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase):
def _callback(request, **kwargs):
d = Deferred()
d.addCallback(_throw)
self.reactor.callLater(1, d.callback, True)
self.reactor.callLater(0.5, d.callback, True)
return make_deferred_yieldable(d)
res = JsonResource(self.homeserver)
@@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"500")
@@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
)
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
)
self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase):
)
# The path was registered as GET, but this is a HEAD request.
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
@@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"]
@@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"]
@@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"]
@@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
)
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)