mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-17 02:10:27 +00:00
Compare commits
12 Commits
v1.135.0
...
erikj/fast
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16e8580571 | ||
|
|
0ce9315f9f | ||
|
|
c2d84edffc | ||
|
|
d1f25c6df4 | ||
|
|
6c0bc18139 | ||
|
|
8521a0c976 | ||
|
|
e369a20d0a | ||
|
|
40c99c22ff | ||
|
|
fbcbfb4aa4 | ||
|
|
2f8abe0905 | ||
|
|
341a92b7d0 | ||
|
|
d18c71abab |
1
changelog.d/10844.feature
Normal file
1
changelog.d/10844.feature
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Speed up responding with large JSON objects to requests.
|
||||||
@@ -21,7 +21,6 @@ import types
|
|||||||
import urllib
|
import urllib
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from inspect import isawaitable
|
from inspect import isawaitable
|
||||||
from io import BytesIO
|
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
@@ -37,7 +36,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
from canonicaljson import iterencode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
@@ -45,7 +44,7 @@ from twisted.internet import defer, interfaces
|
|||||||
from twisted.python import failure
|
from twisted.python import failure
|
||||||
from twisted.web import resource
|
from twisted.web import resource
|
||||||
from twisted.web.server import NOT_DONE_YET, Request
|
from twisted.web.server import NOT_DONE_YET, Request
|
||||||
from twisted.web.static import File, NoRangeStaticProducer
|
from twisted.web.static import File
|
||||||
from twisted.web.util import redirectTo
|
from twisted.web.util import redirectTo
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
@@ -56,10 +55,11 @@ from synapse.api.errors import (
|
|||||||
UnrecognizedRequestError,
|
UnrecognizedRequestError,
|
||||||
)
|
)
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import preserve_fn
|
from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background
|
||||||
from synapse.logging.opentracing import trace_servlet
|
from synapse.logging.opentracing import trace_servlet
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches import intern_dict
|
from synapse.util.caches import intern_dict
|
||||||
|
from synapse.util.iterutils import chunk_seq
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource):
|
|||||||
|
|
||||||
def _send_response(
|
def _send_response(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
code: int,
|
code: int,
|
||||||
response_object: Any,
|
response_object: Any,
|
||||||
):
|
):
|
||||||
@@ -620,16 +620,15 @@ class _ByteProducer:
|
|||||||
self._request = None
|
self._request = None
|
||||||
|
|
||||||
|
|
||||||
def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
|
def _encode_json_bytes(json_object: Any) -> bytes:
|
||||||
"""
|
"""
|
||||||
Encode an object into JSON. Returns an iterator of bytes.
|
Encode an object into JSON. Returns an iterator of bytes.
|
||||||
"""
|
"""
|
||||||
for chunk in json_encoder.iterencode(json_object):
|
return json_encoder.encode(json_object).encode("utf-8")
|
||||||
yield chunk.encode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def respond_with_json(
|
def respond_with_json(
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
code: int,
|
code: int,
|
||||||
json_object: Any,
|
json_object: Any,
|
||||||
send_cors: bool = False,
|
send_cors: bool = False,
|
||||||
@@ -659,7 +658,7 @@ def respond_with_json(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if canonical_json:
|
if canonical_json:
|
||||||
encoder = iterencode_canonical_json
|
encoder = encode_canonical_json
|
||||||
else:
|
else:
|
||||||
encoder = _encode_json_bytes
|
encoder = _encode_json_bytes
|
||||||
|
|
||||||
@@ -670,7 +669,9 @@ def respond_with_json(
|
|||||||
if send_cors:
|
if send_cors:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
|
|
||||||
_ByteProducer(request, encoder(json_object))
|
run_in_background(
|
||||||
|
_async_write_json_to_request_in_thread, request, encoder, json_object
|
||||||
|
)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
|
||||||
@@ -706,15 +707,35 @@ def respond_with_json_bytes(
|
|||||||
if send_cors:
|
if send_cors:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
|
|
||||||
# note that this is zero-copy (the bytesio shares a copy-on-write buffer with
|
_write_json_bytes_to_request(request, json_bytes)
|
||||||
# the original `bytes`).
|
|
||||||
bytes_io = BytesIO(json_bytes)
|
|
||||||
|
|
||||||
producer = NoRangeStaticProducer(request, bytes_io)
|
|
||||||
producer.start()
|
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
|
||||||
|
|
||||||
|
def _write_json_bytes_to_request(request: Request, json_bytes: bytes) -> None:
|
||||||
|
"""Writes the JSON bytes to the request using an appropriate producer.
|
||||||
|
|
||||||
|
Note: This should be used instead of `Request.write` to correctly handle
|
||||||
|
large response bodies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The problem with dumping all of the json response into the `Request`
|
||||||
|
# object at once (via `Request.write`) is that doing so starts the timeout
|
||||||
|
# for the next request to be received: so if it takes longer than 60s to
|
||||||
|
# stream back the response to the client, the client never gets it.
|
||||||
|
#
|
||||||
|
# The correct solution is to use a Producer; then the timeout is only
|
||||||
|
# started once all of the content is sent over the TCP connection.
|
||||||
|
|
||||||
|
# To make sure we don't write the whole of the json at once we split it up
|
||||||
|
# into chunks.
|
||||||
|
chunk_size = 4096
|
||||||
|
bytes_generator = chunk_seq(json_bytes, chunk_size)
|
||||||
|
|
||||||
|
# We use a `_ByteProducer` here rather than `NoRangeStaticProducer` as the
|
||||||
|
# unit tests can't cope with being given a pull producer.
|
||||||
|
_ByteProducer(request, bytes_generator)
|
||||||
|
|
||||||
|
|
||||||
def set_cors_headers(request: Request):
|
def set_cors_headers(request: Request):
|
||||||
"""Set the CORS headers so that javascript running in a web browsers can
|
"""Set the CORS headers so that javascript running in a web browsers can
|
||||||
use this API
|
use this API
|
||||||
@@ -809,3 +830,24 @@ def finish_request(request: Request):
|
|||||||
request.finish()
|
request.finish()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.info("Connection disconnected before response was written: %r", e)
|
logger.info("Connection disconnected before response was written: %r", e)
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_write_json_to_request_in_thread(
|
||||||
|
request: SynapseRequest,
|
||||||
|
json_encoder: Callable[[Any], bytes],
|
||||||
|
json_object: Any,
|
||||||
|
):
|
||||||
|
"""Encodes the given JSON object on a thread and then writes it to the
|
||||||
|
request.
|
||||||
|
|
||||||
|
This is done so that encoding large JSON objects doesn't block the reactor
|
||||||
|
thread.
|
||||||
|
|
||||||
|
Note: We don't use JsonEncoder.iterencode here as that falls back to the
|
||||||
|
Python implementation (rather than the C backend), which is *much* more
|
||||||
|
expensive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
json_str = await defer_to_thread(request.reactor, json_encoder, json_object)
|
||||||
|
|
||||||
|
_write_json_bytes_to_request(request, json_str)
|
||||||
|
|||||||
@@ -14,14 +14,15 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Generator, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet.interfaces import IAddress, IReactorTime
|
from twisted.internet.interfaces import IAddress, IReactorTime
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.web.resource import IResource
|
from twisted.web.http import HTTPChannel
|
||||||
|
from twisted.web.resource import IResource, Resource
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.config.server import ListenerConfig
|
from synapse.config.server import ListenerConfig
|
||||||
@@ -61,10 +62,18 @@ class SynapseRequest(Request):
|
|||||||
logcontext: the log context for this request
|
logcontext: the log context for this request
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channel, *args, max_request_body_size=1024, **kw):
|
def __init__(
|
||||||
Request.__init__(self, channel, *args, **kw)
|
self,
|
||||||
|
channel: HTTPChannel,
|
||||||
|
site: "SynapseSite",
|
||||||
|
*args,
|
||||||
|
max_request_body_size: int = 1024,
|
||||||
|
**kw,
|
||||||
|
):
|
||||||
|
super().__init__(channel, *args, **kw)
|
||||||
self._max_request_body_size = max_request_body_size
|
self._max_request_body_size = max_request_body_size
|
||||||
self.site: SynapseSite = channel.site
|
self.synapse_site = site
|
||||||
|
self.reactor = site.reactor
|
||||||
self._channel = channel # this is used by the tests
|
self._channel = channel # this is used by the tests
|
||||||
self.start_time = 0.0
|
self.start_time = 0.0
|
||||||
|
|
||||||
@@ -83,13 +92,13 @@ class SynapseRequest(Request):
|
|||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
|
|
||||||
# the time when the asynchronous request handler completed its processing
|
# the time when the asynchronous request handler completed its processing
|
||||||
self._processing_finished_time = None
|
self._processing_finished_time: Optional[float] = None
|
||||||
|
|
||||||
# what time we finished sending the response to the client (or the connection
|
# what time we finished sending the response to the client (or the connection
|
||||||
# dropped)
|
# dropped)
|
||||||
self.finish_time = None
|
self.finish_time: Optional[float] = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
# We overwrite this so that we don't log ``access_token``
|
# We overwrite this so that we don't log ``access_token``
|
||||||
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
|
return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
|
||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
@@ -97,10 +106,10 @@ class SynapseRequest(Request):
|
|||||||
self.get_method(),
|
self.get_method(),
|
||||||
self.get_redacted_uri(),
|
self.get_redacted_uri(),
|
||||||
self.clientproto.decode("ascii", errors="replace"),
|
self.clientproto.decode("ascii", errors="replace"),
|
||||||
self.site.site_tag,
|
self.synapse_site.site_tag,
|
||||||
)
|
)
|
||||||
|
|
||||||
def handleContentChunk(self, data):
|
def handleContentChunk(self, data: bytes) -> None:
|
||||||
# we should have a `content` by now.
|
# we should have a `content` by now.
|
||||||
assert self.content, "handleContentChunk() called before gotLength()"
|
assert self.content, "handleContentChunk() called before gotLength()"
|
||||||
if self.content.tell() + len(data) > self._max_request_body_size:
|
if self.content.tell() + len(data) > self._max_request_body_size:
|
||||||
@@ -139,7 +148,7 @@ class SynapseRequest(Request):
|
|||||||
# If there's no authenticated entity, it was the requester.
|
# If there's no authenticated entity, it was the requester.
|
||||||
self.logcontext.request.authenticated_entity = authenticated_entity or requester
|
self.logcontext.request.authenticated_entity = authenticated_entity or requester
|
||||||
|
|
||||||
def get_request_id(self):
|
def get_request_id(self) -> str:
|
||||||
return "%s-%i" % (self.get_method(), self.request_seq)
|
return "%s-%i" % (self.get_method(), self.request_seq)
|
||||||
|
|
||||||
def get_redacted_uri(self) -> str:
|
def get_redacted_uri(self) -> str:
|
||||||
@@ -205,7 +214,7 @@ class SynapseRequest(Request):
|
|||||||
|
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def render(self, resrc):
|
def render(self, resrc: Resource) -> None:
|
||||||
# this is called once a Resource has been found to serve the request; in our
|
# this is called once a Resource has been found to serve the request; in our
|
||||||
# case the Resource in question will normally be a JsonResource.
|
# case the Resource in question will normally be a JsonResource.
|
||||||
|
|
||||||
@@ -216,7 +225,7 @@ class SynapseRequest(Request):
|
|||||||
request=ContextRequest(
|
request=ContextRequest(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
ip_address=self.getClientIP(),
|
ip_address=self.getClientIP(),
|
||||||
site_tag=self.site.site_tag,
|
site_tag=self.synapse_site.site_tag,
|
||||||
# The requester is going to be unknown at this point.
|
# The requester is going to be unknown at this point.
|
||||||
requester=None,
|
requester=None,
|
||||||
authenticated_entity=None,
|
authenticated_entity=None,
|
||||||
@@ -228,7 +237,7 @@ class SynapseRequest(Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# override the Server header which is set by twisted
|
# override the Server header which is set by twisted
|
||||||
self.setHeader("Server", self.site.server_version_string)
|
self.setHeader("Server", self.synapse_site.server_version_string)
|
||||||
|
|
||||||
with PreserveLoggingContext(self.logcontext):
|
with PreserveLoggingContext(self.logcontext):
|
||||||
# we start the request metrics timer here with an initial stab
|
# we start the request metrics timer here with an initial stab
|
||||||
@@ -247,7 +256,7 @@ class SynapseRequest(Request):
|
|||||||
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
|
requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def processing(self):
|
def processing(self) -> Generator[None, None, None]:
|
||||||
"""Record the fact that we are processing this request.
|
"""Record the fact that we are processing this request.
|
||||||
|
|
||||||
Returns a context manager; the correct way to use this is:
|
Returns a context manager; the correct way to use this is:
|
||||||
@@ -282,7 +291,7 @@ class SynapseRequest(Request):
|
|||||||
if self.finish_time is not None:
|
if self.finish_time is not None:
|
||||||
self._finished_processing()
|
self._finished_processing()
|
||||||
|
|
||||||
def finish(self):
|
def finish(self) -> None:
|
||||||
"""Called when all response data has been written to this Request.
|
"""Called when all response data has been written to this Request.
|
||||||
|
|
||||||
Overrides twisted.web.server.Request.finish to record the finish time and do
|
Overrides twisted.web.server.Request.finish to record the finish time and do
|
||||||
@@ -295,7 +304,7 @@ class SynapseRequest(Request):
|
|||||||
with PreserveLoggingContext(self.logcontext):
|
with PreserveLoggingContext(self.logcontext):
|
||||||
self._finished_processing()
|
self._finished_processing()
|
||||||
|
|
||||||
def connectionLost(self, reason):
|
def connectionLost(self, reason: Union[Failure, Exception]) -> None:
|
||||||
"""Called when the client connection is closed before the response is written.
|
"""Called when the client connection is closed before the response is written.
|
||||||
|
|
||||||
Overrides twisted.web.server.Request.connectionLost to record the finish time and
|
Overrides twisted.web.server.Request.connectionLost to record the finish time and
|
||||||
@@ -327,7 +336,7 @@ class SynapseRequest(Request):
|
|||||||
if not self._is_processing:
|
if not self._is_processing:
|
||||||
self._finished_processing()
|
self._finished_processing()
|
||||||
|
|
||||||
def _started_processing(self, servlet_name):
|
def _started_processing(self, servlet_name: str) -> None:
|
||||||
"""Record the fact that we are processing this request.
|
"""Record the fact that we are processing this request.
|
||||||
|
|
||||||
This will log the request's arrival. Once the request completes,
|
This will log the request's arrival. Once the request completes,
|
||||||
@@ -346,17 +355,19 @@ class SynapseRequest(Request):
|
|||||||
self.start_time, name=servlet_name, method=self.get_method()
|
self.start_time, name=servlet_name, method=self.get_method()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.site.access_logger.debug(
|
self.synapse_site.access_logger.debug(
|
||||||
"%s - %s - Received request: %s %s",
|
"%s - %s - Received request: %s %s",
|
||||||
self.getClientIP(),
|
self.getClientIP(),
|
||||||
self.site.site_tag,
|
self.synapse_site.site_tag,
|
||||||
self.get_method(),
|
self.get_method(),
|
||||||
self.get_redacted_uri(),
|
self.get_redacted_uri(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _finished_processing(self):
|
def _finished_processing(self) -> None:
|
||||||
"""Log the completion of this request and update the metrics"""
|
"""Log the completion of this request and update the metrics"""
|
||||||
assert self.logcontext is not None
|
assert self.logcontext is not None
|
||||||
|
assert self.finish_time is not None
|
||||||
|
|
||||||
usage = self.logcontext.get_resource_usage()
|
usage = self.logcontext.get_resource_usage()
|
||||||
|
|
||||||
if self._processing_finished_time is None:
|
if self._processing_finished_time is None:
|
||||||
@@ -386,13 +397,13 @@ class SynapseRequest(Request):
|
|||||||
if authenticated_entity:
|
if authenticated_entity:
|
||||||
requester = f"{authenticated_entity}|{requester}"
|
requester = f"{authenticated_entity}|{requester}"
|
||||||
|
|
||||||
self.site.access_logger.log(
|
self.synapse_site.access_logger.log(
|
||||||
log_level,
|
log_level,
|
||||||
"%s - %s - {%s}"
|
"%s - %s - {%s}"
|
||||||
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
|
||||||
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
' %sB %s "%s %s %s" "%s" [%d dbevts]',
|
||||||
self.getClientIP(),
|
self.getClientIP(),
|
||||||
self.site.site_tag,
|
self.synapse_site.site_tag,
|
||||||
requester,
|
requester,
|
||||||
processing_time,
|
processing_time,
|
||||||
response_send_time,
|
response_send_time,
|
||||||
@@ -437,7 +448,7 @@ class XForwardedForRequest(SynapseRequest):
|
|||||||
_forwarded_for: "Optional[_XForwardedForAddress]" = None
|
_forwarded_for: "Optional[_XForwardedForAddress]" = None
|
||||||
_forwarded_https: bool = False
|
_forwarded_https: bool = False
|
||||||
|
|
||||||
def requestReceived(self, command, path, version):
|
def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None:
|
||||||
# this method is called by the Channel once the full request has been
|
# this method is called by the Channel once the full request has been
|
||||||
# received, to dispatch the request to a resource.
|
# received, to dispatch the request to a resource.
|
||||||
# We can use it to set the IP address and protocol according to the
|
# We can use it to set the IP address and protocol according to the
|
||||||
@@ -445,7 +456,7 @@ class XForwardedForRequest(SynapseRequest):
|
|||||||
self._process_forwarded_headers()
|
self._process_forwarded_headers()
|
||||||
return super().requestReceived(command, path, version)
|
return super().requestReceived(command, path, version)
|
||||||
|
|
||||||
def _process_forwarded_headers(self):
|
def _process_forwarded_headers(self) -> None:
|
||||||
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
|
headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
|
||||||
if not headers:
|
if not headers:
|
||||||
return
|
return
|
||||||
@@ -470,7 +481,7 @@ class XForwardedForRequest(SynapseRequest):
|
|||||||
)
|
)
|
||||||
self._forwarded_https = True
|
self._forwarded_https = True
|
||||||
|
|
||||||
def isSecure(self):
|
def isSecure(self) -> bool:
|
||||||
if self._forwarded_https:
|
if self._forwarded_https:
|
||||||
return True
|
return True
|
||||||
return super().isSecure()
|
return super().isSecure()
|
||||||
@@ -520,7 +531,7 @@ class SynapseSite(Site):
|
|||||||
site_tag: str,
|
site_tag: str,
|
||||||
config: ListenerConfig,
|
config: ListenerConfig,
|
||||||
resource: IResource,
|
resource: IResource,
|
||||||
server_version_string,
|
server_version_string: str,
|
||||||
max_request_body_size: int,
|
max_request_body_size: int,
|
||||||
reactor: IReactorTime,
|
reactor: IReactorTime,
|
||||||
):
|
):
|
||||||
@@ -540,19 +551,23 @@ class SynapseSite(Site):
|
|||||||
Site.__init__(self, resource, reactor=reactor)
|
Site.__init__(self, resource, reactor=reactor)
|
||||||
|
|
||||||
self.site_tag = site_tag
|
self.site_tag = site_tag
|
||||||
|
self.reactor = reactor
|
||||||
|
|
||||||
assert config.http_options is not None
|
assert config.http_options is not None
|
||||||
proxied = config.http_options.x_forwarded
|
proxied = config.http_options.x_forwarded
|
||||||
request_class = XForwardedForRequest if proxied else SynapseRequest
|
request_class = XForwardedForRequest if proxied else SynapseRequest
|
||||||
|
|
||||||
def request_factory(channel, queued) -> Request:
|
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
|
||||||
return request_class(
|
return request_class(
|
||||||
channel, max_request_body_size=max_request_body_size, queued=queued
|
channel,
|
||||||
|
self,
|
||||||
|
max_request_body_size=max_request_body_size,
|
||||||
|
queued=queued,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.requestFactory = request_factory # type: ignore
|
self.requestFactory = request_factory # type: ignore
|
||||||
self.access_logger = logging.getLogger(logger_name)
|
self.access_logger = logging.getLogger(logger_name)
|
||||||
self.server_version_string = server_version_string.encode("ascii")
|
self.server_version_string = server_version_string.encode("ascii")
|
||||||
|
|
||||||
def log(self, request):
|
def log(self, request: SynapseRequest) -> None:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ class EmailPusher(Pusher):
|
|||||||
|
|
||||||
should_notify_at = max(notif_ready_at, room_ready_at)
|
should_notify_at = max(notif_ready_at, room_ready_at)
|
||||||
|
|
||||||
if should_notify_at < self.clock.time_msec():
|
if should_notify_at <= self.clock.time_msec():
|
||||||
# one of our notifications is ready for sending, so we send
|
# one of our notifications is ready for sending, so we send
|
||||||
# *one* email updating the user on their notifications,
|
# *one* email updating the user on their notifications,
|
||||||
# we then consider all previously outstanding notifications
|
# we then consider all previously outstanding notifications
|
||||||
|
|||||||
@@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
|
|
||||||
from signedjson.sign import sign_json
|
from signedjson.sign import sign_json
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.crypto.keyring import ServerKeyFetcher
|
from synapse.crypto.keyring import ServerKeyFetcher
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_decoder
|
from synapse.util import json_decoder
|
||||||
from synapse.util.async_helpers import yieldable_gather_results
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
@@ -100,7 +99,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
assert request.postpath is not None
|
assert request.postpath is not None
|
||||||
if len(request.postpath) == 1:
|
if len(request.postpath) == 1:
|
||||||
(server,) = request.postpath
|
(server,) = request.postpath
|
||||||
@@ -117,7 +116,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||||||
|
|
||||||
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||||
|
|
||||||
async def _async_render_POST(self, request: Request) -> None:
|
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
query = content["server_keys"]
|
query = content["server_keys"]
|
||||||
@@ -126,7 +125,7 @@ class RemoteKey(DirectServeJsonResource):
|
|||||||
|
|
||||||
async def query_keys(
|
async def query_keys(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
query: JsonDict,
|
query: JsonDict,
|
||||||
query_remote_on_cache_miss: bool = False,
|
query_remote_on_cache_miss: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from twisted.web.server import Request
|
|||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||||
from synapse.http.server import finish_request, respond_with_json
|
from synapse.http.server import finish_request, respond_with_json
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def respond_404(request: Request) -> None:
|
def respond_404(request: SynapseRequest) -> None:
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request,
|
request,
|
||||||
404,
|
404,
|
||||||
@@ -84,7 +85,7 @@ def respond_404(request: Request) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def respond_with_file(
|
async def respond_with_file(
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
media_type: str,
|
media_type: str,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
file_size: Optional[int] = None,
|
file_size: Optional[int] = None,
|
||||||
@@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
async def respond_with_responder(
|
async def respond_with_responder(
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
responder: "Optional[Responder]",
|
responder: "Optional[Responder]",
|
||||||
media_type: str,
|
media_type: str,
|
||||||
file_size: Optional[int],
|
file_size: Optional[int],
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
|
|
||||||
@@ -39,5 +37,5 @@ class MediaConfigResource(DirectServeJsonResource):
|
|||||||
await self.auth.get_user_by_req(request)
|
await self.auth.get_user_by_req(request)
|
||||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|||||||
@@ -15,10 +15,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||||
from synapse.http.servlet import parse_boolean
|
from synapse.http.servlet import parse_boolean
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
|
||||||
from ._base import parse_media_id, respond_404
|
from ._base import parse_media_id, respond_404
|
||||||
|
|
||||||
@@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource):
|
|||||||
self.media_repo = media_repo
|
self.media_repo = media_repo
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
request.setHeader(
|
request.setHeader(
|
||||||
b"Content-Security-Policy",
|
b"Content-Security-Policy",
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import twisted.internet.error
|
|||||||
import twisted.web.http
|
import twisted.web.http
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
FederationDeniedError,
|
FederationDeniedError,
|
||||||
@@ -34,6 +33,7 @@ from synapse.api.errors import (
|
|||||||
)
|
)
|
||||||
from synapse.config._base import ConfigError
|
from synapse.config._base import ConfigError
|
||||||
from synapse.config.repository import ThumbnailRequirement
|
from synapse.config.repository import ThumbnailRequirement
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.logging.context import defer_to_thread
|
from synapse.logging.context import defer_to_thread
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
@@ -187,7 +187,7 @@ class MediaRepository:
|
|||||||
return "mxc://%s/%s" % (self.server_name, media_id)
|
return "mxc://%s/%s" % (self.server_name, media_id)
|
||||||
|
|
||||||
async def get_local_media(
|
async def get_local_media(
|
||||||
self, request: Request, media_id: str, name: Optional[str]
|
self, request: SynapseRequest, media_id: str, name: Optional[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Responds to requests for local media, if exists, or returns 404.
|
"""Responds to requests for local media, if exists, or returns 404.
|
||||||
|
|
||||||
@@ -221,7 +221,11 @@ class MediaRepository:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def get_remote_media(
|
async def get_remote_media(
|
||||||
self, request: Request, server_name: str, media_id: str, name: Optional[str]
|
self,
|
||||||
|
request: SynapseRequest,
|
||||||
|
server_name: str,
|
||||||
|
media_id: str,
|
||||||
|
name: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Respond to requests for remote media.
|
"""Respond to requests for remote media.
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ import attr
|
|||||||
|
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.client import SimpleHttpClient
|
from synapse.http.client import SimpleHttpClient
|
||||||
@@ -167,7 +166,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||||||
self._start_expire_url_cache_data, 10 * 1000
|
self._start_expire_url_cache_data, 10 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
|
||||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
|
|||||||
@@ -17,11 +17,10 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||||
from synapse.http.servlet import parse_integer, parse_string
|
from synapse.http.servlet import parse_integer, parse_string
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||||
|
|
||||||
from ._base import (
|
from ._base import (
|
||||||
@@ -57,7 +56,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
|
|
||||||
async def _async_render_GET(self, request: Request) -> None:
|
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||||
set_cors_headers(request)
|
set_cors_headers(request)
|
||||||
server_name, media_id, _ = parse_media_id(request)
|
server_name, media_id, _ = parse_media_id(request)
|
||||||
width = parse_integer(request, "width", required=True)
|
width = parse_integer(request, "width", required=True)
|
||||||
@@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||||||
|
|
||||||
async def _respond_local_thumbnail(
|
async def _respond_local_thumbnail(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
width: int,
|
width: int,
|
||||||
height: int,
|
height: int,
|
||||||
@@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||||||
|
|
||||||
async def _select_or_generate_local_thumbnail(
|
async def _select_or_generate_local_thumbnail(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
desired_width: int,
|
desired_width: int,
|
||||||
desired_height: int,
|
desired_height: int,
|
||||||
@@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||||||
|
|
||||||
async def _select_or_generate_remote_thumbnail(
|
async def _select_or_generate_remote_thumbnail(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
desired_width: int,
|
desired_width: int,
|
||||||
@@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||||||
|
|
||||||
async def _respond_remote_thumbnail(
|
async def _respond_remote_thumbnail(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
server_name: str,
|
server_name: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
width: int,
|
width: int,
|
||||||
@@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
|||||||
|
|
||||||
async def _select_and_respond_with_thumbnail(
|
async def _select_and_respond_with_thumbnail(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: SynapseRequest,
|
||||||
desired_width: int,
|
desired_width: int,
|
||||||
desired_height: int,
|
desired_height: int,
|
||||||
desired_method: str,
|
desired_method: str,
|
||||||
|
|||||||
@@ -16,8 +16,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import IO, TYPE_CHECKING, Dict, List, Optional
|
from typing import IO, TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
from twisted.web.server import Request
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||||
from synapse.http.servlet import parse_bytes_from_args
|
from synapse.http.servlet import parse_bytes_from_args
|
||||||
@@ -46,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
|
|||||||
self.max_upload_size = hs.config.max_upload_size
|
self.max_upload_size = hs.config.max_upload_size
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
|
||||||
respond_with_json(request, 200, {}, send_cors=True)
|
respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||||
|
|||||||
@@ -21,13 +21,28 @@ from typing import (
|
|||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
Mapping,
|
Mapping,
|
||||||
Sequence,
|
|
||||||
Set,
|
Set,
|
||||||
|
Sized,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
S = TypeVar("S", bound="_SelfSlice")
|
||||||
|
|
||||||
|
|
||||||
|
class _SelfSlice(Sized, Protocol):
|
||||||
|
"""A helper protocol that matches types where taking a slice results in the
|
||||||
|
same type being returned.
|
||||||
|
|
||||||
|
This is more specific than `Sequence`, which allows another `Sequence` to be
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getitem__(self: S, i: slice) -> S:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
|
def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
|
||||||
@@ -46,7 +61,7 @@ def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T, ...]]:
|
|||||||
return iter(lambda: tuple(islice(sourceiter, size)), ())
|
return iter(lambda: tuple(islice(sourceiter, size)), ())
|
||||||
|
|
||||||
|
|
||||||
def chunk_seq(iseq: Sequence[T], maxlen: int) -> Iterable[Sequence[T]]:
|
def chunk_seq(iseq: S, maxlen: int) -> Iterator[S]:
|
||||||
"""Split the given sequence into chunks of the given size
|
"""Split the given sequence into chunks of the given size
|
||||||
|
|
||||||
The last chunk may be shorter than the given size.
|
The last chunk may be shorter than the given size.
|
||||||
|
|||||||
@@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase):
|
|||||||
handler = _AsyncTestCustomEndpoint({}, None).handle_request
|
handler = _AsyncTestCustomEndpoint({}, None).handle_request
|
||||||
resource = AdditionalResource(self.hs, handler)
|
resource = AdditionalResource(self.hs, handler)
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
|
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
|
||||||
@@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase):
|
|||||||
handler = _SyncTestCustomEndpoint({}, None).handle_request
|
handler = _SyncTestCustomEndpoint({}, None).handle_request
|
||||||
resource = AdditionalResource(self.hs, handler)
|
resource = AdditionalResource(self.hs, handler)
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(resource, self.reactor), "GET", "/"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
|
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
|
||||||
|
|||||||
@@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
|
|||||||
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
|
site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"])
|
||||||
site.site_tag = "test-site"
|
site.site_tag = "test-site"
|
||||||
site.server_version_string = "Server v1"
|
site.server_version_string = "Server v1"
|
||||||
request = SynapseRequest(FakeChannel(site, None))
|
site.reactor = Mock()
|
||||||
|
request = SynapseRequest(FakeChannel(site, None), site)
|
||||||
# Call requestReceived to finish instantiating the object.
|
# Call requestReceived to finish instantiating the object.
|
||||||
request.content = BytesIO()
|
request.content = BytesIO()
|
||||||
# Partially skip some of the internal processing of SynapseRequest.
|
# Partially skip some of the internal processing of SynapseRequest.
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||||||
resource = hs.get_media_repository_resource().children[b"download"]
|
resource = hs.get_media_repository_resource().children[b"download"]
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(resource),
|
FakeSite(resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
f"/{target}/{media_id}",
|
f"/{target}/{media_id}",
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
|
|||||||
@@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||||||
"""Ensure a piece of media is quarantined when trying to access it."""
|
"""Ensure a piece of media is quarantined when trying to access it."""
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.download_resource),
|
FakeSite(self.download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
server_and_media_id,
|
server_and_media_id,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||||||
# Attempt to access the media
|
# Attempt to access the media
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.download_resource),
|
FakeSite(self.download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
server_name_and_media_id,
|
server_name_and_media_id,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
|
|||||||
# Attempt to access each piece of media
|
# Attempt to access each piece of media
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.download_resource),
|
FakeSite(self.download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
server_and_media_id_2,
|
server_and_media_id_2,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
|
|||||||
# Attempt to access media
|
# Attempt to access media
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(download_resource),
|
FakeSite(download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
server_and_media_id,
|
server_and_media_id,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
|
|||||||
# Attempt to access media
|
# Attempt to access media
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(download_resource),
|
FakeSite(download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
server_and_media_id,
|
server_and_media_id,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(download_resource),
|
FakeSite(download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
server_and_media_id,
|
server_and_media_id,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
|
|||||||
@@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
|||||||
# Try to access a media and to create `last_access_ts`
|
# Try to access a media and to create `last_access_ts`
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(download_resource),
|
FakeSite(download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
server_and_media_id,
|
server_and_media_id,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
# Load the password reset confirmation page
|
# Load the password reset confirmation page
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.submit_token_resource),
|
FakeSite(self.submit_token_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
path,
|
path,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
|||||||
# Confirm the password reset
|
# Confirm the password reset
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.submit_token_resource),
|
FakeSite(self.submit_token_resource, self.reactor),
|
||||||
"POST",
|
"POST",
|
||||||
path,
|
path,
|
||||||
content=b"",
|
content=b"",
|
||||||
|
|||||||
@@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||||||
"""You can observe the terms form without specifying a user"""
|
"""You can observe the terms form without specifying a user"""
|
||||||
resource = consent_resource.ConsentResource(self.hs)
|
resource = consent_resource.ConsentResource(self.hs)
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False
|
self.reactor,
|
||||||
|
FakeSite(resource, self.reactor),
|
||||||
|
"GET",
|
||||||
|
"/consent?v=1",
|
||||||
|
shorthand=False,
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200)
|
self.assertEqual(channel.code, 200)
|
||||||
|
|
||||||
@@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(resource),
|
FakeSite(resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
consent_uri,
|
consent_uri,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
@@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||||||
# POST to the consent page, saying we've agreed
|
# POST to the consent page, saying we've agreed
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(resource),
|
FakeSite(resource, self.reactor),
|
||||||
"POST",
|
"POST",
|
||||||
consent_uri + "&v=" + version,
|
consent_uri + "&v=" + version,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
@@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase):
|
|||||||
# changed
|
# changed
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(resource),
|
FakeSite(resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
consent_uri,
|
consent_uri,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
|
|||||||
@@ -372,7 +372,7 @@ class RestHelper:
|
|||||||
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
|
path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.hs.get_reactor(),
|
self.hs.get_reactor(),
|
||||||
FakeSite(resource),
|
FakeSite(resource, self.hs.get_reactor()),
|
||||||
"POST",
|
"POST",
|
||||||
path,
|
path,
|
||||||
content=image_data,
|
content=image_data,
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
|
|||||||
Checks that the response is a 200 and returns the decoded json body.
|
Checks that the response is a 200 and returns the decoded json body.
|
||||||
"""
|
"""
|
||||||
channel = FakeChannel(self.site, self.reactor)
|
channel = FakeChannel(self.site, self.reactor)
|
||||||
req = SynapseRequest(channel)
|
req = SynapseRequest(channel, self.site)
|
||||||
req.content = BytesIO(b"")
|
req.content = BytesIO(b"")
|
||||||
req.requestReceived(
|
req.requestReceived(
|
||||||
b"GET",
|
b"GET",
|
||||||
@@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
channel = FakeChannel(self.site, self.reactor)
|
channel = FakeChannel(self.site, self.reactor)
|
||||||
req = SynapseRequest(channel)
|
req = SynapseRequest(channel, self.site)
|
||||||
req.content = BytesIO(encode_canonical_json(data))
|
req.content = BytesIO(encode_canonical_json(data))
|
||||||
|
|
||||||
req.requestReceived(
|
req.requestReceived(
|
||||||
|
|||||||
@@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.download_resource),
|
FakeSite(self.download_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
self.media_id,
|
self.media_id,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||||||
params = "?width=32&height=32&method=scale"
|
params = "?width=32&height=32&method=scale"
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.thumbnail_resource),
|
FakeSite(self.thumbnail_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
self.media_id + params,
|
self.media_id + params,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.thumbnail_resource),
|
FakeSite(self.thumbnail_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
self.media_id + params,
|
self.media_id + params,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
@@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
|||||||
params = "?width=32&height=32&method=" + method
|
params = "?width=32&height=32&method=" + method
|
||||||
channel = make_request(
|
channel = make_request(
|
||||||
self.reactor,
|
self.reactor,
|
||||||
FakeSite(self.thumbnail_resource),
|
FakeSite(self.thumbnail_resource, self.reactor),
|
||||||
"GET",
|
"GET",
|
||||||
self.media_id + params,
|
self.media_id + params,
|
||||||
shorthand=False,
|
shorthand=False,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from twisted.internet.interfaces import (
|
|||||||
IPullProducer,
|
IPullProducer,
|
||||||
IPushProducer,
|
IPushProducer,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
|
IReactorTime,
|
||||||
IResolverSimple,
|
IResolverSimple,
|
||||||
ITransport,
|
ITransport,
|
||||||
)
|
)
|
||||||
@@ -181,13 +182,14 @@ class FakeSite:
|
|||||||
site_tag = "test"
|
site_tag = "test"
|
||||||
access_logger = logging.getLogger("synapse.access.http.fake")
|
access_logger = logging.getLogger("synapse.access.http.fake")
|
||||||
|
|
||||||
def __init__(self, resource: IResource):
|
def __init__(self, resource: IResource, reactor: IReactorTime):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
resource: the resource to be used for rendering all requests
|
resource: the resource to be used for rendering all requests
|
||||||
"""
|
"""
|
||||||
self._resource = resource
|
self._resource = resource
|
||||||
|
self.reactor = reactor
|
||||||
|
|
||||||
def getResourceFor(self, request):
|
def getResourceFor(self, request):
|
||||||
return self._resource
|
return self._resource
|
||||||
@@ -268,7 +270,7 @@ def make_request(
|
|||||||
|
|
||||||
channel = FakeChannel(site, reactor, ip=client_ip)
|
channel = FakeChannel(site, reactor, ip=client_ip)
|
||||||
|
|
||||||
req = request(channel)
|
req = request(channel, site)
|
||||||
req.content = BytesIO(content)
|
req.content = BytesIO(content)
|
||||||
# Twisted expects to be at the end of the content when parsing the request.
|
# Twisted expects to be at the end of the content when parsing the request.
|
||||||
req.content.seek(SEEK_END)
|
req.content.seek(SEEK_END)
|
||||||
|
|||||||
@@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
make_request(
|
make_request(
|
||||||
self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83"
|
self.reactor,
|
||||||
|
FakeSite(res, self.reactor),
|
||||||
|
b"GET",
|
||||||
|
b"/_matrix/foo/%E2%98%83?a=%E2%98%83",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
|
self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"})
|
||||||
@@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||||||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"500")
|
self.assertEqual(channel.result["code"], b"500")
|
||||||
|
|
||||||
@@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase):
|
|||||||
def _callback(request, **kwargs):
|
def _callback(request, **kwargs):
|
||||||
d = Deferred()
|
d = Deferred()
|
||||||
d.addCallback(_throw)
|
d.addCallback(_throw)
|
||||||
self.reactor.callLater(1, d.callback, True)
|
self.reactor.callLater(0.5, d.callback, True)
|
||||||
return make_deferred_yieldable(d)
|
return make_deferred_yieldable(d)
|
||||||
|
|
||||||
res = JsonResource(self.homeserver)
|
res = JsonResource(self.homeserver)
|
||||||
@@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||||||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"500")
|
self.assertEqual(channel.result["code"], b"500")
|
||||||
|
|
||||||
@@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||||||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"403")
|
self.assertEqual(channel.result["code"], b"403")
|
||||||
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
|
self.assertEqual(channel.json_body["error"], "Forbidden!!one!")
|
||||||
@@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||||||
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
"GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet"
|
||||||
)
|
)
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"400")
|
self.assertEqual(channel.result["code"], b"400")
|
||||||
self.assertEqual(channel.json_body["error"], "Unrecognized request")
|
self.assertEqual(channel.json_body["error"], "Unrecognized request")
|
||||||
@@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# The path was registered as GET, but this is a HEAD request.
|
# The path was registered as GET, but this is a HEAD request.
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertNotIn("body", channel.result)
|
self.assertNotIn("body", channel.result)
|
||||||
@@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||||||
res = WrapHtmlRequestHandlerTests.TestResource()
|
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||||
res.callback = callback
|
res.callback = callback
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
body = channel.result["body"]
|
body = channel.result["body"]
|
||||||
@@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||||||
res = WrapHtmlRequestHandlerTests.TestResource()
|
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||||
res.callback = callback
|
res.callback = callback
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"301")
|
self.assertEqual(channel.result["code"], b"301")
|
||||||
headers = channel.result["headers"]
|
headers = channel.result["headers"]
|
||||||
@@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||||||
res = WrapHtmlRequestHandlerTests.TestResource()
|
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||||
res.callback = callback
|
res.callback = callback
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"GET", b"/path"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"304")
|
self.assertEqual(channel.result["code"], b"304")
|
||||||
headers = channel.result["headers"]
|
headers = channel.result["headers"]
|
||||||
@@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
|
|||||||
res = WrapHtmlRequestHandlerTests.TestResource()
|
res = WrapHtmlRequestHandlerTests.TestResource()
|
||||||
res.callback = callback
|
res.callback = callback
|
||||||
|
|
||||||
channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path")
|
channel = make_request(
|
||||||
|
self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path"
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(channel.result["code"], b"200")
|
self.assertEqual(channel.result["code"], b"200")
|
||||||
self.assertNotIn("body", channel.result)
|
self.assertNotIn("body", channel.result)
|
||||||
|
|||||||
Reference in New Issue
Block a user