Compare commits

...

12 Commits

Author SHA1 Message Date
H. Shay
c5802ea6b1 add tests 2023-08-16 15:06:03 -07:00
H. Shay
072ae11c5c add an admin endpoint for authorizing server to signal token revocations 2023-08-16 15:05:54 -07:00
Mathieu Velten
a432ae0997 Rename pagination&purge locks and add comments explaining them (#16112) 2023-08-16 12:02:11 -07:00
axel simon
98ec5257aa Add link explaining ELK stack to structured_logging.md (#16091) 2023-08-16 12:02:11 -07:00
David Robertson
e05c7ce208 Attempt to fix twisted trunk (#16115) 2023-08-16 12:02:11 -07:00
Patrick Cloke
b150b3626d Run pyupgrade for python 3.7 & 3.8. (#16110) 2023-08-16 12:02:11 -07:00
Olivier Wilkinson (reivilibre)
ddbb346124 1.90.0 2023-08-16 12:02:11 -07:00
dependabot[bot]
d3e46a739e Bump log from 0.4.19 to 0.4.20 (#16109)
Bumps [log](https://github.com/rust-lang/log) from 0.4.19 to 0.4.20.
- [Release notes](https://github.com/rust-lang/log/releases)
- [Changelog](https://github.com/rust-lang/log/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/log/compare/0.4.19...0.4.20)

---
updated-dependencies:
- dependency-name: log
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-08-16 12:02:11 -07:00
H. Shay
400c90d0a7 requested changes 2023-08-16 09:23:11 -07:00
H. Shay
d3841eb337 newsfragment 2023-08-15 15:04:21 -07:00
H. Shay
e4dfba4425 add some tests 2023-08-15 14:27:44 -07:00
H. Shay
9db3a90782 add an expiring cache to _introspect_token 2023-08-15 14:27:39 -07:00
66 changed files with 360 additions and 149 deletions

View File

@@ -5,6 +5,9 @@ on:
- cron: 0 8 * * * - cron: 0 8 * * *
workflow_dispatch: workflow_dispatch:
# NB: inputs are only present when this workflow is dispatched manually.
# (The default below is the default field value in the form to trigger
# a manual dispatch). Otherwise the inputs will evaluate to null.
inputs: inputs:
twisted_ref: twisted_ref:
description: Commit, branch or tag to checkout from upstream Twisted. description: Commit, branch or tag to checkout from upstream Twisted.
@@ -49,7 +52,7 @@ jobs:
extras: "all" extras: "all"
- run: | - run: |
poetry remove twisted poetry remove twisted
poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref }} poetry add --extras tls git+https://github.com/twisted/twisted.git#${{ inputs.twisted_ref || 'trunk' }}
poetry install --no-interaction --extras "all test" poetry install --no-interaction --extras "all test"
- name: Remove warn_unused_ignores from mypy config - name: Remove warn_unused_ignores from mypy config
run: sed '/warn_unused_ignores = True/d' -i mypy.ini run: sed '/warn_unused_ignores = True/d' -i mypy.ini

View File

@@ -1,3 +1,8 @@
# Synapse 1.90.0 (2023-08-15)
No significant changes since 1.90.0rc1.
# Synapse 1.90.0rc1 (2023-08-08) # Synapse 1.90.0rc1 (2023-08-08)
### Features ### Features

4
Cargo.lock generated
View File

@@ -132,9 +132,9 @@ dependencies = [
[[package]] [[package]]
name = "log" name = "log"
version = "0.4.19" version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]] [[package]]
name = "memchr" name = "memchr"

1
changelog.d/16091.doc Normal file
View File

@@ -0,0 +1 @@
Structured logging docs: add a link to explain the ELK stack

1
changelog.d/16110.misc Normal file
View File

@@ -0,0 +1 @@
Run `pyupgrade` for Python 3.8+.

1
changelog.d/16112.misc Normal file
View File

@@ -0,0 +1 @@
Rename pagination and purge locks and add comments to explain why they exist and how they work.

1
changelog.d/16115.misc Normal file
View File

@@ -0,0 +1 @@
Attempt to fix the twisted trunk job.

1
changelog.d/16117.misc Normal file
View File

@@ -0,0 +1 @@
Cache token introspection response from OIDC provider.

View File

@@ -769,7 +769,7 @@ def main(server_url, identity_server_url, username, token, config_path):
global CONFIG_JSON global CONFIG_JSON
CONFIG_JSON = config_path # bit cheeky, but just overwrite the global CONFIG_JSON = config_path # bit cheeky, but just overwrite the global
try: try:
with open(config_path, "r") as config: with open(config_path) as config:
syn_cmd.config = json.load(config) syn_cmd.config = json.load(config)
try: try:
http_client.verbose = "on" == syn_cmd.config["verbose"] http_client.verbose = "on" == syn_cmd.config["verbose"]

6
debian/changelog vendored
View File

@@ -1,3 +1,9 @@
matrix-synapse-py3 (1.90.0) stable; urgency=medium
* New Synapse release 1.90.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 15 Aug 2023 11:17:34 +0100
matrix-synapse-py3 (1.90.0~rc1) stable; urgency=medium matrix-synapse-py3 (1.90.0~rc1) stable; urgency=medium
* New Synapse release 1.90.0rc1. * New Synapse release 1.90.0rc1.

View File

@@ -861,7 +861,7 @@ def generate_worker_files(
# Then a worker config file # Then a worker config file
convert( convert(
"/conf/worker.yaml.j2", "/conf/worker.yaml.j2",
"/conf/workers/{name}.yaml".format(name=worker_name), f"/conf/workers/{worker_name}.yaml",
**worker_config, **worker_config,
worker_log_config_filepath=log_config_filepath, worker_log_config_filepath=log_config_filepath,
using_unix_sockets=using_unix_sockets, using_unix_sockets=using_unix_sockets,

View File

@@ -82,7 +82,7 @@ def generate_config_from_template(
with open(filename) as handle: with open(filename) as handle:
value = handle.read() value = handle.read()
else: else:
log("Generating a random secret for {}".format(secret)) log(f"Generating a random secret for {secret}")
value = codecs.encode(os.urandom(32), "hex").decode() value = codecs.encode(os.urandom(32), "hex").decode()
with open(filename, "w") as handle: with open(filename, "w") as handle:
handle.write(value) handle.write(value)

View File

@@ -3,7 +3,7 @@
A structured logging system can be useful when your logs are destined for a A structured logging system can be useful when your logs are destined for a
machine to parse and process. By maintaining its machine-readable characteristics, machine to parse and process. By maintaining its machine-readable characteristics,
it enables more efficient searching and aggregations when consumed by software it enables more efficient searching and aggregations when consumed by software
such as the "ELK stack". such as the [ELK stack](https://opensource.com/article/18/9/open-source-log-aggregation-tools).
Synapse's structured logging system is configured via the file that Synapse's Synapse's structured logging system is configured via the file that Synapse's
`log_config` config option points to. The file should include a formatter which `log_config` config option points to. The file should include a formatter which

View File

@@ -45,6 +45,13 @@ warn_unused_ignores = False
disallow_untyped_defs = False disallow_untyped_defs = False
disallow_incomplete_defs = False disallow_incomplete_defs = False
[mypy-synapse.util.manhole]
# This module imports something from Twisted which has a bad annotation in Twisted trunk,
# but is unannotated in Twisted's latest release. We want to type-ignore the problem
# in the twisted trunk job, even though it has no effect on normal mypy runs.
warn_unused_ignores = False
;; Dependencies without annotations ;; Dependencies without annotations
;; Before ignoring a module, check to see if type stubs are available. ;; Before ignoring a module, check to see if type stubs are available.
;; The `typeshed` project maintains stubs here: ;; The `typeshed` project maintains stubs here:

View File

@@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
[tool.poetry] [tool.poetry]
name = "matrix-synapse" name = "matrix-synapse"
version = "1.90.0rc1" version = "1.90.0"
description = "Homeserver for the Matrix decentralised comms protocol" description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"] authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0" license = "Apache-2.0"

View File

@@ -47,7 +47,7 @@ can be passed on the commandline for debugging.
projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) projdir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
class Builder(object): class Builder:
def __init__( def __init__(
self, self,
redirect_stdout: bool = False, redirect_stdout: bool = False,

View File

@@ -43,7 +43,7 @@ def main(force_colors: bool) -> None:
diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None) diffs: List[git.Diff] = repo.remote().refs.develop.commit.diff(None)
# Get the schema version of the local file to check against current schema on develop # Get the schema version of the local file to check against current schema on develop
with open("synapse/storage/schema/__init__.py", "r") as file: with open("synapse/storage/schema/__init__.py") as file:
local_schema = file.read() local_schema = file.read()
new_locals: Dict[str, Any] = {} new_locals: Dict[str, Any] = {}
exec(local_schema, new_locals) exec(local_schema, new_locals)

View File

@@ -247,7 +247,7 @@ def main() -> None:
def read_args_from_config(args: argparse.Namespace) -> None: def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh: with open(args.config) as fh:
config = yaml.safe_load(fh) config = yaml.safe_load(fh)
if not args.server_name: if not args.server_name:

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C. # Copyright 2020 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -145,7 +145,7 @@ Example usage:
def read_args_from_config(args: argparse.Namespace) -> None: def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh: with open(args.config) as fh:
config = yaml.safe_load(fh) config = yaml.safe_load(fh)
if not args.server_name: if not args.server_name:
args.server_name = config["server_name"] args.server_name = config["server_name"]

View File

@@ -25,7 +25,11 @@ from synapse.util.rust import check_rust_lib_up_to_date
from synapse.util.stringutils import strtobool from synapse.util.stringutils import strtobool
# Check that we're not running on an unsupported Python version. # Check that we're not running on an unsupported Python version.
if sys.version_info < (3, 8): #
# Note that we use an (unneeded) variable here so that pyupgrade doesn't nuke the
# if-statement completely.
py_version = sys.version_info
if py_version < (3, 8):
print("Synapse requires Python 3.8 or above.") print("Synapse requires Python 3.8 or above.")
sys.exit(1) sys.exit(1)
@@ -78,7 +82,7 @@ try:
except ImportError: except ImportError:
pass pass
import synapse.util import synapse.util # noqa: E402
__version__ = synapse.util.SYNAPSE_VERSION __version__ = synapse.util.SYNAPSE_VERSION

View File

@@ -1205,10 +1205,10 @@ class CursesProgress(Progress):
self.total_processed = 0 self.total_processed = 0
self.total_remaining = 0 self.total_remaining = 0
super(CursesProgress, self).__init__() super().__init__()
def update(self, table: str, num_done: int) -> None: def update(self, table: str, num_done: int) -> None:
super(CursesProgress, self).update(table, num_done) super().update(table, num_done)
self.total_processed = 0 self.total_processed = 0
self.total_remaining = 0 self.total_remaining = 0
@@ -1304,7 +1304,7 @@ class TerminalProgress(Progress):
"""Just prints progress to the terminal""" """Just prints progress to the terminal"""
def update(self, table: str, num_done: int) -> None: def update(self, table: str, num_done: int) -> None:
super(TerminalProgress, self).update(table, num_done) super().update(table, num_done)
data = self.tables[table] data = self.tables[table]

View File

@@ -38,7 +38,7 @@ class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore # type: ignore [assignment] DATASTORE_CLASS = DataStore # type: ignore [assignment]
def __init__(self, config: HomeServerConfig): def __init__(self, config: HomeServerConfig):
super(MockHomeserver, self).__init__( super().__init__(
hostname=config.server.server_name, hostname=config.server.server_name,
config=config, config=config,
reactor=reactor, reactor=reactor,

View File

@@ -39,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import Requester, UserID, create_requester from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.expiringcache import ExpiringCache
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@@ -106,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth):
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
self._clock = hs.get_clock()
self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
cache_name="introspection_token_cache",
clock=self._clock,
max_len=10000,
expiry_ms=5 * 60 * 1000,
)
if isinstance(auth_method, PrivateKeyJWTWithKid): if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method # Use the JWK as the client secret when using the private_key_jwt method
assert self._config.jwk, "No JWK provided" assert self._config.jwk, "No JWK provided"
@@ -144,6 +153,20 @@ class MSC3861DelegatedAuth(BaseAuth):
Returns: Returns:
The introspection response The introspection response
""" """
# check the cache before doing a request
introspection_token = self._token_cache.get(token, None)
if introspection_token:
# check the expiration field of the token (if it exists)
exp = introspection_token.get("exp", None)
if exp:
time_now = self._clock.time()
expired = time_now > exp
if not expired:
return introspection_token
else:
return introspection_token
metadata = await self._issuer_metadata.get() metadata = await self._issuer_metadata.get()
introspection_endpoint = metadata.get("introspection_endpoint") introspection_endpoint = metadata.get("introspection_endpoint")
raw_headers: Dict[str, str] = { raw_headers: Dict[str, str] = {
@@ -157,7 +180,10 @@ class MSC3861DelegatedAuth(BaseAuth):
# Fill the body/headers with credentials # Fill the body/headers with credentials
uri, raw_headers, body = self._client_auth.prepare( uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body method="POST",
uri=introspection_endpoint,
headers=raw_headers,
body=body,
) )
headers = Headers({k: [v] for (k, v) in raw_headers.items()}) headers = Headers({k: [v] for (k, v) in raw_headers.items()})
@@ -187,7 +213,17 @@ class MSC3861DelegatedAuth(BaseAuth):
"The introspection endpoint returned an invalid JSON response." "The introspection endpoint returned an invalid JSON response."
) )
return IntrospectionToken(**resp) expiration = resp.get("exp", None)
if expiration:
if self._clock.time() > expiration:
raise InvalidClientTokenError("Token is expired.")
introspection_token = IntrospectionToken(**resp)
# add token to cache
self._token_cache[token] = introspection_token
return introspection_token
async def is_server_admin(self, requester: Requester) -> bool: async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope return "urn:synapse:admin:*" in requester.scope

View File

@@ -18,8 +18,7 @@
"""Contains constants from the specification.""" """Contains constants from the specification."""
import enum import enum
from typing import Final
from typing_extensions import Final
# the max size of a (canonical-json-encoded) event # the max size of a (canonical-json-encoded) event
MAX_PDU_SIZE = 65536 MAX_PDU_SIZE = 65536

View File

@@ -63,7 +63,7 @@ from synapse.federation.federation_base import (
) )
from synapse.federation.persistence import TransactionActions from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction from synapse.federation.units import Edu, Transaction
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
from synapse.logging.context import ( from synapse.logging.context import (
make_deferred_yieldable, make_deferred_yieldable,
@@ -1245,7 +1245,7 @@ class FederationServer(FederationBase):
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME` # while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
# lock. # lock.
async with self._worker_lock_handler.acquire_read_write_lock( async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
): ):
await self._federation_event_handler.on_receive_pdu( await self._federation_event_handler.on_receive_pdu(
origin, event origin, event

View File

@@ -53,7 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler from synapse.handlers.directory import DirectoryHandler
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@@ -1034,7 +1034,7 @@ class EventCreationHandler:
) )
async with self._worker_lock_handler.acquire_read_write_lock( async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
): ):
return await self._create_and_send_nonmember_event_locked( return await self._create_and_send_nonmember_event_locked(
requester=requester, requester=requester,
@@ -1978,7 +1978,7 @@ class EventCreationHandler:
for room_id in room_ids: for room_id in room_ids:
async with self._worker_lock_handler.acquire_read_write_lock( async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
): ):
dummy_event_sent = await self._send_dummy_event_for_room(room_id) dummy_event_sent = await self._send_dummy_event_for_room(room_id)

View File

@@ -24,6 +24,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig from synapse.events.utils import SerializeEventConfig
from synapse.handlers.room import ShutdownRoomResponse from synapse.handlers.room import ShutdownRoomResponse
from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.opentracing import trace from synapse.logging.opentracing import trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.admin._base import assert_user_is_admin from synapse.rest.admin._base import assert_user_is_admin
@@ -46,9 +47,10 @@ logger = logging.getLogger(__name__)
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
PURGE_HISTORY_LOCK_NAME = "purge_history_lock" # This is used to avoid purging a room several time at the same moment,
# and also paginating during a purge. Pagination can trigger backfill,
DELETE_ROOM_LOCK_NAME = "delete_room_lock" # which would create old events locally, and would potentially clash with the room delete.
PURGE_PAGINATION_LOCK_NAME = "purge_pagination_lock"
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True, auto_attribs=True)
@@ -363,7 +365,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
async with self._worker_locks.acquire_read_write_lock( async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=True PURGE_PAGINATION_LOCK_NAME, room_id, write=True
): ):
await self._storage_controllers.purge_events.purge_history( await self._storage_controllers.purge_events.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
@@ -421,7 +423,10 @@ class PaginationHandler:
force: set true to skip checking for joined users. force: set true to skip checking for joined users.
""" """
async with self._worker_locks.acquire_multi_read_write_lock( async with self._worker_locks.acquire_multi_read_write_lock(
[(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)], [
(PURGE_PAGINATION_LOCK_NAME, room_id),
(NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id),
],
write=True, write=True,
): ):
# first check that we have no users in this room # first check that we have no users in this room
@@ -483,7 +488,7 @@ class PaginationHandler:
room_token = from_token.room_key room_token = from_token.room_key
async with self._worker_locks.acquire_read_write_lock( async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=False PURGE_PAGINATION_LOCK_NAME, room_id, write=False
): ):
(membership, member_event_id) = (None, None) (membership, member_event_id) = (None, None)
if not use_admin_priviledge: if not use_admin_priviledge:
@@ -761,7 +766,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
async with self._worker_locks.acquire_read_write_lock( async with self._worker_locks.acquire_read_write_lock(
PURGE_HISTORY_LOCK_NAME, room_id, write=True PURGE_PAGINATION_LOCK_NAME, room_id, write=True
): ):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[ self._delete_by_id[

View File

@@ -32,6 +32,7 @@ from typing import (
Any, Any,
Callable, Callable,
Collection, Collection,
ContextManager,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@@ -43,7 +44,6 @@ from typing import (
) )
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import ContextManager
import synapse.metrics import synapse.metrics
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState

View File

@@ -39,7 +39,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@@ -621,7 +621,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
async with self.member_as_limiter.queue(as_id): async with self.member_as_limiter.queue(as_id):
async with self.member_linearizer.queue(key): async with self.member_linearizer.queue(key):
async with self._worker_lock_handler.acquire_read_write_lock( async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
): ):
with opentracing.start_active_span("update_membership_locked"): with opentracing.start_active_span("update_membership_locked"):
result = await self.update_membership_locked( result = await self.update_membership_locked(

View File

@@ -24,13 +24,14 @@ from typing import (
Iterable, Iterable,
List, List,
Mapping, Mapping,
NoReturn,
Optional, Optional,
Set, Set,
) )
from urllib.parse import urlencode from urllib.parse import urlencode
import attr import attr
from typing_extensions import NoReturn, Protocol from typing_extensions import Protocol
from twisted.web.iweb import IRequest from twisted.web.iweb import IRequest
from twisted.web.server import Request from twisted.web.server import Request
@@ -791,7 +792,7 @@ class SsoHandler:
if code != 200: if code != 200:
raise Exception( raise Exception(
"GET request to download sso avatar image returned {}".format(code) f"GET request to download sso avatar image returned {code}"
) )
# upload name includes hash of the image file's content so that we can # upload name includes hash of the image file's content so that we can

View File

@@ -14,9 +14,15 @@
# limitations under the License. # limitations under the License.
import logging import logging
from collections import Counter from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple from typing import (
TYPE_CHECKING,
from typing_extensions import Counter as CounterType Any,
Counter as CounterType,
Dict,
Iterable,
Optional,
Tuple,
)
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions

View File

@@ -1442,11 +1442,9 @@ class SyncHandler:
# Now we have our list of joined room IDs, exclude as configured and freeze # Now we have our list of joined room IDs, exclude as configured and freeze
joined_room_ids = frozenset( joined_room_ids = frozenset(
( room_id
room_id for room_id in mutable_joined_room_ids
for room_id in mutable_joined_room_ids if room_id not in mutable_rooms_to_exclude
if room_id not in mutable_rooms_to_exclude
)
) )
logger.debug( logger.debug(

View File

@@ -42,7 +42,11 @@ if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
DELETE_ROOM_LOCK_NAME = "delete_room_lock" # This lock is used to avoid creating an event while we are purging the room.
# We take a read lock when creating an event, and a write one when purging a room.
# This is because it is fine to create several events concurrently, since referenced events
# will not disappear under our feet as long as we don't delete the room.
NEW_EVENT_DURING_PURGE_LOCK_NAME = "new_event_during_purge_lock"
class WorkerLocksHandler: class WorkerLocksHandler:

View File

@@ -18,10 +18,9 @@ import traceback
from collections import deque from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor from math import floor
from typing import Callable, Optional from typing import Callable, Deque, Optional
import attr import attr
from typing_extensions import Deque
from zope.interface import implementer from zope.interface import implementer
from twisted.application.internet import ClientService from twisted.application.internet import ClientService

View File

@@ -426,9 +426,7 @@ class SpamCheckerModuleApiCallbacks:
generally discouraged as it doesn't support internationalization. generally discouraged as it doesn't support internationalization.
""" """
for callback in self._check_event_for_spam_callbacks: for callback in self._check_event_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(event)) res = await delay_cancellation(callback(event))
if res is False or res == self.NOT_SPAM: if res is False or res == self.NOT_SPAM:
# This spam-checker accepts the event. # This spam-checker accepts the event.
@@ -481,9 +479,7 @@ class SpamCheckerModuleApiCallbacks:
True if the event should be silently dropped True if the event should be silently dropped
""" """
for callback in self._should_drop_federated_event_callbacks: for callback in self._should_drop_federated_event_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res: Union[bool, str] = await delay_cancellation(callback(event)) res: Union[bool, str] = await delay_cancellation(callback(event))
if res: if res:
return res return res
@@ -505,9 +501,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise. NOT_SPAM if the operation is permitted, [Codes, Dict] otherwise.
""" """
for callback in self._user_may_join_room_callbacks: for callback in self._user_may_join_room_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(user_id, room_id, is_invited)) res = await delay_cancellation(callback(user_id, room_id, is_invited))
# Normalize return values to `Codes` or `"NOT_SPAM"`. # Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
@@ -546,9 +540,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise. NOT_SPAM if the operation is permitted, Codes otherwise.
""" """
for callback in self._user_may_invite_callbacks: for callback in self._user_may_invite_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation( res = await delay_cancellation(
callback(inviter_userid, invitee_userid, room_id) callback(inviter_userid, invitee_userid, room_id)
) )
@@ -593,9 +585,7 @@ class SpamCheckerModuleApiCallbacks:
NOT_SPAM if the operation is permitted, Codes otherwise. NOT_SPAM if the operation is permitted, Codes otherwise.
""" """
for callback in self._user_may_send_3pid_invite_callbacks: for callback in self._user_may_send_3pid_invite_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation( res = await delay_cancellation(
callback(inviter_userid, medium, address, room_id) callback(inviter_userid, medium, address, room_id)
) )
@@ -630,9 +620,7 @@ class SpamCheckerModuleApiCallbacks:
userid: The ID of the user attempting to create a room userid: The ID of the user attempting to create a room
""" """
for callback in self._user_may_create_room_callbacks: for callback in self._user_may_create_room_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(userid)) res = await delay_cancellation(callback(userid))
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
@@ -666,9 +654,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._user_may_create_room_alias_callbacks: for callback in self._user_may_create_room_alias_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(userid, room_alias)) res = await delay_cancellation(callback(userid, room_alias))
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
@@ -701,9 +687,7 @@ class SpamCheckerModuleApiCallbacks:
room_id: The ID of the room that would be published room_id: The ID of the room that would be published
""" """
for callback in self._user_may_publish_room_callbacks: for callback in self._user_may_publish_room_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(userid, room_id)) res = await delay_cancellation(callback(userid, room_id))
if res is True or res is self.NOT_SPAM: if res is True or res is self.NOT_SPAM:
continue continue
@@ -742,9 +726,7 @@ class SpamCheckerModuleApiCallbacks:
True if the user is spammy. True if the user is spammy.
""" """
for callback in self._check_username_for_spam_callbacks: for callback in self._check_username_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
# Make a copy of the user profile object to ensure the spam checker cannot # Make a copy of the user profile object to ensure the spam checker cannot
# modify it. # modify it.
res = await delay_cancellation(callback(user_profile.copy())) res = await delay_cancellation(callback(user_profile.copy()))
@@ -776,9 +758,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._check_registration_for_spam_callbacks: for callback in self._check_registration_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
behaviour = await delay_cancellation( behaviour = await delay_cancellation(
callback(email_threepid, username, request_info, auth_provider_id) callback(email_threepid, username, request_info, auth_provider_id)
) )
@@ -820,9 +800,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._check_media_file_for_spam_callbacks: for callback in self._check_media_file_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation(callback(file_wrapper, file_info)) res = await delay_cancellation(callback(file_wrapper, file_info))
# Normalize return values to `Codes` or `"NOT_SPAM"`. # Normalize return values to `Codes` or `"NOT_SPAM"`.
if res is False or res is self.NOT_SPAM: if res is False or res is self.NOT_SPAM:
@@ -869,9 +847,7 @@ class SpamCheckerModuleApiCallbacks:
""" """
for callback in self._check_login_for_spam_callbacks: for callback in self._check_login_for_spam_callbacks:
with Measure( with Measure(self.clock, f"{callback.__module__}.{callback.__qualname__}"):
self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
):
res = await delay_cancellation( res = await delay_cancellation(
callback( callback(
user_id, user_id,

View File

@@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Deque,
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
@@ -29,7 +30,6 @@ from typing import (
) )
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory

View File

@@ -47,6 +47,7 @@ from synapse.rest.admin.federation import (
ListDestinationsRestServlet, ListDestinationsRestServlet,
) )
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.oidc import OIDCTokenRevocationRestServlet
from synapse.rest.admin.registration_tokens import ( from synapse.rest.admin.registration_tokens import (
ListRegistrationTokensRestServlet, ListRegistrationTokensRestServlet,
NewRegistrationTokenRestServlet, NewRegistrationTokenRestServlet,
@@ -297,6 +298,8 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
BackgroundUpdateRestServlet(hs).register(http_server) BackgroundUpdateRestServlet(hs).register(http_server)
BackgroundUpdateStartJobRestServlet(hs).register(http_server) BackgroundUpdateStartJobRestServlet(hs).register(http_server)
ExperimentalFeaturesRestServlet(hs).register(http_server) ExperimentalFeaturesRestServlet(hs).register(http_server)
if hs.config.experimental.msc3861.enabled:
OIDCTokenRevocationRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource( def register_servlets_for_client_rest_resource(

View File

@@ -0,0 +1,50 @@
# Copyright 2023 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Tuple
from synapse.api.errors import InvalidClientTokenError
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
if TYPE_CHECKING:
from synapse.server import HomeServer
class OIDCTokenRevocationRestServlet(RestServlet):
"""
Delete a given token introspection response - identified by the `jti` field - from the
introspection token cache when a token is revoked at the authorizing server
"""
PATTERNS = admin_patterns("/OIDC_token_revocation/(?P<token_id>[^/]*)")
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
async def on_DELETE(
self, request: SynapseRequest, token_id: str
) -> Tuple[HTTPStatus, Dict]:
await assert_requester_is_admin(self.auth, request)
try:
# mypy ignore - this attribute is defined on MSC3861DelegatedAuth, which is loaded via a config flag
# this endpoint will only be loaded if the same config flag is present
self.auth._token_cache.pop(token_id) # type: ignore[attr-defined]
except KeyError:
raise InvalidClientTokenError("Token not found.")
return HTTPStatus.OK, {}

View File

@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
@@ -81,7 +81,7 @@ class RoomUpgradeRestServlet(RestServlet):
try: try:
async with self._worker_lock_handler.acquire_read_write_lock( async with self._worker_lock_handler.acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
): ):
new_room_id = await self._room_creation_handler.upgrade_room( new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version requester, room_id, new_version

View File

@@ -45,7 +45,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
SynapseTags, SynapseTags,
@@ -357,7 +357,7 @@ class EventsPersistenceStorageController:
# it. We might already have taken out the lock, but since this is just a # it. We might already have taken out the lock, but since this is just a
# "read" lock its inherently reentrant. # "read" lock its inherently reentrant.
async with self.hs.get_worker_locks_handler().acquire_read_write_lock( async with self.hs.get_worker_locks_handler().acquire_read_write_lock(
DELETE_ROOM_LOCK_NAME, room_id, write=False NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
): ):
if isinstance(task, _PersistEventsTask): if isinstance(task, _PersistEventsTask):
return await self._persist_event_batch(room_id, task) return await self._persist_event_batch(room_id, task)

View File

@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union, cast from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from typing_extensions import TYPE_CHECKING
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json

View File

@@ -188,7 +188,7 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of # invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one # _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id). # param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate((((server_name, key_id),))) self._get_server_keys_json.invalidate(((server_name, key_id),))
@cached() @cached()
def _get_server_keys_json( def _get_server_keys_json(

View File

@@ -19,6 +19,7 @@ from itertools import chain
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Counter,
Dict, Dict,
Iterable, Iterable,
List, List,
@@ -28,8 +29,6 @@ from typing import (
cast, cast,
) )
from typing_extensions import Counter
from twisted.internet.defer import DeferredLock from twisted.internet.defer import DeferredLock
from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership

View File

@@ -145,5 +145,5 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM
This is not provided by DBAPI2, and so needs engine-specific support. This is not provided by DBAPI2, and so needs engine-specific support.
""" """
with open(filepath, "rt") as f: with open(filepath) as f:
cls.executescript(cursor, f.read()) cls.executescript(cursor, f.read())

View File

@@ -16,10 +16,18 @@ import logging
import os import os
import re import re
from collections import Counter from collections import Counter
from typing import Collection, Generator, Iterable, List, Optional, TextIO, Tuple from typing import (
Collection,
Counter as CounterType,
Generator,
Iterable,
List,
Optional,
TextIO,
Tuple,
)
import attr import attr
from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction from synapse.storage.database import LoggingDatabaseConnection, LoggingTransaction

View File

@@ -21,6 +21,7 @@ from typing import (
Any, Any,
ClassVar, ClassVar,
Dict, Dict,
Final,
List, List,
Mapping, Mapping,
Match, Match,
@@ -38,7 +39,7 @@ import attr
from immutabledict import immutabledict from immutabledict import immutabledict
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey from signedjson.types import VerifyKey
from typing_extensions import Final, TypedDict from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
from zope.interface import Interface from zope.interface import Interface

View File

@@ -22,6 +22,7 @@ import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import ( from typing import (
Any, Any,
AsyncContextManager,
AsyncIterator, AsyncIterator,
Awaitable, Awaitable,
Callable, Callable,
@@ -42,7 +43,7 @@ from typing import (
) )
import attr import attr
from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError

View File

@@ -218,7 +218,7 @@ class MacaroonGenerator:
# to avoid validating those as guest tokens, we explicitely verify if # to avoid validating those as guest tokens, we explicitely verify if
# the macaroon includes the "guest = true" caveat. # the macaroon includes the "guest = true" caveat.
is_guest = any( is_guest = any(
(caveat.caveat_id == "guest = true" for caveat in macaroon.caveats) caveat.caveat_id == "guest = true" for caveat in macaroon.caveats
) )
if not is_guest: if not is_guest:

View File

@@ -98,7 +98,9 @@ def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> ServerFactory:
SynapseManhole, dict(globals, __name__="__console__") SynapseManhole, dict(globals, __name__="__console__")
) )
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type-ignore: This is an error in Twisted's annotations. See
# https://github.com/twisted/twisted/issues/11812 and /11813 .
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker])) # type: ignore[arg-type]
# conch has the wrong type on these dicts (says bytes to bytes, # conch has the wrong type on these dicts (says bytes to bytes,
# should be bytes to Keys judging by how it's used). # should be bytes to Keys judging by how it's used).

View File

@@ -20,6 +20,7 @@ import typing
from typing import ( from typing import (
Any, Any,
Callable, Callable,
ContextManager,
DefaultDict, DefaultDict,
Dict, Dict,
Iterator, Iterator,
@@ -33,7 +34,6 @@ from typing import (
from weakref import WeakSet from weakref import WeakSet
from prometheus_client.core import Counter from prometheus_client.core import Counter
from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer

View File

@@ -17,6 +17,7 @@ from enum import Enum, auto
from typing import ( from typing import (
Collection, Collection,
Dict, Dict,
Final,
FrozenSet, FrozenSet,
List, List,
Mapping, Mapping,
@@ -27,7 +28,6 @@ from typing import (
) )
import attr import attr
from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase from synapse.events import EventBase

View File

@@ -26,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase):
def make_homeserver( def make_homeserver(
self, reactor: ThreadedMemoryReactorClock, clock: Clock self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer: ) -> HomeServer:
hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) hs = super().make_homeserver(reactor, clock)
# We don't want our tests to actually report statistics, so check # We don't want our tests to actually report statistics, so check
# that it's not enabled # that it's not enabled

View File

@@ -312,7 +312,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
[("server9", get_key_id(key1))] [("server9", get_key_id(key1))]
) )
result = self.get_success(d) result = self.get_success(d)
self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0) self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
def test_verify_json_dedupes_key_requests(self) -> None: def test_verify_json_dedupes_key_requests(self) -> None:
"""Two requests for the same key should be deduped.""" """Two requests for the same key should be deduped."""

View File

@@ -14,7 +14,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, Union from typing import Any, Dict, Union
from unittest.mock import ANY, Mock from unittest.mock import ANY, AsyncMock, Mock
from urllib.parse import parse_qs from urllib.parse import parse_qs
from signedjson.key import ( from signedjson.key import (
@@ -491,6 +491,100 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503) self.assertEqual(error.value.code, 503)
def test_introspection_token_cache(self) -> None:
access_token = "open_sesame"
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={"active": "true", "scope": "guest", "jti": access_token},
)
)
# first call should cache response
# Mpyp ignores below are due to mypy not understanding the dynamic substitution of msc3861 auth code
# for regular auth code via the config
self.get_success(
self.auth._introspect_token(access_token) # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get(access_token) # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], access_token)
# there's been one http request
self.http_client.request.assert_called_once()
# second call should pull from cache, there should still be only one http request
token = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.http_client.request.assert_called_once()
self.assertEqual(token["jti"], access_token)
# advance past five minutes and check that cache expired - there should be more than one http call now
self.reactor.advance(360)
token_2 = self.get_success(self.auth._introspect_token(access_token)) # type: ignore[attr-defined]
self.assertEqual(self.http_client.request.call_count, 2)
self.assertEqual(token_2["jti"], access_token)
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
# token with a soon-to-expire `exp` field to the cache
self.http_client.request = simple_async_mock(
return_value=FakeResponse.json(
code=200,
payload={
"active": "true",
"scope": "guest",
"jti": "stale",
"exp": self.clock.time() + 100,
},
)
)
self.get_success(
self.auth._introspect_token("stale") # type: ignore[attr-defined]
)
introspection_token = self.auth._token_cache.get("stale") # type: ignore[attr-defined]
self.assertEqual(introspection_token["jti"], "stale")
self.assertEqual(self.http_client.request.call_count, 1)
# advance the reactor past the token expiry but less than the cache expiry
self.reactor.advance(120)
self.assertEqual(self.auth._token_cache.get("stale"), introspection_token) # type: ignore[attr-defined]
# check that the next call causes another http request (which will fail because the token is technically expired
# but the important thing is we discard the token from the cache and try the network)
self.get_failure(
self.auth._introspect_token("stale"), InvalidClientTokenError # type: ignore[attr-defined]
)
self.assertEqual(self.http_client.request.call_count, 2)
def test_revocation_endpoint(self) -> None:
# mock introspection response and then admin verification response
self.http_client.request = AsyncMock(
side_effect=[
FakeResponse.json(
code=200, payload={"active": True, "jti": "open_sesame"}
),
FakeResponse.json(
code=200,
payload={
"active": True,
"sub": SUBJECT,
"scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]),
"username": USERNAME,
},
),
]
)
# cache a token to delete
introspection_token = self.get_success(
self.auth._introspect_token("open_sesame") # type: ignore[attr-defined]
)
self.assertEqual(self.auth._token_cache.get("open_sesame"), introspection_token) # type: ignore[attr-defined]
# delete the revoked token
introspection_token_id = "open_sesame"
url = f"/_synapse/admin/v1/OIDC_token_revocation/{introspection_token_id}"
channel = self.make_request("DELETE", url, access_token="mockAccessToken")
self.assertEqual(channel.code, 200)
self.assertEqual(self.auth._token_cache.get("open_sesame"), None) # type: ignore[attr-defined]
def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: def make_device_keys(self, user_id: str, device_id: str) -> JsonDict:
# We only generate a master key to simplify the test. # We only generate a master key to simplify the test.
master_signing_key = generate_signing_key(device_id) master_signing_key = generate_signing_key(device_id)

View File

@@ -514,7 +514,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
# Send the body # Send the body
request.write('{ "a": 1 }'.encode("ascii")) request.write(b'{ "a": 1 }')
request.finish() request.finish()
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))

View File

@@ -757,7 +757,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
self.assertEqual(channel.json_body["creator"], user_id) self.assertEqual(channel.json_body["creator"], user_id)
# Check room alias. # Check room alias.
self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}") self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}")
# Let's try a room with no alias. # Let's try a room with no alias.
room_id, room_alias = self.get_success( room_id, room_alias = self.get_success(

View File

@@ -116,7 +116,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.assertEqual(request.method, b"GET") self.assertEqual(request.method, b"GET")
self.assertEqual( self.assertEqual(
request.path, request.path,
f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), f"/_matrix/media/r0/download/{target}/{media_id}".encode(),
) )
self.assertEqual( self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]

View File

@@ -627,8 +627,8 @@ class RedactionsTestCase(HomeserverTestCase):
redact_event = timeline[-1] redact_event = timeline[-1]
self.assertEqual(redact_event["type"], EventTypes.Redaction) self.assertEqual(redact_event["type"], EventTypes.Redaction)
# The redacts key should be in the content and the redacts keys. # The redacts key should be in the content and the redacts keys.
self.assertEquals(redact_event["content"]["redacts"], event_id) self.assertEqual(redact_event["content"]["redacts"], event_id)
self.assertEquals(redact_event["redacts"], event_id) self.assertEqual(redact_event["redacts"], event_id)
# But it isn't actually part of the event. # But it isn't actually part of the event.
def get_event(txn: LoggingTransaction) -> JsonDict: def get_event(txn: LoggingTransaction) -> JsonDict:
@@ -642,10 +642,10 @@ class RedactionsTestCase(HomeserverTestCase):
event_json = self.get_success( event_json = self.get_success(
main_datastore.db_pool.runInteraction("get_event", get_event) main_datastore.db_pool.runInteraction("get_event", get_event)
) )
self.assertEquals(event_json["type"], EventTypes.Redaction) self.assertEqual(event_json["type"], EventTypes.Redaction)
if expect_content: if expect_content:
self.assertNotIn("redacts", event_json) self.assertNotIn("redacts", event_json)
self.assertEquals(event_json["content"]["redacts"], event_id) self.assertEqual(event_json["content"]["redacts"], event_id)
else: else:
self.assertEquals(event_json["redacts"], event_id) self.assertEqual(event_json["redacts"], event_id)
self.assertNotIn("redacts", event_json["content"]) self.assertNotIn("redacts", event_json["content"])

View File

@@ -129,7 +129,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}", f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
return [ev["event_id"] for ev in channel.json_body["chunk"]] return [ev["event_id"] for ev in channel.json_body["chunk"]]
def _get_bundled_aggregations(self) -> JsonDict: def _get_bundled_aggregations(self) -> JsonDict:
@@ -142,7 +142,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}", f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
return channel.json_body["unsigned"].get("m.relations", {}) return channel.json_body["unsigned"].get("m.relations", {})
def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
@@ -1602,7 +1602,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
threads = channel.json_body["chunk"] threads = channel.json_body["chunk"]
return [ return [
( (
@@ -1634,7 +1634,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
################################################## ##################################################
# Check the test data is configured as expected. # # Check the test data is configured as expected. #
################################################## ##################################################
self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"count": 3, "current_user_participated": True}, {"count": 3, "current_user_participated": True},
@@ -1655,7 +1655,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop()) self._redact(thread_replies.pop())
# The thread should still exist, but the latest event should be updated. # The thread should still exist, but the latest event should be updated.
self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) self.assertEqual(self._get_related_events(), list(reversed(thread_replies)))
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"count": 2, "current_user_participated": True}, {"count": 2, "current_user_participated": True},
@@ -1674,7 +1674,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self._redact(thread_replies.pop(0)) self._redact(thread_replies.pop(0))
# Nothing should have changed (except the thread count). # Nothing should have changed (except the thread count).
self.assertEquals(self._get_related_events(), thread_replies) self.assertEqual(self._get_related_events(), thread_replies)
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertDictContainsSubset( self.assertDictContainsSubset(
{"count": 1, "current_user_participated": True}, {"count": 1, "current_user_participated": True},
@@ -1691,11 +1691,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# Redact the last remaining event. # # Redact the last remaining event. #
#################################### ####################################
self._redact(thread_replies.pop(0)) self._redact(thread_replies.pop(0))
self.assertEquals(thread_replies, []) self.assertEqual(thread_replies, [])
# The event should no longer be considered a thread. # The event should no longer be considered a thread.
self.assertEquals(self._get_related_events(), []) self.assertEqual(self._get_related_events(), [])
self.assertEquals(self._get_bundled_aggregations(), {}) self.assertEqual(self._get_bundled_aggregations(), {})
self.assertEqual(self._get_threads(), []) self.assertEqual(self._get_threads(), [])
def test_redact_parent_edit(self) -> None: def test_redact_parent_edit(self) -> None:
@@ -1749,8 +1749,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The relations are returned. # The relations are returned.
event_ids = self._get_related_events() event_ids = self._get_related_events()
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertEquals(event_ids, [related_event_id]) self.assertEqual(event_ids, [related_event_id])
self.assertEquals( self.assertEqual(
relations[RelationTypes.REFERENCE], relations[RelationTypes.REFERENCE],
{"chunk": [{"event_id": related_event_id}]}, {"chunk": [{"event_id": related_event_id}]},
) )
@@ -1772,7 +1772,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# The unredacted relation should still exist. # The unredacted relation should still exist.
event_ids = self._get_related_events() event_ids = self._get_related_events()
relations = self._get_bundled_aggregations() relations = self._get_bundled_aggregations()
self.assertEquals(len(event_ids), 1) self.assertEqual(len(event_ids), 1)
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
"count": 1, "count": 1,
@@ -1816,7 +1816,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
threads = self._get_threads(channel.json_body) threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)])
@@ -1829,7 +1829,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
# Tuple of (thread ID, latest event ID) for each thread. # Tuple of (thread ID, latest event ID) for each thread.
threads = self._get_threads(channel.json_body) threads = self._get_threads(channel.json_body)
self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)])
@@ -1850,7 +1850,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2]) self.assertEqual(thread_roots, [thread_2])
@@ -1864,7 +1864,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body) self.assertEqual(thread_roots, [thread_1], channel.json_body)
@@ -1899,7 +1899,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual( self.assertEqual(
thread_roots, [thread_3, thread_2, thread_1], channel.json_body thread_roots, [thread_3, thread_2, thread_1], channel.json_body
@@ -1911,7 +1911,7 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
@@ -1943,6 +1943,6 @@ class ThreadsTestCase(BaseRelationsTestCase):
f"/_matrix/client/v1/rooms/{self.room}/threads", f"/_matrix/client/v1/rooms/{self.room}/threads",
access_token=self.user_token, access_token=self.user_token,
) )
self.assertEquals(200, channel.code, channel.json_body) self.assertEqual(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body) self.assertEqual(thread_roots, [thread_1], channel.json_body)

View File

@@ -1362,7 +1362,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp. # Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id)) res = self.get_success(self.main_store.get_event(event_id))
self.assertEquals(ts, res.origin_server_ts) self.assertEqual(ts, res.origin_server_ts)
def test_send_state_event_ts(self) -> None: def test_send_state_event_ts(self) -> None:
"""Test sending a state event with a custom timestamp.""" """Test sending a state event with a custom timestamp."""
@@ -1384,7 +1384,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp. # Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id)) res = self.get_success(self.main_store.get_event(event_id))
self.assertEquals(ts, res.origin_server_ts) self.assertEqual(ts, res.origin_server_ts)
def test_send_membership_event_ts(self) -> None: def test_send_membership_event_ts(self) -> None:
"""Test sending a membership event with a custom timestamp.""" """Test sending a membership event with a custom timestamp."""
@@ -1406,7 +1406,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase):
# Ensure the event was persisted with the correct timestamp. # Ensure the event was persisted with the correct timestamp.
res = self.get_success(self.main_store.get_event(event_id)) res = self.get_success(self.main_store.get_event(event_id))
self.assertEquals(ts, res.origin_server_ts) self.assertEqual(ts, res.origin_server_ts)
class RoomJoinRatelimitTestCase(RoomBase): class RoomJoinRatelimitTestCase(RoomBase):

View File

@@ -26,6 +26,7 @@ from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Deque,
Dict, Dict,
Iterable, Iterable,
List, List,
@@ -41,7 +42,7 @@ from typing import (
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
from typing_extensions import Deque, ParamSpec from typing_extensions import ParamSpec
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import address, threads, udp from twisted.internet import address, threads, udp

View File

@@ -40,7 +40,7 @@ from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(ApplicationServiceStoreTestCase, self).setUp() super().setUp()
self.as_yaml_files: List[str] = [] self.as_yaml_files: List[str] = []
@@ -71,7 +71,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
except Exception: except Exception:
pass pass
super(ApplicationServiceStoreTestCase, self).tearDown() super().tearDown()
def _add_appservice( def _add_appservice(
self, as_token: str, id: str, url: str, hs_token: str, sender: str self, as_token: str, id: str, url: str, hs_token: str, sender: str
@@ -110,7 +110,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(ApplicationServiceTransactionStoreTestCase, self).setUp() super().setUp()
self.as_yaml_files: List[str] = [] self.as_yaml_files: List[str] = []
self.hs.config.appservice.app_service_config_files = self.as_yaml_files self.hs.config.appservice.app_service_config_files = self.as_yaml_files

View File

@@ -20,7 +20,7 @@ from tests import unittest
class DataStoreTestCase(unittest.HomeserverTestCase): class DataStoreTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(DataStoreTestCase, self).setUp() super().setUp()
self.store = self.hs.get_datastores().main self.store = self.hs.get_datastores().main

View File

@@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success( result = self.get_success(
store.search_msgs([self.room_id], query, ["content.body"]) store.search_msgs([self.room_id], query, ["content.body"])
) )
self.assertEquals( self.assertEqual(
result["count"], result["count"],
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'" f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'", else f"'{query}' unexpectedly matched '{self.PHRASE}'",
) )
self.assertEquals( self.assertEqual(
len(result["results"]), len(result["results"]),
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
"results array length should match count", "results array length should match count",
@@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase):
result = self.get_success( result = self.get_success(
store.search_rooms([self.room_id], query, ["content.body"], 10) store.search_rooms([self.room_id], query, ["content.body"], 10)
) )
self.assertEquals( self.assertEqual(
result["count"], result["count"],
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
f"expected '{query}' to match '{self.PHRASE}'" f"expected '{query}' to match '{self.PHRASE}'"
if expect_to_contain if expect_to_contain
else f"'{query}' unexpectedly matched '{self.PHRASE}'", else f"'{query}' unexpectedly matched '{self.PHRASE}'",
) )
self.assertEquals( self.assertEqual(
len(result["results"]), len(result["results"]),
1 if expect_to_contain else 0, 1 if expect_to_contain else 0,
"results array length should match count", "results array length should match count",

View File

@@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM"
class FilterEventsForServerTestCase(unittest.HomeserverTestCase): class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None: def setUp(self) -> None:
super(FilterEventsForServerTestCase, self).setUp() super().setUp()
self.event_creation_handler = self.hs.get_event_creation_handler() self.event_creation_handler = self.hs.get_event_creation_handler()
self.event_builder_factory = self.hs.get_event_builder_factory() self.event_builder_factory = self.hs.get_event_builder_factory()
self._storage_controllers = self.hs.get_storage_controllers() self._storage_controllers = self.hs.get_storage_controllers()