mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-11 01:40:27 +00:00
Compare commits
14 Commits
v1.111.0
...
anoa/typeh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
70dc5da72d | ||
|
|
e532a8f77d | ||
|
|
33af486813 | ||
|
|
3dd34bba19 | ||
|
|
dfffd5ad06 | ||
|
|
88c6999840 | ||
|
|
89add577d7 | ||
|
|
970555dc9e | ||
|
|
ed833e3bc9 | ||
|
|
510b30f7e6 | ||
|
|
ca2cfa58d8 | ||
|
|
76f5ce9537 | ||
|
|
69bfd46158 | ||
|
|
d8448a0414 |
1
changelog.d/11349.misc
Normal file
1
changelog.d/11349.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix type hints to allow `tests.utils` to pass `mypy`.
|
||||
1
mypy.ini
1
mypy.ini
@@ -142,7 +142,6 @@ exclude = (?x)
|
||||
|tests/util/test_lrucache.py
|
||||
|tests/util/test_rwlock.py
|
||||
|tests/util/test_wheel_timer.py
|
||||
|tests/utils.py
|
||||
)$
|
||||
|
||||
[mypy-synapse.api.*]
|
||||
|
||||
1
setup.py
1
setup.py
@@ -108,6 +108,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
|
||||
"types-bleach>=4.1.0",
|
||||
"types-jsonschema>=3.2.0",
|
||||
"types-Pillow>=8.3.4",
|
||||
"types-psycopg2>=2.9.1",
|
||||
"types-pyOpenSSL>=20.0.7",
|
||||
"types-PyYAML>=5.4.10",
|
||||
"types-requests>=2.26.0",
|
||||
|
||||
@@ -224,7 +224,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
# This is overridden in derived application classes
|
||||
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
|
||||
# instantiated during setup() for future return by get_datastore()
|
||||
DATASTORE_CLASS = abc.abstractproperty()
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def DATASTORE_CLASS(self):
|
||||
pass
|
||||
|
||||
tls_server_context_factory: Optional[IOpenSSLContextFactory]
|
||||
|
||||
|
||||
@@ -300,7 +300,9 @@ class LoggingTransaction:
|
||||
from psycopg2.extras import execute_values # type: ignore
|
||||
|
||||
return self._do_execute(
|
||||
lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
|
||||
lambda sql, argslist: execute_values(self.txn, sql, argslist, fetch=fetch),
|
||||
sql,
|
||||
*args,
|
||||
)
|
||||
|
||||
def execute(self, sql: str, *args: Any) -> None:
|
||||
|
||||
@@ -52,7 +52,7 @@ class Databases(Generic[DataStoreT]):
|
||||
# Note we pass in the main store class here as workers use a different main
|
||||
# store.
|
||||
|
||||
self.databases = []
|
||||
self.databases: List[DatabasePool] = []
|
||||
main: Optional[DataStoreT] = None
|
||||
state: Optional[StateGroupDataStore] = None
|
||||
persist_events: Optional[PersistEventsStore] = None
|
||||
|
||||
@@ -67,7 +67,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
|
||||
self.mock_resolver = Mock()
|
||||
|
||||
config_dict = default_config("test", parse=False)
|
||||
config_dict = default_config("test")
|
||||
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
|
||||
|
||||
self._config = config = HomeServerConfig()
|
||||
@@ -957,7 +957,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
config = default_config("test", parse=True)
|
||||
config_dict = default_config("test")
|
||||
config = HomeServerConfig()
|
||||
config.parse_config_dict(config_dict)
|
||||
|
||||
# Build a new agent and WellKnownResolver with a different tls factory
|
||||
tls_factory = FederationPolicyForHTTPS(config)
|
||||
|
||||
@@ -18,9 +18,11 @@ from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.types import Connection
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import TestHomeServer, default_config
|
||||
@@ -47,7 +49,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
|
||||
self.db_pool.runWithConnection = runWithConnection
|
||||
|
||||
config = default_config(name="test", parse=True)
|
||||
config_dict = default_config(name="test")
|
||||
config = HomeServerConfig()
|
||||
config.parse_config_dict(config_dict)
|
||||
|
||||
hs = TestHomeServer("test", config=config)
|
||||
|
||||
sqlite_config = {"name": "sqlite3"}
|
||||
@@ -59,7 +64,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
|
||||
db._db_pool = self.db_pool
|
||||
|
||||
self.datastore = SQLBaseStore(db, None, hs)
|
||||
self.datastore = SQLBaseStore(db, Mock(spec=Connection), hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_insert_1col(self):
|
||||
|
||||
@@ -19,6 +19,7 @@ from twisted.internet import defer
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
@@ -172,7 +173,11 @@ class StateTestCase(unittest.TestCase):
|
||||
"hostname",
|
||||
]
|
||||
)
|
||||
hs.config = default_config("tesths", True)
|
||||
|
||||
config_dict = default_config("tesths")
|
||||
hs.config = HomeServerConfig()
|
||||
hs.config.parse_config_dict(config_dict)
|
||||
|
||||
hs.get_datastore.return_value = self.store
|
||||
hs.get_state_handler.return_value = None
|
||||
hs.get_clock.return_value = MockClock()
|
||||
|
||||
240
tests/utils.py
240
tests/utils.py
@@ -1,5 +1,4 @@
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018-2019 New Vector Ltd
|
||||
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,14 +18,11 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from typing import Type
|
||||
from unittest.mock import Mock, patch
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from twisted.internet import defer
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable, Dict, Optional, Type
|
||||
from unittest.mock import Mock
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import CodeMessageException, cs_error
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
@@ -54,13 +50,47 @@ POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
|
||||
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
|
||||
|
||||
|
||||
def setupdb():
|
||||
# If we're using PostgreSQL, set up the db once
|
||||
if USE_POSTGRES_FOR_TESTS:
|
||||
# create a PostgresEngine
|
||||
db_engine = create_engine({"name": "psycopg2", "args": {}})
|
||||
def setupdb() -> None:
|
||||
"""
|
||||
Set up a temporary database to run tests in. Only applicable to postgres,
|
||||
which uses a persistent database server rather than a database in memory.
|
||||
"""
|
||||
# Setting up the database is only required when using postgres
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
return
|
||||
|
||||
# connect to postgres to create the base database.
|
||||
# create a PostgresEngine
|
||||
db_engine = create_engine({"name": "psycopg2", "args": {}})
|
||||
|
||||
# connect to postgres to create the base database.
|
||||
db_conn = db_engine.module.connect(
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
|
||||
cur.execute(
|
||||
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
|
||||
"template=template0;" % (POSTGRES_BASE_DB,)
|
||||
)
|
||||
cur.close()
|
||||
db_conn.close()
|
||||
|
||||
# Set up in the db
|
||||
db_conn = db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
password=POSTGRES_PASSWORD,
|
||||
)
|
||||
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
||||
prepare_database(db_conn, db_engine, None)
|
||||
db_conn.close()
|
||||
|
||||
def _cleanup():
|
||||
db_conn = db_engine.module.connect(
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
@@ -70,43 +100,21 @@ def setupdb():
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
|
||||
cur.execute(
|
||||
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
|
||||
"template=template0;" % (POSTGRES_BASE_DB,)
|
||||
)
|
||||
cur.close()
|
||||
db_conn.close()
|
||||
|
||||
# Set up in the db
|
||||
db_conn = db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
password=POSTGRES_PASSWORD,
|
||||
)
|
||||
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
||||
prepare_database(db_conn, db_engine, None)
|
||||
db_conn.close()
|
||||
|
||||
def _cleanup():
|
||||
db_conn = db_engine.module.connect(
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
|
||||
cur.close()
|
||||
db_conn.close()
|
||||
|
||||
atexit.register(_cleanup)
|
||||
atexit.register(_cleanup)
|
||||
|
||||
|
||||
def default_config(name, parse=False):
|
||||
def default_config(name: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a reasonable test config.
|
||||
|
||||
Args:
|
||||
name: The value of the 'server_name' option in the returned config.
|
||||
|
||||
Returns:
|
||||
A sensible, default homeserver config.
|
||||
"""
|
||||
config_dict = {
|
||||
"server_name": name,
|
||||
@@ -175,11 +183,6 @@ def default_config(name, parse=False):
|
||||
"listeners": [{"port": 0, "type": "http"}],
|
||||
}
|
||||
|
||||
if parse:
|
||||
config = HomeServerConfig()
|
||||
config.parse_config_dict(config_dict, "", "")
|
||||
return config
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
@@ -188,10 +191,10 @@ class TestHomeServer(HomeServer):
|
||||
|
||||
|
||||
def setup_test_homeserver(
|
||||
cleanup_func,
|
||||
name="test",
|
||||
config=None,
|
||||
reactor=None,
|
||||
cleanup_func: Callable[[Callable], Any],
|
||||
name: str = "test",
|
||||
config: Optional[HomeServerConfig] = None,
|
||||
reactor: Optional[ModuleType] = None,
|
||||
homeserver_to_use: Type[HomeServer] = TestHomeServer,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -209,12 +212,15 @@ def setup_test_homeserver(
|
||||
HomeserverTestCase.
|
||||
"""
|
||||
if reactor is None:
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet import reactor as _reactor
|
||||
|
||||
reactor = _reactor
|
||||
|
||||
if config is None:
|
||||
config = default_config(name, parse=True)
|
||||
config_dict = default_config(name)
|
||||
|
||||
config.ldap_enabled = False
|
||||
config = HomeServerConfig()
|
||||
config.parse_config_dict(config_dict)
|
||||
|
||||
if "clock" not in kwargs:
|
||||
kwargs["clock"] = MockClock()
|
||||
@@ -222,7 +228,7 @@ def setup_test_homeserver(
|
||||
if USE_POSTGRES_FOR_TESTS:
|
||||
test_db = "synapse_test_%s" % uuid.uuid4().hex
|
||||
|
||||
database_config = {
|
||||
database_config_dict = {
|
||||
"name": "psycopg2",
|
||||
"args": {
|
||||
"database": test_db,
|
||||
@@ -234,18 +240,18 @@ def setup_test_homeserver(
|
||||
},
|
||||
}
|
||||
else:
|
||||
database_config = {
|
||||
database_config_dict = {
|
||||
"name": "sqlite3",
|
||||
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
|
||||
}
|
||||
|
||||
if "db_txn_limit" in kwargs:
|
||||
database_config["txn_limit"] = kwargs["db_txn_limit"]
|
||||
database_config_dict["txn_limit"] = kwargs["db_txn_limit"]
|
||||
|
||||
database = DatabaseConnectionConfig("master", database_config)
|
||||
config.database.databases = [database]
|
||||
database_config = DatabaseConnectionConfig("master", database_config_dict)
|
||||
config.database.databases = [database_config]
|
||||
|
||||
db_engine = create_engine(database.config)
|
||||
db_engine = create_engine(database_config.config)
|
||||
|
||||
# Create the database before we actually try and connect to it, based off
|
||||
# the template database we generate in setupdb()
|
||||
@@ -278,7 +284,6 @@ def setup_test_homeserver(
|
||||
|
||||
# Mock TLS
|
||||
hs.tls_server_context_factory = Mock()
|
||||
hs.tls_client_options_factory = Mock()
|
||||
|
||||
hs.setup()
|
||||
if homeserver_to_use == TestHomeServer:
|
||||
@@ -338,12 +343,12 @@ def setup_test_homeserver(
|
||||
async def hash(p):
|
||||
return hashlib.md5(p.encode("utf8")).hexdigest()
|
||||
|
||||
hs.get_auth_handler().hash = hash
|
||||
hs.get_auth_handler().hash = hash # type: ignore
|
||||
|
||||
async def validate_hash(p, h):
|
||||
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
||||
|
||||
hs.get_auth_handler().validate_hash = validate_hash
|
||||
hs.get_auth_handler().validate_hash = validate_hash # type: ignore
|
||||
|
||||
return hs
|
||||
|
||||
@@ -357,111 +362,6 @@ def mock_getRawHeaders(headers=None):
|
||||
return getRawHeaders
|
||||
|
||||
|
||||
# This is a mock /resource/ not an entire server
|
||||
class MockHttpResource:
|
||||
def __init__(self, prefix=""):
|
||||
self.callbacks = [] # 3-tuple of method/pattern/function
|
||||
self.prefix = prefix
|
||||
|
||||
def trigger_get(self, path):
|
||||
return self.trigger(b"GET", path, None)
|
||||
|
||||
@patch("twisted.web.http.Request")
|
||||
@defer.inlineCallbacks
|
||||
def trigger(
|
||||
self, http_method, path, content, mock_request, federation_auth_origin=None
|
||||
):
|
||||
"""Fire an HTTP event.
|
||||
|
||||
Args:
|
||||
http_method : The HTTP method
|
||||
path : The HTTP path
|
||||
content : The HTTP body
|
||||
mock_request : Mocked request to pass to the event so it can get
|
||||
content.
|
||||
federation_auth_origin (bytes|None): domain to authenticate as, for federation
|
||||
Returns:
|
||||
A tuple of (code, response)
|
||||
Raises:
|
||||
KeyError If no event is found which will handle the path.
|
||||
"""
|
||||
path = self.prefix + path
|
||||
|
||||
# annoyingly we return a twisted http request which has chained calls
|
||||
# to get at the http content, hence mock it here.
|
||||
mock_content = Mock()
|
||||
config = {"read.return_value": content}
|
||||
mock_content.configure_mock(**config)
|
||||
mock_request.content = mock_content
|
||||
|
||||
mock_request.method = http_method.encode("ascii")
|
||||
mock_request.uri = path.encode("ascii")
|
||||
|
||||
mock_request.getClientIP.return_value = "-"
|
||||
|
||||
headers = {}
|
||||
if federation_auth_origin is not None:
|
||||
headers[b"Authorization"] = [
|
||||
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
|
||||
]
|
||||
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
|
||||
|
||||
# return the right path if the event requires it
|
||||
mock_request.path = path
|
||||
|
||||
# add in query params to the right place
|
||||
try:
|
||||
mock_request.args = urlparse.parse_qs(path.split("?")[1])
|
||||
mock_request.path = path.split("?")[0]
|
||||
path = mock_request.path
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(path, bytes):
|
||||
path = path.decode("utf8")
|
||||
|
||||
for (method, pattern, func) in self.callbacks:
|
||||
if http_method != method:
|
||||
continue
|
||||
|
||||
matcher = pattern.match(path)
|
||||
if matcher:
|
||||
try:
|
||||
args = [urlparse.unquote(u) for u in matcher.groups()]
|
||||
|
||||
(code, response) = yield defer.ensureDeferred(
|
||||
func(mock_request, *args)
|
||||
)
|
||||
return code, response
|
||||
except CodeMessageException as e:
|
||||
return e.code, cs_error(e.msg, code=e.errcode)
|
||||
|
||||
raise KeyError("No event can handle %s" % path)
|
||||
|
||||
def register_paths(self, method, path_patterns, callback, servlet_name):
|
||||
for path_pattern in path_patterns:
|
||||
self.callbacks.append((method, path_pattern, callback))
|
||||
|
||||
|
||||
class MockKey:
|
||||
alg = "mock_alg"
|
||||
version = "mock_version"
|
||||
signature = b"\x9a\x87$"
|
||||
|
||||
@property
|
||||
def verify_key(self):
|
||||
return self
|
||||
|
||||
def sign(self, message):
|
||||
return self
|
||||
|
||||
def verify(self, message, sig):
|
||||
assert sig == b"\x9a\x87$"
|
||||
|
||||
def encode(self):
|
||||
return b"<fake_encoded_key>"
|
||||
|
||||
|
||||
class MockClock:
|
||||
now = 1000
|
||||
|
||||
|
||||
Reference in New Issue
Block a user