Compare commits

...

14 Commits

Author SHA1 Message Date
Andrew Morgan
70dc5da72d Simply calling execute_values in order to appease mypy
mypy was getting upset as it thought 'fetch' was being passed multiple times
to execute_values. This cannot happen, as fetch would never be part of 'args',
and thus seems like a bug in mypy.

This code was a bit strange anyhow. Why collect values with *x only to
expand them again immediately afterwards when calling execute_values?

The code is simpler now IMO, and appeases mypy.
2021-11-17 01:04:29 +00:00
Andrew Morgan
e532a8f77d Add types-psycopg2 package to optional dev dependencies 2021-11-17 00:47:56 +00:00
Andrew Morgan
33af486813 Changelog 2021-11-17 00:47:56 +00:00
Andrew Morgan
3dd34bba19 Remove some unused classes. 2021-11-17 00:47:56 +00:00
Andrew Morgan
dfffd5ad06 Refactor 'setupdb' to remove an indent-level, and help explain its purpose a bit more. 2021-11-17 00:47:56 +00:00
Andrew Morgan
88c6999840 Pass a fake Connection mock to SQLBaseStore, instead of None.
SQLBaseStore expects the connection to be non-Optional, so we can't just pass none.

Note that this parameter isn't even used by the function anyhow. I believe it's only there to satisfy the inherited class it's overriding.
2021-11-17 00:47:56 +00:00
Andrew Morgan
89add577d7 Ignore monkey-patching functions. mypy doesn't currently allow this.
See https://github.com/python/mypy/issues/2427 for details/complaints.
2021-11-17 00:47:56 +00:00
Andrew Morgan
970555dc9e Remove mock of non-existent property. 2021-11-17 00:47:56 +00:00
Andrew Morgan
ed833e3bc9 Again, don't re-use variable names.
Interestingly I noticed that the reactor argument is never actually set by any calling functions. Should we just remove it?
2021-11-17 00:47:56 +00:00
Andrew Morgan
510b30f7e6 Don't re-use variable names with differing types. 2021-11-17 00:47:56 +00:00
Andrew Morgan
ca2cfa58d8 Make default_config only return a dict representation
This does make mypy happy, and does reduce a bit of confusion, though it's
a shame we have to duplicate the parsing code around everywhere now.

Is there a better way to solve this?
2021-11-17 00:47:56 +00:00
Andrew Morgan
76f5ce9537 Add type annotation to Databases.databases
Awkwardly, this is a list of database configs, and should probably be renamed?
2021-11-17 00:47:56 +00:00
Andrew Morgan
69bfd46158 Denote HomeServer.DATASTORE_CLASS as an abstract property the Python 3 way
This seems to be required to make mypy happy about using it in inheriting classes
2021-11-17 00:47:56 +00:00
Andrew Morgan
d8448a0414 Remove tests/utils.py from mypy exclude list 2021-11-17 00:47:56 +00:00
10 changed files with 97 additions and 179 deletions

1
changelog.d/11349.misc Normal file
View File

@@ -0,0 +1 @@
Fix type hints to allow `tests.utils` to pass `mypy`.

View File

@@ -142,7 +142,6 @@ exclude = (?x)
|tests/util/test_lrucache.py
|tests/util/test_rwlock.py
|tests/util/test_wheel_timer.py
|tests/utils.py
)$
[mypy-synapse.api.*]

View File

@@ -108,6 +108,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
"types-bleach>=4.1.0",
"types-jsonschema>=3.2.0",
"types-Pillow>=8.3.4",
"types-psycopg2>=2.9.1",
"types-pyOpenSSL>=20.0.7",
"types-PyYAML>=5.4.10",
"types-requests>=2.26.0",

View File

@@ -224,7 +224,10 @@ class HomeServer(metaclass=abc.ABCMeta):
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
@property
@abc.abstractmethod
def DATASTORE_CLASS(self):
pass
tls_server_context_factory: Optional[IOpenSSLContextFactory]

View File

@@ -300,7 +300,9 @@ class LoggingTransaction:
from psycopg2.extras import execute_values # type: ignore
return self._do_execute(
lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
lambda sql, argslist: execute_values(self.txn, sql, argslist, fetch=fetch),
sql,
*args,
)
def execute(self, sql: str, *args: Any) -> None:

View File

@@ -52,7 +52,7 @@ class Databases(Generic[DataStoreT]):
# Note we pass in the main store class here as workers use a different main
# store.
self.databases = []
self.databases: List[DatabasePool] = []
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
persist_events: Optional[PersistEventsStore] = None

View File

@@ -67,7 +67,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.mock_resolver = Mock()
config_dict = default_config("test", parse=False)
config_dict = default_config("test")
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
self._config = config = HomeServerConfig()
@@ -957,7 +957,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
config_dict = default_config("test")
config = HomeServerConfig()
config.parse_config_dict(config_dict)
# Build a new agent and WellKnownResolver with a different tls factory
tls_factory = FederationPolicyForHTTPS(config)

View File

@@ -18,9 +18,11 @@ from unittest.mock import Mock
from twisted.internet import defer
from synapse.config.homeserver import HomeServerConfig
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine
from synapse.storage.types import Connection
from tests import unittest
from tests.utils import TestHomeServer, default_config
@@ -47,7 +49,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection
config = default_config(name="test", parse=True)
config_dict = default_config(name="test")
config = HomeServerConfig()
config.parse_config_dict(config_dict)
hs = TestHomeServer("test", config=config)
sqlite_config = {"name": "sqlite3"}
@@ -59,7 +64,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
self.datastore = SQLBaseStore(db, None, hs)
self.datastore = SQLBaseStore(db, Mock(spec=Connection), hs)
@defer.inlineCallbacks
def test_insert_1col(self):

View File

@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler
@@ -172,7 +173,11 @@ class StateTestCase(unittest.TestCase):
"hostname",
]
)
hs.config = default_config("tesths", True)
config_dict = default_config("tesths")
hs.config = HomeServerConfig()
hs.config.parse_config_dict(config_dict)
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()

View File

@@ -1,5 +1,4 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,14 +18,11 @@ import os
import time
import uuid
import warnings
from typing import Type
from unittest.mock import Mock, patch
from urllib import parse as urlparse
from twisted.internet import defer
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Type
from unittest.mock import Mock
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
@@ -54,13 +50,47 @@ POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
def setupdb():
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS:
# create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}})
def setupdb() -> None:
"""
Set up a temporary database to run tests in. Only applicable to postgres,
which uses a persistent database server rather than a database in memory.
"""
# Setting up the database is only required when using postgres
if not USE_POSTGRES_FOR_TESTS:
return
# connect to postgres to create the base database.
# create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}})
# connect to postgres to create the base database.
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute(
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
"template=template0;" % (POSTGRES_BASE_DB,)
)
cur.close()
db_conn.close()
# Set up in the db
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
def _cleanup():
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
@@ -70,43 +100,21 @@ def setupdb():
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute(
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
"template=template0;" % (POSTGRES_BASE_DB,)
)
cur.close()
db_conn.close()
# Set up in the db
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
def _cleanup():
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.close()
db_conn.close()
atexit.register(_cleanup)
atexit.register(_cleanup)
def default_config(name, parse=False):
def default_config(name: str) -> Dict[str, Any]:
"""
Create a reasonable test config.
Args:
name: The value of the 'server_name' option in the returned config.
Returns:
A sensible, default homeserver config.
"""
config_dict = {
"server_name": name,
@@ -175,11 +183,6 @@ def default_config(name, parse=False):
"listeners": [{"port": 0, "type": "http"}],
}
if parse:
config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
return config
return config_dict
@@ -188,10 +191,10 @@ class TestHomeServer(HomeServer):
def setup_test_homeserver(
cleanup_func,
name="test",
config=None,
reactor=None,
cleanup_func: Callable[[Callable], Any],
name: str = "test",
config: Optional[HomeServerConfig] = None,
reactor: Optional[ModuleType] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs,
):
@@ -209,12 +212,15 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor
from twisted.internet import reactor as _reactor
reactor = _reactor
if config is None:
config = default_config(name, parse=True)
config_dict = default_config(name)
config.ldap_enabled = False
config = HomeServerConfig()
config.parse_config_dict(config_dict)
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -222,7 +228,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
database_config = {
database_config_dict = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -234,18 +240,18 @@ def setup_test_homeserver(
},
}
else:
database_config = {
database_config_dict = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
database_config_dict["txn_limit"] = kwargs["db_txn_limit"]
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
database_config = DatabaseConnectionConfig("master", database_config_dict)
config.database.databases = [database_config]
db_engine = create_engine(database.config)
db_engine = create_engine(database_config.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -278,7 +284,6 @@ def setup_test_homeserver(
# Mock TLS
hs.tls_server_context_factory = Mock()
hs.tls_client_options_factory = Mock()
hs.setup()
if homeserver_to_use == TestHomeServer:
@@ -338,12 +343,12 @@ def setup_test_homeserver(
async def hash(p):
return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash
hs.get_auth_handler().hash = hash # type: ignore
async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash
hs.get_auth_handler().validate_hash = validate_hash # type: ignore
return hs
@@ -357,111 +362,6 @@ def mock_getRawHeaders(headers=None):
return getRawHeaders
# This is a mock /resource/ not an entire server
class MockHttpResource:
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
def trigger_get(self, path):
return self.trigger(b"GET", path, None)
@patch("twisted.web.http.Request")
@defer.inlineCallbacks
def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
):
"""Fire an HTTP event.
Args:
http_method : The HTTP method
path : The HTTP path
content : The HTTP body
mock_request : Mocked request to pass to the event so it can get
content.
federation_auth_origin (bytes|None): domain to authenticate as, for federation
Returns:
A tuple of (code, response)
Raises:
KeyError If no event is found which will handle the path.
"""
path = self.prefix + path
# annoyingly we return a twisted http request which has chained calls
# to get at the http content, hence mock it here.
mock_content = Mock()
config = {"read.return_value": content}
mock_content.configure_mock(**config)
mock_request.content = mock_content
mock_request.method = http_method.encode("ascii")
mock_request.uri = path.encode("ascii")
mock_request.getClientIP.return_value = "-"
headers = {}
if federation_auth_origin is not None:
headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
mock_request.path = path
# add in query params to the right place
try:
mock_request.args = urlparse.parse_qs(path.split("?")[1])
mock_request.path = path.split("?")[0]
path = mock_request.path
except Exception:
pass
if isinstance(path, bytes):
path = path.decode("utf8")
for (method, pattern, func) in self.callbacks:
if http_method != method:
continue
matcher = pattern.match(path)
if matcher:
try:
args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield defer.ensureDeferred(
func(mock_request, *args)
)
return code, response
except CodeMessageException as e:
return e.code, cs_error(e.msg, code=e.errcode)
raise KeyError("No event can handle %s" % path)
def register_paths(self, method, path_patterns, callback, servlet_name):
for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback))
class MockKey:
alg = "mock_alg"
version = "mock_version"
signature = b"\x9a\x87$"
@property
def verify_key(self):
return self
def sign(self, message):
return self
def verify(self, message, sig):
assert sig == b"\x9a\x87$"
def encode(self):
return b"<fake_encoded_key>"
class MockClock:
now = 1000