Compare commits

...

9 Commits

Author SHA1 Message Date
David Robertson
c929b8a073 Changelog 2022-08-21 23:30:39 +01:00
David Robertson
9b6764b2ef A bit more paramspec 2022-08-21 23:28:48 +01:00
David Robertson
dd70d11373 Remove unused config.ldap_enabled 2022-08-21 23:17:08 +01:00
David Robertson
5126d867b1 WIP: annotate setup_test_homeserver 2022-08-21 23:16:51 +01:00
David Robertson
9d4da69ffd Annotate FakeTransport 2022-08-21 23:02:58 +01:00
David Robertson
c9e80bc772 Annotate ThreadPool 2022-08-21 22:35:14 +01:00
David Robertson
48ae00e5bd Annotate ThreadedMemoryReactorClock 2022-08-21 22:27:04 +01:00
David Robertson
db1c5ffce9 annotate getResourceFor 2022-08-21 22:12:24 +01:00
David Robertson
895c09b6e4 annotate writeHeaders 2022-08-21 22:12:11 +01:00
2 changed files with 125 additions and 81 deletions

1
changelog.d/13578.misc Normal file
View File

@@ -0,0 +1 @@
Improve the type annotations in `tests.server`.

View File

@@ -22,20 +22,24 @@ import warnings
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
MutableMapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from unittest.mock import Mock
import attr
from typing_extensions import Deque
from typing_extensions import Deque, ParamSpec
from zope.interface import implementer
from twisted.internet import address, threads, udp
@@ -44,23 +48,28 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
IConnector,
IConsumer,
IHostnameResolver,
IProtocol,
IPullProducer,
IPushProducer,
IReactorFromThreads,
IReactorPluggableNameResolver,
IReactorTime,
IResolverSimple,
ITransport,
)
from twisted.internet.protocol import ClientFactory, DatagramProtocol
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
from twisted.web.iweb import IRequest
from twisted.web.resource import IResource
from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -70,7 +79,7 @@ from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
from tests.utils import (
@@ -90,6 +99,8 @@ logger = logging.getLogger(__name__)
# the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
P = ParamSpec("P")
R = TypeVar("R")
class TimedOutException(Exception):
@@ -165,7 +176,9 @@ class FakeChannel:
h.addRawHeader(*i)
return h
def writeHeaders(self, version, code, reason, headers):
def writeHeaders(
self, version: bytes, code: bytes, reason: bytes, headers: Headers
) -> None:
self.result["version"] = version
self.result["code"] = code
self.result["reason"] = reason
@@ -275,7 +288,7 @@ class FakeSite:
self._resource = resource
self.reactor = reactor
def getResourceFor(self, request):
def getResourceFor(self, request: IRequest) -> IResource:
return self._resource
@@ -389,17 +402,17 @@ def make_request(
return channel
@implementer(IReactorPluggableNameResolver)
@implementer(IReactorPluggableNameResolver, IReactorFromThreads)
class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
A MemoryReactorClock that supports callFromThread.
"""
def __init__(self):
def __init__(self) -> None:
self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
self._udp = []
self._tcp_callbacks: Dict[Tuple[str, int], Callable[[], None]] = {}
self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque()
@@ -407,7 +420,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
@implementer(IResolverSimple)
class FakeResolver:
def getHostByName(self, name, timeout=None):
def getHostByName(
self, name: str, timeout: Sequence[int] = ()
) -> "Deferred[str]":
if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name])
@@ -418,13 +433,22 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
def listenUDP(
self,
port: int,
protocol: DatagramProtocol,
interface: str = "",
maxPacketSize: int = 8196,
) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening()
self._udp.append(p)
return p
def callFromThread(self, callback, *args, **kwargs):
# Type-ignore: IReactorFromThreads doesn't use paramspec here.
def callFromThread( # type: ignore[override]
self, callback: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> None:
"""
Make the callback fire in the next reactor iteration.
"""
@@ -433,10 +457,12 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
# separate queue.
self._thread_callbacks.append(cb)
def getThreadPool(self):
def getThreadPool(self) -> "ThreadPool":
return self.threadpool
def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
def add_tcp_client_callback(
self, host: str, port: int, callback: Callable[[], None]
) -> None:
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
@@ -445,7 +471,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
"""
self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
def connectTCP(
self,
host: str,
port: int,
factory: ClientFactory,
timeout: float = 30,
bindAddress: Optional[Tuple[str, int]] = None,
) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP(
@@ -458,7 +491,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn
def advance(self, amount):
def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that
# makes ready
super().advance(amount)
@@ -485,26 +518,32 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool:
"""
Threadless thread pool.
Threadless thread pool. A stand-in for twisted.python.threadpool.ThreadPool.
"""
def __init__(self, reactor):
def __init__(self, reactor: IReactorTime):
self._reactor = reactor
def start(self):
def start(self) -> None:
pass
def stop(self):
def stop(self) -> None:
pass
def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
def _(res):
def callInThreadWithCallback(
self,
onResult: Callable[[bool, object], Any],
function: Callable[P, Any],
*args: P.args,
**kwargs: P.kwargs,
) -> "Deferred[bool]":
def _(res: object) -> None:
if isinstance(res, Failure):
onResult(False, res)
else:
onResult(True, res)
d = Deferred()
d: "Deferred[bool]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
self._reactor.callLater(0, d.callback, True)
@@ -521,7 +560,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases:
pool = database._db_pool
def runWithConnection(func, *args, **kwargs):
def runWithConnection(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -531,7 +570,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs,
)
def runInteraction(interaction, *args, **kwargs):
def runInteraction(interaction: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
return threads.deferToThreadPool(
pool._reactor,
pool.threadpool,
@@ -559,7 +598,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport)
@attr.s(cmp=False)
@attr.s(cmp=False, auto_attribs=True)
class FakeTransport:
"""
A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -574,35 +613,29 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances.
"""
other = attr.ib()
"""The Protocol object which will receive any data written to this transport.
other: IProtocol
"""The Protocol object which will receive any data written to this transport."""
:type: twisted.internet.interfaces.IProtocol
"""
_reactor: IReactorTime
"""Test reactor """
_reactor = attr.ib()
"""Test reactor
:type: twisted.internet.interfaces.IReactorTime
"""
_protocol = attr.ib(default=None)
_protocol: Optional[IProtocol] = None
"""The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc.
"""
_peer_address: Optional[IAddress] = attr.ib(default=None)
_peer_address: Optional[IAddress] = None
"""The value to be returned by getPeer"""
_host_address: Optional[IAddress] = attr.ib(default=None)
_host_address: Optional[IAddress] = None
"""The value to be returned by getHost"""
disconnecting = False
disconnected = False
connected = True
buffer = attr.ib(default=b"")
producer = attr.ib(default=None)
autoflush = attr.ib(default=True)
disconnecting: bool = False
disconnected: bool = False
connected: bool = True
buffer: bytes = b""
producer: Optional[IPushProducer] = None
autoflush: bool = True
def getPeer(self) -> Optional[IAddress]:
return self._peer_address
@@ -610,7 +643,7 @@ class FakeTransport:
def getHost(self) -> Optional[IAddress]:
return self._host_address
def loseConnection(self, reason=None):
def loseConnection(self, reason: Optional[Failure] = None) -> None:
if not self.disconnecting:
logger.info("FakeTransport: loseConnection(%s)", reason)
self.disconnecting = True
@@ -626,7 +659,7 @@ class FakeTransport:
self.connected = False
self.disconnected = True
def abortConnection(self):
def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()")
if not self.disconnecting:
@@ -636,28 +669,28 @@ class FakeTransport:
self.disconnected = True
def pauseProducing(self):
def pauseProducing(self) -> None:
if not self.producer:
return
self.producer.pauseProducing()
def resumeProducing(self):
def resumeProducing(self) -> None:
if not self.producer:
return
self.producer.resumeProducing()
def unregisterProducer(self):
def unregisterProducer(self) -> None:
if not self.producer:
return
self.producer = None
def registerProducer(self, producer, streaming):
def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer
self.producerStreaming = streaming
def _produce():
def _produce() -> None:
if not self.producer:
# we've been unregistered
return
@@ -669,7 +702,7 @@ class FakeTransport:
if not streaming:
self._reactor.callLater(0.0, _produce)
def write(self, byt):
def write(self, byt: bytes) -> None:
if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport")
@@ -681,11 +714,11 @@ class FakeTransport:
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
def writeSequence(self, seq):
def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq:
self.write(x)
def flush(self, maxbytes=None):
def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol
@@ -739,14 +772,17 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
HS = TypeVar("HS", bound=HomeServer)
def setup_test_homeserver(
cleanup_func,
name="test",
config=None,
reactor=None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs,
):
cleanup_func: Callable[[Callable[[], None]], Any],
name: str = "test",
config: Union[HomeServerConfig, None] = None,
reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HS] = TestHomeServer,
**kwargs: object,
) -> HS:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor.
@@ -761,13 +797,12 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor
from twisted.internet import reactor # type: ignore[no-redef]
if config is None:
config = default_config(name, parse=True)
config.caches.resize_all_caches()
config.ldap_enabled = False
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -810,20 +845,25 @@ def setup_test_homeserver(
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
database_conn_config = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database_conn_config]
db_engine = create_engine(database.config)
db_engine = create_engine(database_conn_config.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
import psycopg2
db_conn = cast(
psycopg2.connection,
db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
),
)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -856,7 +896,7 @@ def setup_test_homeserver(
database = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL
def cleanup():
def cleanup() -> None:
import psycopg2
# Close all the db pools
@@ -865,12 +905,15 @@ def setup_test_homeserver(
dropped = False
# Drop the test database
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
db_conn = cast(
psycopg2.connection,
db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
),
)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -904,12 +947,12 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p):
async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash
async def validate_hash(p, h):
async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash