Compare commits

...

12 Commits

Author SHA1 Message Date
Erik Johnston
16e8580571 Write JSON response using a producer 2021-09-21 08:58:28 +01:00
Erik Johnston
0ce9315f9f Fix tests 2021-09-17 17:30:32 +01:00
Erik Johnston
c2d84edffc Don't create temporary function 2021-09-17 17:12:38 +01:00
Erik Johnston
d1f25c6df4 Fix tests 2021-09-17 17:12:09 +01:00
Erik Johnston
6c0bc18139 Fix tests 2021-09-17 15:28:39 +01:00
Erik Johnston
8521a0c976 Newsfile 2021-09-17 15:06:33 +01:00
Erik Johnston
e369a20d0a Encode JSON responses on a thread 2021-09-17 14:58:24 +01:00
Erik Johnston
40c99c22ff Add a _write_json_to_request_in_thread 2021-09-17 14:53:02 +01:00
Erik Johnston
fbcbfb4aa4 Require SynapseRequest for respond_with_json 2021-09-17 14:53:02 +01:00
Erik Johnston
2f8abe0905 Add reactor to SynapseRequest 2021-09-17 14:52:58 +01:00
Erik Johnston
341a92b7d0 Fix SynapseRequest.site type.
There were two issues: 1) `channel` is actually private type so its hard
to type `.site`, and 2) `.site` was actually overwriting an existing
member and so we need to rename it.
2021-09-17 14:41:29 +01:00
Erik Johnston
d18c71abab Add types to http.site 2021-09-17 14:41:29 +01:00
26 changed files with 219 additions and 117 deletions

View File

@@ -0,0 +1 @@
Speed up responding with large JSON objects to requests.

View File

@@ -21,7 +21,6 @@ import types
import urllib import urllib
from http import HTTPStatus from http import HTTPStatus
from inspect import isawaitable from inspect import isawaitable
from io import BytesIO
from typing import ( from typing import (
Any, Any,
Awaitable, Awaitable,
@@ -37,7 +36,7 @@ from typing import (
) )
import jinja2 import jinja2
from canonicaljson import iterencode_canonical_json from canonicaljson import encode_canonical_json
from typing_extensions import Protocol from typing_extensions import Protocol
from zope.interface import implementer from zope.interface import implementer
@@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces
from twisted.python import failure from twisted.python import failure
from twisted.web import resource from twisted.web import resource
from twisted.web.server import NOT_DONE_YET, Request from twisted.web.server import NOT_DONE_YET, Request
from twisted.web.static import File, NoRangeStaticProducer from twisted.web.static import File
from twisted.web.util import redirectTo from twisted.web.util import redirectTo
from synapse.api.errors import ( from synapse.api.errors import (
@@ -56,10 +55,11 @@ from synapse.api.errors import (
UnrecognizedRequestError, UnrecognizedRequestError,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import preserve_fn from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
from synapse.logging.opentracing import trace_servlet from synapse.logging.opentracing import trace_servlet
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
def _send_response( def _send_response(
self, self,
request: Request, request: SynapseRequest,
code: int, code: int,
response_object: Any, response_object: Any,
): ):
@@ -620,16 +620,15 @@ class _ByteProducer:
self._request = None self._request = None
def _encode_json_bytes(json_object: Any) -> Iterator[bytes]: def _encode_json_bytes(json_object: Any) -> bytes:
""" """
Encode an object into JSON. Returns an iterator of bytes. Encode an object into JSON. Returns an iterator of bytes.
""" """
for chunk in json_encoder.iterencode(json_object): return json_encoder.encode(json_object).encode("utf-8")
yield chunk.encode("utf-8")
def respond_with_json( def respond_with_json(
request: Request, request: SynapseRequest,
code: int, code: int,
json_object: Any, json_object: Any,
send_cors: bool = False, send_cors: bool = False,
@@ -659,7 +658,7 @@ def respond_with_json(
return None return None
if canonical_json: if canonical_json:
encoder = iterencode_canonical_json encoder = encode_canonical_json
else: else:
encoder = _encode_json_bytes encoder = _encode_json_bytes
@@ -670,7 +669,9 @@ def respond_with_json(
if send_cors: if send_cors:
set_cors_headers(request) set_cors_headers(request)
_ByteProducer(request, encoder(json_object)) run_in_background(
_async_write_json_to_request_in_thread, request, encoder, json_object
)
return NOT_DONE_YET return NOT_DONE_YET
@@ -706,15 +707,35 @@ def respond_with_json_bytes(
if send_cors: if send_cors:
set_cors_headers(request) set_cors_headers(request)
# note that this is zero-copy (the bytesio shares a copy-on-write buffer with _write_json_bytes_to_request(request, json_bytes)
# the original `bytes`).
bytes_io = BytesIO(json_bytes)
producer = NoRangeStaticProducer(request, bytes_io)
producer.start()
return NOT_DONE_YET return NOT_DONE_YET
def _write_json_bytes_to_request(request: Request, json_bytes: bytes) -> None:
"""Writes the JSON bytes to the request using an appropriate producer.
Note: This should be used instead of `Request.write` to correctly handle
large response bodies.
"""
# The problem with dumping all of the json response into the `Request`
# object at once (via `Request.write`) is that doing so starts the timeout
# for the next request to be received: so if it takes longer than 60s to
# stream back the response to the client, the client never gets it.
#
# The correct solution is to use a Producer; then the timeout is only
# started once all of the content is sent over the TCP connection.
# To make sure we don't write the whole of the json at once we split it up
# into chunks.
chunk_size = 4096
bytes_generator = chunk_seq(json_bytes, chunk_size)
# We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the
# unit tests can't cope with being given a pull producer.
_ByteProducer(request, bytes_generator)
def set_cors_headers(request: Request): def set_cors_headers(request: Request):
"""Set the CORS headers so that javascript running in a web browsers can """Set the CORS headers so that javascript running in a web browsers can
use this API use this API
@@ -809,3 +830,24 @@ def finish_request(request: Request):
request.finish() request.finish()
except RuntimeError as e: except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e) logger.info("Connection disconnected before response was written: %r", e)
async def _async_write_json_to_request_in_thread(
request: SynapseRequest,
json_encoder: Callable[[Any], bytes],
json_object: Any,
):
"""Encodes the given JSON object on a thread and then writes it to the
request.
This is done so that encoding large JSON objects doesn't block the reactor
thread.
Note: We don't use JsonEncoder.iterencode here as that falls back to the
Python implementation (rather than the C backend), which is *much* more
expensive.
"""
json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
_write_json_bytes_to_request(request, json_str)

View File

@@ -14,14 +14,15 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Optional, Tuple, Union from typing import Generator, Optional, Tuple, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.interfaces import IAddress, IReactorTime from twisted.internet.interfaces import IAddress, IReactorTime
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.web.resource import IResource from twisted.web.http import HTTPChannel
from twisted.web.resource import IResource, Resource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.server import ListenerConfig from synapse.config.server import ListenerConfig
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
logcontext: the log context for this request logcontext: the log context for this request
""" """
def __init__(self, channel, *args, max_request_body_size=1024, **kw): def __init__(
Request.__init__(self, channel, *args, **kw) self,
channel: HTTPChannel,
site: "SynapseSite",
*args,
max_request_body_size: int = 1024,
**kw,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size self._max_request_body_size = max_request_body_size
self.site: SynapseSite = channel.site self.synapse_site = site
self.reactor = site.reactor
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.start_time = 0.0 self.start_time = 0.0
@@ -83,13 +92,13 @@ class SynapseRequest(Request):
self._is_processing = False self._is_processing = False
# the time when the asynchronous request handler completed its processing # the time when the asynchronous request handler completed its processing
self._processing_finished_time = None self._processing_finished_time: Optional[float] = None
# what time we finished sending the response to the client (or the connection # what time we finished sending the response to the client (or the connection
# dropped) # dropped)
self.finish_time = None self.finish_time: Optional[float] = None
def __repr__(self): def __repr__(self) -> str:
# We overwrite this so that we don't log ``access_token`` # We overwrite this so that we don't log ``access_token``
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % ( return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__, self.__class__.__name__,
@@ -97,10 +106,10 @@ class SynapseRequest(Request):
self.get_method(), self.get_method(),
self.get_redacted_uri(), self.get_redacted_uri(),
self.clientproto.decode("ascii", errors="replace"), self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag, self.synapse_site.site_tag,
) )
def handleContentChunk(self, data): def handleContentChunk(self, data: bytes) -> None:
# we should have a `content` by now. # we should have a `content` by now.
assert self.content, "handleContentChunk() called before gotLength()" assert self.content, "handleContentChunk() called before gotLength()"
if self.content.tell() + len(data) > self._max_request_body_size: if self.content.tell() + len(data) > self._max_request_body_size:
@@ -139,7 +148,7 @@ class SynapseRequest(Request):
# If there's no authenticated entity, it was the requester. # If there's no authenticated entity, it was the requester.
self.logcontext.request.authenticated_entity = authenticated_entity or requester self.logcontext.request.authenticated_entity = authenticated_entity or requester
def get_request_id(self): def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq) return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str: def get_redacted_uri(self) -> str:
@@ -205,7 +214,7 @@ class SynapseRequest(Request):
return None, None return None, None
def render(self, resrc): def render(self, resrc: Resource) -> None:
# this is called once a Resource has been found to serve the request; in our # this is called once a Resource has been found to serve the request; in our
# case the Resource in question will normally be a JsonResource. # case the Resource in question will normally be a JsonResource.
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
request=ContextRequest( request=ContextRequest(
request_id=request_id, request_id=request_id,
ip_address=self.getClientIP(), ip_address=self.getClientIP(),
site_tag=self.site.site_tag, site_tag=self.synapse_site.site_tag,
# The requester is going to be unknown at this point. # The requester is going to be unknown at this point.
requester=None, requester=None,
authenticated_entity=None, authenticated_entity=None,
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
) )
# override the Server header which is set by twisted # override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string) self.setHeader("Server", self.synapse_site.server_version_string)
with PreserveLoggingContext(self.logcontext): with PreserveLoggingContext(self.logcontext):
# we start the request metrics timer here with an initial stab # we start the request metrics timer here with an initial stab
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
requests_counter.labels(self.get_method(), self.request_metrics.name).inc() requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager @contextlib.contextmanager
def processing(self): def processing(self) -> Generator[None, None, None]:
"""Record the fact that we are processing this request. """Record the fact that we are processing this request.
Returns a context manager; the correct way to use this is: Returns a context manager; the correct way to use this is:
@@ -282,7 +291,7 @@ class SynapseRequest(Request):
if self.finish_time is not None: if self.finish_time is not None:
self._finished_processing() self._finished_processing()
def finish(self): def finish(self) -> None:
"""Called when all response data has been written to this Request. """Called when all response data has been written to this Request.
Overrides twisted.web.server.Request.finish to record the finish time and do Overrides twisted.web.server.Request.finish to record the finish time and do
@@ -295,7 +304,7 @@ class SynapseRequest(Request):
with PreserveLoggingContext(self.logcontext): with PreserveLoggingContext(self.logcontext):
self._finished_processing() self._finished_processing()
def connectionLost(self, reason): def connectionLost(self, reason: Union[Failure, Exception]) -> None:
"""Called when the client connection is closed before the response is written. """Called when the client connection is closed before the response is written.
Overrides twisted.web.server.Request.connectionLost to record the finish time and Overrides twisted.web.server.Request.connectionLost to record the finish time and
@@ -327,7 +336,7 @@ class SynapseRequest(Request):
if not self._is_processing: if not self._is_processing:
self._finished_processing() self._finished_processing()
def _started_processing(self, servlet_name): def _started_processing(self, servlet_name: str) -> None:
"""Record the fact that we are processing this request. """Record the fact that we are processing this request.
This will log the request's arrival. Once the request completes, This will log the request's arrival. Once the request completes,
@@ -346,17 +355,19 @@ class SynapseRequest(Request):
self.start_time, name=servlet_name, method=self.get_method() self.start_time, name=servlet_name, method=self.get_method()
) )
self.site.access_logger.debug( self.synapse_site.access_logger.debug(
"%s - %s - Received request: %s %s", "%s - %s - Received request: %s %s",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.synapse_site.site_tag,
self.get_method(), self.get_method(),
self.get_redacted_uri(), self.get_redacted_uri(),
) )
def _finished_processing(self): def _finished_processing(self) -> None:
"""Log the completion of this request and update the metrics""" """Log the completion of this request and update the metrics"""
assert self.logcontext is not None assert self.logcontext is not None
assert self.finish_time is not None
usage = self.logcontext.get_resource_usage() usage = self.logcontext.get_resource_usage()
if self._processing_finished_time is None: if self._processing_finished_time is None:
@@ -386,13 +397,13 @@ class SynapseRequest(Request):
if authenticated_entity: if authenticated_entity:
requester = f"{authenticated_entity}|{requester}" requester = f"{authenticated_entity}|{requester}"
self.site.access_logger.log( self.synapse_site.access_logger.log(
log_level, log_level,
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
' %sB %s "%s %s %s" "%s" [%d dbevts]', ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.synapse_site.site_tag,
requester, requester,
processing_time, processing_time,
response_send_time, response_send_time,
@@ -437,7 +448,7 @@ class XForwardedForRequest(SynapseRequest):
_forwarded_for: "Optional[_XForwardedForAddress]" = None _forwarded_for: "Optional[_XForwardedForAddress]" = None
_forwarded_https: bool = False _forwarded_https: bool = False
def requestReceived(self, command, path, version): def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
# this method is called by the Channel once the full request has been # this method is called by the Channel once the full request has been
# received, to dispatch the request to a resource. # received, to dispatch the request to a resource.
# We can use it to set the IP address and protocol according to the # We can use it to set the IP address and protocol according to the
@@ -445,7 +456,7 @@ class XForwardedForRequest(SynapseRequest):
self._process_forwarded_headers() self._process_forwarded_headers()
return super().requestReceived(command, path, version) return super().requestReceived(command, path, version)
def _process_forwarded_headers(self): def _process_forwarded_headers(self) -> None:
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for") headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
if not headers: if not headers:
return return
@@ -470,7 +481,7 @@ class XForwardedForRequest(SynapseRequest):
) )
self._forwarded_https = True self._forwarded_https = True
def isSecure(self): def isSecure(self) -> bool:
if self._forwarded_https: if self._forwarded_https:
return True return True
return super().isSecure() return super().isSecure()
@@ -520,7 +531,7 @@ class SynapseSite(Site):
site_tag: str, site_tag: str,
config: ListenerConfig, config: ListenerConfig,
resource: IResource, resource: IResource,
server_version_string, server_version_string: str,
max_request_body_size: int, max_request_body_size: int,
reactor: IReactorTime, reactor: IReactorTime,
): ):
@@ -540,19 +551,23 @@ class SynapseSite(Site):
Site.__init__(self, resource, reactor=reactor) Site.__init__(self, resource, reactor=reactor)
self.site_tag = site_tag self.site_tag = site_tag
self.reactor = reactor
assert config.http_options is not None assert config.http_options is not None
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued) -> Request: def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class( return request_class(
channel, max_request_body_size=max_request_body_size, queued=queued channel,
self,
max_request_body_size=max_request_body_size,
queued=queued,
) )
self.requestFactory = request_factory # type: ignore self.requestFactory = request_factory # type: ignore
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii") self.server_version_string = server_version_string.encode("ascii")
def log(self, request): def log(self, request: SynapseRequest) -> None:
pass pass

View File

@@ -184,7 +184,7 @@ class EmailPusher(Pusher):
should_notify_at = max(notif_ready_at, room_ready_at) should_notify_at = max(notif_ready_at, room_ready_at)
if should_notify_at < self.clock.time_msec(): if should_notify_at <= self.clock.time_msec():
# one of our notifications is ready for sending, so we send # one of our notifications is ready for sending, so we send
# *one* email updating the user on their notifications, # *one* email updating the user on their notifications,
# we then consider all previously outstanding notifications # we then consider all previously outstanding notifications

View File

@@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
@@ -100,7 +99,7 @@ class RemoteKey(DirectServeJsonResource):
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config self.config = hs.config
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
assert request.postpath is not None assert request.postpath is not None
if len(request.postpath) == 1: if len(request.postpath) == 1:
(server,) = request.postpath (server,) = request.postpath
@@ -117,7 +116,7 @@ class RemoteKey(DirectServeJsonResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True) await self.query_keys(request, query, query_remote_on_cache_miss=True)
async def _async_render_POST(self, request: Request) -> None: async def _async_render_POST(self, request: SynapseRequest) -> None:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
query = content["server_keys"] query = content["server_keys"]
@@ -126,7 +125,7 @@ class RemoteKey(DirectServeJsonResource):
async def query_keys( async def query_keys(
self, self,
request: Request, request: SynapseRequest,
query: JsonDict, query: JsonDict,
query_remote_on_cache_miss: bool = False, query_remote_on_cache_miss: bool = False,
) -> None: ) -> None:

View File

@@ -27,6 +27,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii from synapse.util.stringutils import is_ascii
@@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
) )
def respond_404(request: Request) -> None: def respond_404(request: SynapseRequest) -> None:
respond_with_json( respond_with_json(
request, request,
404, 404,
@@ -84,7 +85,7 @@ def respond_404(request: Request) -> None:
async def respond_with_file( async def respond_with_file(
request: Request, request: SynapseRequest,
media_type: str, media_type: str,
file_path: str, file_path: str,
file_size: Optional[int] = None, file_size: Optional[int] = None,
@@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool:
async def respond_with_responder( async def respond_with_responder(
request: Request, request: SynapseRequest,
responder: "Optional[Responder]", responder: "Optional[Responder]",
media_type: str, media_type: str,
file_size: Optional[int], file_size: Optional[int],

View File

@@ -16,8 +16,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@@ -39,5 +37,5 @@ class MediaConfigResource(DirectServeJsonResource):
await self.auth.get_user_by_req(request) await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request: Request) -> None: async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)

View File

@@ -15,10 +15,9 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean from synapse.http.servlet import parse_boolean
from synapse.http.site import SynapseRequest
from ._base import parse_media_id, respond_404 from ._base import parse_media_id, respond_404
@@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource):
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(
b"Content-Security-Policy", b"Content-Security-Policy",

View File

@@ -23,7 +23,6 @@ import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import ( from synapse.api.errors import (
FederationDeniedError, FederationDeniedError,
@@ -34,6 +33,7 @@ from synapse.api.errors import (
) )
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.repository import ThumbnailRequirement from synapse.config.repository import ThumbnailRequirement
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID from synapse.types import UserID
@@ -187,7 +187,7 @@ class MediaRepository:
return "mxc://%s/%s" % (self.server_name, media_id) return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media( async def get_local_media(
self, request: Request, media_id: str, name: Optional[str] self, request: SynapseRequest, media_id: str, name: Optional[str]
) -> None: ) -> None:
"""Responds to requests for local media, if exists, or returns 404. """Responds to requests for local media, if exists, or returns 404.
@@ -221,7 +221,11 @@ class MediaRepository:
) )
async def get_remote_media( async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str] self,
request: SynapseRequest,
server_name: str,
media_id: str,
name: Optional[str],
) -> None: ) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.

View File

@@ -29,7 +29,6 @@ import attr
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
@@ -167,7 +166,7 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000 self._start_expire_url_cache_data, 10 * 1000
) )
async def _async_render_OPTIONS(self, request: Request) -> None: async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET") request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)

View File

@@ -17,11 +17,10 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import ( from ._base import (
@@ -57,7 +56,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request: Request) -> None: async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request) set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True) width = parse_integer(request, "width", required=True)
@@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_local_thumbnail( async def _respond_local_thumbnail(
self, self,
request: Request, request: SynapseRequest,
media_id: str, media_id: str,
width: int, width: int,
height: int, height: int,
@@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail( async def _select_or_generate_local_thumbnail(
self, self,
request: Request, request: SynapseRequest,
media_id: str, media_id: str,
desired_width: int, desired_width: int,
desired_height: int, desired_height: int,
@@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail( async def _select_or_generate_remote_thumbnail(
self, self,
request: Request, request: SynapseRequest,
server_name: str, server_name: str,
media_id: str, media_id: str,
desired_width: int, desired_width: int,
@@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _respond_remote_thumbnail( async def _respond_remote_thumbnail(
self, self,
request: Request, request: SynapseRequest,
server_name: str, server_name: str,
media_id: str, media_id: str,
width: int, width: int,
@@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_and_respond_with_thumbnail( async def _select_and_respond_with_thumbnail(
self, self,
request: Request, request: SynapseRequest,
desired_width: int, desired_width: int,
desired_height: int, desired_height: int,
desired_method: str, desired_method: str,

View File

@@ -16,8 +16,6 @@
import logging import logging
from typing import IO, TYPE_CHECKING, Dict, List, Optional from typing import IO, TYPE_CHECKING, Dict, List, Optional
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_bytes_from_args from synapse.http.servlet import parse_bytes_from_args
@@ -46,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock() self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request: Request) -> None: async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: SynapseRequest) -> None: async def _async_render_POST(self, request: SynapseRequest) -> None:

View File

@@ -21,13 +21,28 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
Mapping, Mapping,
Sequence,
Set, Set,
Sized,
Tuple, Tuple,
TypeVar, TypeVar,
) )
from typing_extensions import Protocol
T = TypeVar("T") T = TypeVar("T")
S = TypeVar("S", bound="_SelfSlice")
class _SelfSlice(Sized, Protocol):
"""A helper protocol that matches types where taking a slice results in the
same type being returned.
This is more specific than `Sequence`, which allows another `Sequence` to be
returned.
"""
def __getitem__(self: S, i: slice) -> S:
...
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]: def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
@@ -46,7 +61,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
return iter(lambda: tuple(islice(sourceiter, size)), ()) return iter(lambda: tuple(islice(sourceiter, size)), ())
def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]: def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]:
"""Split the given sequence into chunks of the given size """Split the given sequence into chunks of the given size
The last chunk may be shorter than the given size. The last chunk may be shorter than the given size.

View File

@@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _AsyncTestCustomEndpoint({}, None).handle_request handler = _AsyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
channel = make_request(self.reactor, FakeSite(resource), "GET", "/") channel = make_request(
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
@@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase):
handler = _SyncTestCustomEndpoint({}, None).handle_request handler = _SyncTestCustomEndpoint({}, None).handle_request
resource = AdditionalResource(self.hs, handler) resource = AdditionalResource(self.hs, handler)
channel = make_request(self.reactor, FakeSite(resource), "GET", "/") channel = make_request(
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
)
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

View File

@@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"]) site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
site.site_tag = "test-site" site.site_tag = "test-site"
site.server_version_string = "Server v1" site.server_version_string = "Server v1"
request = SynapseRequest(FakeChannel(site, None)) site.reactor = Mock()
request = SynapseRequest(FakeChannel(site, None), site)
# Call requestReceived to finish instantiating the object. # Call requestReceived to finish instantiating the object.
request.content = BytesIO() request.content = BytesIO()
# Partially skip some of the internal processing of SynapseRequest. # Partially skip some of the internal processing of SynapseRequest.

View File

@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
resource = hs.get_media_repository_resource().children[b"download"] resource = hs.get_media_repository_resource().children[b"download"]
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"GET", "GET",
f"/{target}/{media_id}", f"/{target}/{media_id}",
shorthand=False, shorthand=False,

View File

@@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
"""Ensure a piece of media is quarantined when trying to access it.""" """Ensure a piece of media is quarantined when trying to access it."""
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access the media # Attempt to access the media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_name_and_media_id, server_name_and_media_id,
shorthand=False, shorthand=False,
@@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
# Attempt to access each piece of media # Attempt to access each piece of media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_and_media_id_2, server_and_media_id_2,
shorthand=False, shorthand=False,

View File

@@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media # Attempt to access media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
# Attempt to access media # Attempt to access media
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,
@@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,

View File

@@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
# Try to access a media and to create `last_access_ts` # Try to access a media and to create `last_access_ts`
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(download_resource), FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, server_and_media_id,
shorthand=False, shorthand=False,

View File

@@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Load the password reset confirmation page # Load the password reset confirmation page
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.submit_token_resource), FakeSite(self.submit_token_resource, self.reactor),
"GET", "GET",
path, path,
shorthand=False, shorthand=False,
@@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Confirm the password reset # Confirm the password reset
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.submit_token_resource), FakeSite(self.submit_token_resource, self.reactor),
"POST", "POST",
path, path,
content=b"", content=b"",

View File

@@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
"""You can observe the terms form without specifying a user""" """You can observe the terms form without specifying a user"""
resource = consent_resource.ConsentResource(self.hs) resource = consent_resource.ConsentResource(self.hs)
channel = make_request( channel = make_request(
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False self.reactor,
FakeSite(resource, self.reactor),
"GET",
"/consent?v=1",
shorthand=False,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
@@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
) )
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"GET", "GET",
consent_uri, consent_uri,
access_token=access_token, access_token=access_token,
@@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# POST to the consent page, saying we've agreed # POST to the consent page, saying we've agreed
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"POST", "POST",
consent_uri + "&v=" + version, consent_uri + "&v=" + version,
access_token=access_token, access_token=access_token,
@@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
# changed # changed
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource), FakeSite(resource, self.reactor),
"GET", "GET",
consent_uri, consent_uri,
access_token=access_token, access_token=access_token,

View File

@@ -372,7 +372,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
FakeSite(resource), FakeSite(resource, self.hs.get_reactor()),
"POST", "POST",
path, path,
content=image_data, content=image_data,

View File

@@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body. Checks that the response is a 200 and returns the decoded json body.
""" """
channel = FakeChannel(self.site, self.reactor) channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel) req = SynapseRequest(channel, self.site)
req.content = BytesIO(b"") req.content = BytesIO(b"")
req.requestReceived( req.requestReceived(
b"GET", b"GET",
@@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
) )
channel = FakeChannel(self.site, self.reactor) channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel) req = SynapseRequest(channel, self.site)
req.content = BytesIO(encode_canonical_json(data)) req.content = BytesIO(encode_canonical_json(data))
req.requestReceived( req.requestReceived(

View File

@@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.download_resource), FakeSite(self.download_resource, self.reactor),
"GET", "GET",
self.media_id, self.media_id,
shorthand=False, shorthand=False,
@@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale" params = "?width=32&height=32&method=scale"
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.thumbnail_resource), FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, self.media_id + params,
shorthand=False, shorthand=False,
@@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.thumbnail_resource), FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, self.media_id + params,
shorthand=False, shorthand=False,
@@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(self.thumbnail_resource), FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, self.media_id + params,
shorthand=False, shorthand=False,

View File

@@ -19,6 +19,7 @@ from twisted.internet.interfaces import (
IPullProducer, IPullProducer,
IPushProducer, IPushProducer,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTime,
IResolverSimple, IResolverSimple,
ITransport, ITransport,
) )
@@ -181,13 +182,14 @@ class FakeSite:
site_tag = "test" site_tag = "test"
access_logger = logging.getLogger("synapse.access.http.fake") access_logger = logging.getLogger("synapse.access.http.fake")
def __init__(self, resource: IResource): def __init__(self, resource: IResource, reactor: IReactorTime):
""" """
Args: Args:
resource: the resource to be used for rendering all requests resource: the resource to be used for rendering all requests
""" """
self._resource = resource self._resource = resource
self.reactor = reactor
def getResourceFor(self, request): def getResourceFor(self, request):
return self._resource return self._resource
@@ -268,7 +270,7 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip) channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel) req = request(channel, site)
req.content = BytesIO(content) req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request. # Twisted expects to be at the end of the content when parsing the request.
req.content.seek(SEEK_END) req.content.seek(SEEK_END)

View File

@@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase):
) )
make_request( make_request(
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" self.reactor,
FakeSite(res, self.reactor),
b"GET",
b"/_matrix/foo/%E2%98%83?a=%E2%98%83",
) )
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
@@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase):
def _callback(request, **kwargs): def _callback(request, **kwargs):
d = Deferred() d = Deferred()
d.addCallback(_throw) d.addCallback(_throw)
self.reactor.callLater(1, d.callback, True) self.reactor.callLater(0.5, d.callback, True)
return make_deferred_yieldable(d) return make_deferred_yieldable(d)
res = JsonResource(self.homeserver) res = JsonResource(self.homeserver)
@@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"500") self.assertEqual(channel.result["code"], b"500")
@@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["error"], "Forbidden!!one!") self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
@@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
) )
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
)
self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.result["code"], b"400")
self.assertEqual(channel.json_body["error"], "Unrecognized request") self.assertEqual(channel.json_body["error"], "Unrecognized request")
@@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase):
) )
# The path was registered as GET, but this is a HEAD request. # The path was registered as GET, but this is a HEAD request.
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)
@@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
body = channel.result["body"] body = channel.result["body"]
@@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"301") self.assertEqual(channel.result["code"], b"301")
headers = channel.result["headers"] headers = channel.result["headers"]
@@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
)
self.assertEqual(channel.result["code"], b"304") self.assertEqual(channel.result["code"], b"304")
headers = channel.result["headers"] headers = channel.result["headers"]
@@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
res = WrapHtmlRequestHandlerTests.TestResource() res = WrapHtmlRequestHandlerTests.TestResource()
res.callback = callback res.callback = callback
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") channel = make_request(
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
)
self.assertEqual(channel.result["code"], b"200") self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result) self.assertNotIn("body", channel.result)