Use type hinting generics in standard collections (#19046)

aka PEP 585, added in Python 3.9

 - https://peps.python.org/pep-0585/
 - https://docs.astral.sh/ruff/rules/non-pep585-annotation/
This commit is contained in:
Andrew Ferrazzutti
2025-10-22 17:48:19 -04:00
committed by GitHub
parent cba3a814c6
commit fc244bb592
539 changed files with 4599 additions and 5066 deletions

View File

@@ -2,13 +2,13 @@
import itertools
import os
from typing import Any, Dict
from typing import Any
from packaging.specifiers import SpecifierSet
from setuptools_rust import Binding, RustExtension
def build(setup_kwargs: Dict[str, Any]) -> None:
def build(setup_kwargs: dict[str, Any]) -> None:
original_project_dir = os.path.dirname(os.path.realpath(__file__))
cargo_toml_path = os.path.join(original_project_dir, "rust", "Cargo.toml")

1
changelog.d/19046.misc Normal file
View File

@@ -0,0 +1 @@
Use type hinting generics in standard collections, as per PEP 585, added in Python 3.9.

View File

@@ -24,7 +24,6 @@ import datetime
import html
import json
import urllib.request
from typing import List
import pydot
@@ -33,7 +32,7 @@ def make_name(pdu_id: str, origin: str) -> str:
return f"{pdu_id}@{origin}"
def make_graph(pdus: List[dict], filename_prefix: str) -> None:
def make_graph(pdus: list[dict], filename_prefix: str) -> None:
"""
Generate a dot and SVG file for a graph of events in the room based on the
topological ordering by querying a homeserver.
@@ -127,7 +126,7 @@ def make_graph(pdus: List[dict], filename_prefix: str) -> None:
graph.write_svg("%s.svg" % filename_prefix, prog="dot")
def get_pdus(host: str, room: str) -> List[dict]:
def get_pdus(host: str, room: str) -> list[dict]:
transaction = json.loads(
urllib.request.urlopen(
f"http://{host}/_matrix/federation/v1/context/{room}/"

View File

@@ -65,13 +65,10 @@ from itertools import chain
from pathlib import Path
from typing import (
Any,
Dict,
List,
Mapping,
MutableMapping,
NoReturn,
Optional,
Set,
SupportsIndex,
)
@@ -96,7 +93,7 @@ WORKER_PLACEHOLDER_NAME = "placeholder_name"
# Watching /_matrix/media and related needs a "media" listener
# Stream Writers require "client" and "replication" listeners because they
# have to attach by instance_map to the master process and have client endpoints.
WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
WORKERS_CONFIG: dict[str, dict[str, Any]] = {
"pusher": {
"app": "synapse.app.generic_worker",
"listener_resources": [],
@@ -408,7 +405,7 @@ def convert(src: str, dst: str, **template_vars: object) -> None:
def add_worker_roles_to_shared_config(
shared_config: dict,
worker_types_set: Set[str],
worker_types_set: set[str],
worker_name: str,
worker_port: int,
) -> None:
@@ -471,9 +468,9 @@ def add_worker_roles_to_shared_config(
def merge_worker_template_configs(
existing_dict: Optional[Dict[str, Any]],
to_be_merged_dict: Dict[str, Any],
) -> Dict[str, Any]:
existing_dict: Optional[dict[str, Any]],
to_be_merged_dict: dict[str, Any],
) -> dict[str, Any]:
"""When given an existing dict of worker template configuration consisting with both
dicts and lists, merge new template data from WORKERS_CONFIG(or create) and
return new dict.
@@ -484,7 +481,7 @@ def merge_worker_template_configs(
existing_dict.
Returns: The newly merged together dict values.
"""
new_dict: Dict[str, Any] = {}
new_dict: dict[str, Any] = {}
if not existing_dict:
# It doesn't exist yet, just use the new dict(but take a copy not a reference)
new_dict = to_be_merged_dict.copy()
@@ -509,8 +506,8 @@ def merge_worker_template_configs(
def insert_worker_name_for_worker_config(
existing_dict: Dict[str, Any], worker_name: str
) -> Dict[str, Any]:
existing_dict: dict[str, Any], worker_name: str
) -> dict[str, Any]:
"""Insert a given worker name into the worker's configuration dict.
Args:
@@ -526,7 +523,7 @@ def insert_worker_name_for_worker_config(
return dict_to_edit
def apply_requested_multiplier_for_worker(worker_types: List[str]) -> List[str]:
def apply_requested_multiplier_for_worker(worker_types: list[str]) -> list[str]:
"""
Apply multiplier(if found) by returning a new expanded list with some basic error
checking.
@@ -587,7 +584,7 @@ def is_sharding_allowed_for_worker_type(worker_type: str) -> bool:
def split_and_strip_string(
given_string: str, split_char: str, max_split: SupportsIndex = -1
) -> List[str]:
) -> list[str]:
"""
Helper to split a string on split_char and strip whitespace from each end of each
element.
@@ -616,8 +613,8 @@ def generate_base_homeserver_config() -> None:
def parse_worker_types(
requested_worker_types: List[str],
) -> Dict[str, Set[str]]:
requested_worker_types: list[str],
) -> dict[str, set[str]]:
"""Read the desired list of requested workers and prepare the data for use in
generating worker config files while also checking for potential gotchas.
@@ -633,14 +630,14 @@ def parse_worker_types(
# A counter of worker_base_name -> int. Used for determining the name for a given
# worker when generating its config file, as each worker's name is just
# worker_base_name followed by instance number
worker_base_name_counter: Dict[str, int] = defaultdict(int)
worker_base_name_counter: dict[str, int] = defaultdict(int)
# Similar to above, but more finely grained. This is used to determine we don't have
# more than a single worker for cases where multiples would be bad(e.g. presence).
worker_type_shard_counter: Dict[str, int] = defaultdict(int)
worker_type_shard_counter: dict[str, int] = defaultdict(int)
# The final result of all this processing
dict_to_return: Dict[str, Set[str]] = {}
dict_to_return: dict[str, set[str]] = {}
# Handle any multipliers requested for given workers.
multiple_processed_worker_types = apply_requested_multiplier_for_worker(
@@ -684,7 +681,7 @@ def parse_worker_types(
# Split the worker_type_string on "+", remove whitespace from ends then make
# the list a set so it's deduplicated.
worker_types_set: Set[str] = set(
worker_types_set: set[str] = set(
split_and_strip_string(worker_type_string, "+")
)
@@ -743,7 +740,7 @@ def generate_worker_files(
environ: Mapping[str, str],
config_path: str,
data_dir: str,
requested_worker_types: Dict[str, Set[str]],
requested_worker_types: dict[str, set[str]],
) -> None:
"""Read the desired workers(if any) that is passed in and generate shared
homeserver, nginx and supervisord configs.
@@ -764,7 +761,7 @@ def generate_worker_files(
# First read the original config file and extract the listeners block. Then we'll
# add another listener for replication. Later we'll write out the result to the
# shared config file.
listeners: List[Any]
listeners: list[Any]
if using_unix_sockets:
listeners = [
{
@@ -792,12 +789,12 @@ def generate_worker_files(
# base shared worker jinja2 template. This config file will be passed to all
# workers, included Synapse's main process. It is intended mainly for disabling
# functionality when certain workers are spun up, and adding a replication listener.
shared_config: Dict[str, Any] = {"listeners": listeners}
shared_config: dict[str, Any] = {"listeners": listeners}
# List of dicts that describe workers.
# We pass this to the Supervisor template later to generate the appropriate
# program blocks.
worker_descriptors: List[Dict[str, Any]] = []
worker_descriptors: list[dict[str, Any]] = []
# Upstreams for load-balancing purposes. This dict takes the form of the worker
# type to the ports of each worker. For example:
@@ -805,14 +802,14 @@ def generate_worker_files(
# worker_type: {1234, 1235, ...}}
# }
# and will be used to construct 'upstream' nginx directives.
nginx_upstreams: Dict[str, Set[int]] = {}
nginx_upstreams: dict[str, set[int]] = {}
# A map of: {"endpoint": "upstream"}, where "upstream" is a str representing what
# will be placed after the proxy_pass directive. The main benefit to representing
# this data as a dict over a str is that we can easily deduplicate endpoints
# across multiple instances of the same worker. The final rendering will be combined
# with nginx_upstreams and placed in /etc/nginx/conf.d.
nginx_locations: Dict[str, str] = {}
nginx_locations: dict[str, str] = {}
# Create the worker configuration directory if it doesn't already exist
os.makedirs("/conf/workers", exist_ok=True)
@@ -846,7 +843,7 @@ def generate_worker_files(
# yaml config file
for worker_name, worker_types_set in requested_worker_types.items():
# The collected and processed data will live here.
worker_config: Dict[str, Any] = {}
worker_config: dict[str, Any] = {}
# Merge all worker config templates for this worker into a single config
for worker_type in worker_types_set:
@@ -1029,7 +1026,7 @@ def generate_worker_log_config(
Returns: the path to the generated file
"""
# Check whether we should write worker logs to disk, in addition to the console
extra_log_template_args: Dict[str, Optional[str]] = {}
extra_log_template_args: dict[str, Optional[str]] = {}
if environ.get("SYNAPSE_WORKERS_WRITE_LOGS_TO_DISK"):
extra_log_template_args["LOG_FILE_PATH"] = f"{data_dir}/logs/{worker_name}.log"
@@ -1053,7 +1050,7 @@ def generate_worker_log_config(
return log_config_filepath
def main(args: List[str], environ: MutableMapping[str, str]) -> None:
def main(args: list[str], environ: MutableMapping[str, str]) -> None:
parser = ArgumentParser()
parser.add_argument(
"--generate-only",
@@ -1087,7 +1084,7 @@ def main(args: List[str], environ: MutableMapping[str, str]) -> None:
if not worker_types_env:
# No workers, just the main process
worker_types = []
requested_worker_types: Dict[str, Any] = {}
requested_worker_types: dict[str, Any] = {}
else:
# Split type names by comma, ignoring whitespace.
worker_types = split_and_strip_string(worker_types_env, ",")

View File

@@ -6,7 +6,7 @@ import os
import platform
import subprocess
import sys
from typing import Any, Dict, List, Mapping, MutableMapping, NoReturn, Optional
from typing import Any, Mapping, MutableMapping, NoReturn, Optional
import jinja2
@@ -69,7 +69,7 @@ def generate_config_from_template(
)
# populate some params from data files (if they exist, else create new ones)
environ: Dict[str, Any] = dict(os_environ)
environ: dict[str, Any] = dict(os_environ)
secrets = {
"registration": "SYNAPSE_REGISTRATION_SHARED_SECRET",
"macaroon": "SYNAPSE_MACAROON_SECRET_KEY",
@@ -200,7 +200,7 @@ def run_generate_config(environ: Mapping[str, str], ownership: Optional[str]) ->
subprocess.run(args, check=True)
def main(args: List[str], environ: MutableMapping[str, str]) -> None:
def main(args: list[str], environ: MutableMapping[str, str]) -> None:
mode = args[1] if len(args) > 1 else "run"
# if we were given an explicit user to switch to, do so

View File

@@ -78,6 +78,12 @@ select = [
"LOG",
# flake8-logging-format
"G",
# pyupgrade
"UP006",
]
extend-safe-fixes = [
# pyupgrade
"UP006"
]
[tool.ruff.lint.isort]

View File

@@ -18,7 +18,7 @@ import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from types import FrameType
from typing import Collection, Optional, Sequence, Set
from typing import Collection, Optional, Sequence
# These are expanded inside the dockerfile to be a fully qualified image name.
# e.g. docker.io/library/debian:bullseye
@@ -54,7 +54,7 @@ class Builder:
):
self.redirect_stdout = redirect_stdout
self._docker_build_args = tuple(docker_build_args or ())
self.active_containers: Set[str] = set()
self.active_containers: set[str] = set()
self._lock = threading.Lock()
self._failed = False

View File

@@ -21,7 +21,6 @@
#
import sys
from pathlib import Path
from typing import Dict, List
import tomli
@@ -33,7 +32,7 @@ def main() -> None:
# Poetry 1.3+ lockfile format:
# There's a `files` inline table in each [[package]]
packages_to_assets: Dict[str, List[Dict[str, str]]] = {
packages_to_assets: dict[str, list[dict[str, str]]] = {
package["name"]: package["files"] for package in lockfile_content["package"]
}

View File

@@ -47,11 +47,7 @@ from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Set,
Type,
TypeVar,
)
@@ -69,7 +65,7 @@ from synapse._pydantic_compat import (
logger = logging.getLogger(__name__)
CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: list[Callable] = [
constr,
conbytes,
conint,
@@ -145,7 +141,7 @@ class PatchedBaseModel(PydanticBaseModel):
"""
@classmethod
def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
def __init_subclass__(cls: type[PydanticBaseModel], **kwargs: object):
for field in cls.__fields__.values():
# Note that field.type_ and field.outer_type are computed based on the
# annotation type, see pydantic.fields.ModelField._type_analysis
@@ -212,7 +208,7 @@ def lint() -> int:
return os.EX_DATAERR if failures else os.EX_OK
def do_lint() -> Set[str]:
def do_lint() -> set[str]:
"""Try to import all of Synapse and see if we spot any Pydantic type coercions."""
failures = set()
@@ -258,8 +254,8 @@ def run_test_snippet(source: str) -> None:
# > Remember that at the module level, globals and locals are the same dictionary.
# > If exec gets two separate objects as globals and locals, the code will be
# > executed as if it were embedded in a class definition.
globals_: Dict[str, object]
locals_: Dict[str, object]
globals_: dict[str, object]
locals_: dict[str, object]
globals_ = locals_ = {}
exec(textwrap.dedent(source), globals_, locals_)
@@ -394,10 +390,10 @@ class TestFieldTypeInspection(unittest.TestCase):
("bool"),
("Optional[str]",),
("Union[None, str]",),
("List[str]",),
("List[List[str]]",),
("Dict[StrictStr, str]",),
("Dict[str, StrictStr]",),
("list[str]",),
("list[list[str]]",),
("dict[StrictStr, str]",),
("dict[str, StrictStr]",),
("TypedDict('D', x=int)",),
]
)
@@ -425,9 +421,9 @@ class TestFieldTypeInspection(unittest.TestCase):
("constr(strict=True, min_length=10)",),
("Optional[StrictStr]",),
("Union[None, StrictStr]",),
("List[StrictStr]",),
("List[List[StrictStr]]",),
("Dict[StrictStr, StrictStr]",),
("list[StrictStr]",),
("list[list[StrictStr]]",),
("dict[StrictStr, StrictStr]",),
("TypedDict('D', x=StrictInt)",),
]
)

View File

@@ -5,7 +5,7 @@
# Also checks that schema deltas do not try and create or drop indices.
import re
from typing import Any, Dict, List
from typing import Any
import click
import git
@@ -48,16 +48,16 @@ def main(force_colors: bool) -> None:
r = repo.git.show(f"origin/{DEVELOP_BRANCH}:synapse/storage/schema/__init__.py")
locals: Dict[str, Any] = {}
locals: dict[str, Any] = {}
exec(r, locals)
current_schema_version = locals["SCHEMA_VERSION"]
diffs: List[git.Diff] = repo.remote().refs[DEVELOP_BRANCH].commit.diff(None)
diffs: list[git.Diff] = repo.remote().refs[DEVELOP_BRANCH].commit.diff(None)
# Get the schema version of the local file to check against current schema on develop
with open("synapse/storage/schema/__init__.py") as file:
local_schema = file.read()
new_locals: Dict[str, Any] = {}
new_locals: dict[str, Any] = {}
exec(local_schema, new_locals)
local_schema_version = new_locals["SCHEMA_VERSION"]

View File

@@ -43,7 +43,7 @@ import argparse
import base64
import json
import sys
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from typing import Any, Mapping, Optional, Union
from urllib import parse as urlparse
import requests
@@ -147,7 +147,7 @@ def request(
s = requests.Session()
s.mount("matrix-federation://", MatrixConnectionAdapter())
headers: Dict[str, str] = {
headers: dict[str, str] = {
"Authorization": authorization_headers[0],
}
@@ -303,7 +303,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
request: PreparedRequest,
verify: Optional[Union[bool, str]],
proxies: Optional[Mapping[str, str]] = None,
cert: Optional[Union[Tuple[str, str], str]] = None,
cert: Optional[Union[tuple[str, str], str]] = None,
) -> HTTPConnectionPool:
# overrides the get_connection_with_tls_context() method in the base class
parsed = urlparse.urlsplit(request.url)
@@ -326,7 +326,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
)
@staticmethod
def _lookup(server_name: str) -> Tuple[str, int, str]:
def _lookup(server_name: str) -> tuple[str, int, str]:
"""
Do an SRV lookup on a server name and return the host:port to connect to
Given the server_name (after any .well-known lookup), return the host, port and

View File

@@ -24,7 +24,7 @@ can crop up, e.g the cache descriptors.
"""
import enum
from typing import Callable, Mapping, Optional, Tuple, Type, Union
from typing import Callable, Mapping, Optional, Union
import attr
import mypy.types
@@ -184,8 +184,8 @@ should be in the source code.
# Unbound at this point because we don't know the mypy version yet.
# This is set in the `plugin(...)` function below.
MypyPydanticPluginClass: Type[Plugin]
MypyZopePluginClass: Type[Plugin]
MypyPydanticPluginClass: type[Plugin]
MypyZopePluginClass: type[Plugin]
class SynapsePlugin(Plugin):
@@ -795,7 +795,7 @@ AT_CACHED_MUTABLE_RETURN = ErrorCode(
def is_cacheable(
rt: mypy.types.Type, signature: CallableType, verbose: bool
) -> Tuple[bool, Optional[str]]:
) -> tuple[bool, Optional[str]]:
"""
Check if a particular type is cachable.
@@ -905,7 +905,7 @@ def is_cacheable(
return False, f"Don't know how to handle {type(rt).__qualname__} return type"
def plugin(version: str) -> Type[SynapsePlugin]:
def plugin(version: str) -> type[SynapsePlugin]:
global MypyPydanticPluginClass, MypyZopePluginClass
# This is the entry point of the plugin, and lets us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version

View File

@@ -32,7 +32,7 @@ import time
import urllib.request
from os import path
from tempfile import TemporaryDirectory
from typing import Any, List, Match, Optional, Union
from typing import Any, Match, Optional, Union
import attr
import click
@@ -884,7 +884,7 @@ def get_changes_for_version(wanted_version: version.Version) -> str:
start_line: int
end_line: Optional[int] = None # Is none if its the last entry
headings: List[VersionSection] = []
headings: list[VersionSection] = []
for i, token in enumerate(tokens):
# We look for level 1 headings (h1 tags).
if token.type != "heading_open" or token.tag != "h1":

View File

@@ -38,7 +38,7 @@ import io
import json
import sys
from collections import defaultdict
from typing import Any, Dict, Iterator, Optional, Tuple
from typing import Any, Iterator, Optional
import git
from packaging import version
@@ -57,7 +57,7 @@ SCHEMA_VERSION_FILES = (
OLDEST_SHOWN_VERSION = version.parse("v1.0")
def get_schema_versions(tag: git.Tag) -> Tuple[Optional[int], Optional[int]]:
def get_schema_versions(tag: git.Tag) -> tuple[Optional[int], Optional[int]]:
"""Get the schema and schema compat versions for a tag."""
schema_version = None
schema_compat_version = None
@@ -81,7 +81,7 @@ def get_schema_versions(tag: git.Tag) -> Tuple[Optional[int], Optional[int]]:
# SCHEMA_COMPAT_VERSION is sometimes across multiple lines, the easist
# thing to do is exec the code. Luckily it has only ever existed in
# a file which imports nothing else from Synapse.
locals: Dict[str, Any] = {}
locals: dict[str, Any] = {}
exec(schema_file.data_stream.read().decode("utf-8"), {}, locals)
schema_version = locals["SCHEMA_VERSION"]
schema_compat_version = locals.get("SCHEMA_COMPAT_VERSION")

View File

@@ -7,18 +7,14 @@ from __future__ import annotations
from typing import (
Any,
Callable,
Dict,
Hashable,
ItemsView,
Iterable,
Iterator,
KeysView,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
ValuesView,
@@ -35,14 +31,14 @@ _VT_co = TypeVar("_VT_co", covariant=True)
_SD = TypeVar("_SD", bound=SortedDict)
_Key = Callable[[_T], Any]
class SortedDict(Dict[_KT, _VT]):
class SortedDict(dict[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
self, __iterable: Iterable[tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
@overload
def __init__(self, __key: _Key[_KT], **kwargs: _VT) -> None: ...
@@ -52,7 +48,7 @@ class SortedDict(Dict[_KT, _VT]):
) -> None: ...
@overload
def __init__(
self, __key: _Key[_KT], __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
self, __key: _Key[_KT], __iterable: Iterable[tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
@property
def key(self) -> Optional[_Key[_KT]]: ...
@@ -84,8 +80,8 @@ class SortedDict(Dict[_KT, _VT]):
def pop(self, key: _KT) -> _VT: ...
@overload
def pop(self, key: _KT, default: _T = ...) -> Union[_VT, _T]: ...
def popitem(self, index: int = ...) -> Tuple[_KT, _VT]: ...
def peekitem(self, index: int = ...) -> Tuple[_KT, _VT]: ...
def popitem(self, index: int = ...) -> tuple[_KT, _VT]: ...
def peekitem(self, index: int = ...) -> tuple[_KT, _VT]: ...
def setdefault(self, key: _KT, default: Optional[_VT] = ...) -> _VT: ...
# Mypy now reports the first overload as an error, because typeshed widened the type
# of `__map` to its internal `_typeshed.SupportsKeysAndGetItem` type in
@@ -102,9 +98,9 @@ class SortedDict(Dict[_KT, _VT]):
# def update(self, **kwargs: _VT) -> None: ...
def __reduce__(
self,
) -> Tuple[
Type[SortedDict[_KT, _VT]],
Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
) -> tuple[
type[SortedDict[_KT, _VT]],
tuple[Callable[[_KT], Any], list[tuple[_KT, _VT]]],
]: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
@@ -121,20 +117,20 @@ class SortedKeysView(KeysView[_KT_co], Sequence[_KT_co]):
@overload
def __getitem__(self, index: int) -> _KT_co: ...
@overload
def __getitem__(self, index: slice) -> List[_KT_co]: ...
def __getitem__(self, index: slice) -> list[_KT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[Tuple[_KT_co, _VT_co]]):
def __iter__(self) -> Iterator[Tuple[_KT_co, _VT_co]]: ...
class SortedItemsView(ItemsView[_KT_co, _VT_co], Sequence[tuple[_KT_co, _VT_co]]):
def __iter__(self) -> Iterator[tuple[_KT_co, _VT_co]]: ...
@overload
def __getitem__(self, index: int) -> Tuple[_KT_co, _VT_co]: ...
def __getitem__(self, index: int) -> tuple[_KT_co, _VT_co]: ...
@overload
def __getitem__(self, index: slice) -> List[Tuple[_KT_co, _VT_co]]: ...
def __getitem__(self, index: slice) -> list[tuple[_KT_co, _VT_co]]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
class SortedValuesView(ValuesView[_VT_co], Sequence[_VT_co]):
@overload
def __getitem__(self, index: int) -> _VT_co: ...
@overload
def __getitem__(self, index: slice) -> List[_VT_co]: ...
def __getitem__(self, index: slice) -> list[_VT_co]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...

View File

@@ -9,12 +9,9 @@ from typing import (
Callable,
Iterable,
Iterator,
List,
MutableSequence,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
@@ -37,11 +34,11 @@ class SortedList(MutableSequence[_T]):
): ...
# NB: currently mypy does not honour return type, see mypy #3307
@overload
def __new__(cls: Type[_SL], iterable: None, key: None) -> _SL: ...
def __new__(cls: type[_SL], iterable: None, key: None) -> _SL: ...
@overload
def __new__(cls: Type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ...
def __new__(cls: type[_SL], iterable: None, key: _Key[_T]) -> SortedKeyList[_T]: ...
@overload
def __new__(cls: Type[_SL], iterable: Iterable[_T], key: None) -> _SL: ...
def __new__(cls: type[_SL], iterable: Iterable[_T], key: None) -> _SL: ...
@overload
def __new__(cls, iterable: Iterable[_T], key: _Key[_T]) -> SortedKeyList[_T]: ...
@property
@@ -64,11 +61,11 @@ class SortedList(MutableSequence[_T]):
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> List[_T]: ...
def __getitem__(self, index: slice) -> list[_T]: ...
@overload
def _getitem(self, index: int) -> _T: ...
@overload
def _getitem(self, index: slice) -> List[_T]: ...
def _getitem(self, index: slice) -> list[_T]: ...
@overload
def __setitem__(self, index: int, value: _T) -> None: ...
@overload
@@ -95,7 +92,7 @@ class SortedList(MutableSequence[_T]):
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
inclusive: Tuple[bool, bool] = ...,
inclusive: tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def bisect_left(self, value: _T) -> int: ...
@@ -151,14 +148,14 @@ class SortedKeyList(SortedList[_T]):
self,
minimum: Optional[int] = ...,
maximum: Optional[int] = ...,
inclusive: Tuple[bool, bool] = ...,
inclusive: tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def irange_key(
self,
min_key: Optional[Any] = ...,
max_key: Optional[Any] = ...,
inclusive: Tuple[bool, bool] = ...,
inclusive: tuple[bool, bool] = ...,
reserve: bool = ...,
) -> Iterator[_T]: ...
def bisect_left(self, value: _T) -> int: ...

View File

@@ -10,13 +10,9 @@ from typing import (
Hashable,
Iterable,
Iterator,
List,
MutableSet,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
overload,
@@ -37,7 +33,7 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
) -> None: ...
@classmethod
def _fromset(
cls, values: Set[_T], key: Optional[_Key[_T]] = ...
cls, values: set[_T], key: Optional[_Key[_T]] = ...
) -> SortedSet[_T]: ...
@property
def key(self) -> Optional[_Key[_T]]: ...
@@ -45,7 +41,7 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
@overload
def __getitem__(self, index: int) -> _T: ...
@overload
def __getitem__(self, index: slice) -> List[_T]: ...
def __getitem__(self, index: slice) -> list[_T]: ...
def __delitem__(self, index: Union[int, slice]) -> None: ...
def __eq__(self, other: Any) -> bool: ...
def __ne__(self, other: Any) -> bool: ...
@@ -94,7 +90,7 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
def _update(self, *iterables: Iterable[_S]) -> SortedSet[Union[_T, _S]]: ...
def __reduce__(
self,
) -> Tuple[Type[SortedSet[_T]], Set[_T], Callable[[_T], Any]]: ...
) -> tuple[type[SortedSet[_T]], set[_T], Callable[[_T], Any]]: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
def bisect_left(self, value: _T) -> int: ...
@@ -109,7 +105,7 @@ class SortedSet(MutableSet[_T], Sequence[_T]):
self,
minimum: Optional[_T] = ...,
maximum: Optional[_T] = ...,
inclusive: Tuple[bool, bool] = ...,
inclusive: tuple[bool, bool] = ...,
reverse: bool = ...,
) -> Iterator[_T]: ...
def index(

View File

@@ -15,7 +15,7 @@
"""Contains *incomplete* type hints for txredisapi."""
from typing import Any, List, Optional, Type, Union
from typing import Any, Optional, Union
from twisted.internet import protocol
from twisted.internet.defer import Deferred
@@ -39,7 +39,7 @@ class RedisProtocol(protocol.Protocol):
class SubscriberProtocol(RedisProtocol):
def __init__(self, *args: object, **kwargs: object): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]) -> "Deferred[None]": ...
def subscribe(self, channels: Union[str, list[str]]) -> "Deferred[None]": ...
def connectionMade(self) -> None: ...
# type-ignore: twisted.internet.protocol.Protocol provides a default argument for
# `reason`. txredisapi's LineReceiver Protocol doesn't. But that's fine: it's what's
@@ -69,7 +69,7 @@ class UnixConnectionHandler(ConnectionHandler): ...
class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: ConnectionHandler
pool: List[RedisProtocol]
pool: list[RedisProtocol]
replyTimeout: Optional[int]
def __init__(
self,
@@ -77,7 +77,7 @@ class RedisFactory(protocol.ReconnectingClientFactory):
dbid: Optional[int],
poolsize: int,
isLazy: bool = False,
handler: Type = ConnectionHandler,
handler: type = ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: Optional[int] = None,

View File

@@ -24,7 +24,7 @@
import os
import sys
from typing import Any, Dict
from typing import Any
from PIL import ImageFile
@@ -70,7 +70,7 @@ try:
from canonicaljson import register_preserialisation_callback
from immutabledict import immutabledict
def _immutabledict_cb(d: immutabledict) -> Dict[str, Any]:
def _immutabledict_cb(d: immutabledict) -> dict[str, Any]:
try:
return d._dict
except Exception:

View File

@@ -25,7 +25,7 @@ import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Pattern, Set, Tuple
from typing import Iterable, Optional, Pattern
import yaml
@@ -81,7 +81,7 @@ class EnumerationResource(HttpServer):
"""
def __init__(self, is_worker: bool) -> None:
self.registrations: Dict[Tuple[str, str], EndpointDescription] = {}
self.registrations: dict[tuple[str, str], EndpointDescription] = {}
self._is_worker = is_worker
def register_paths(
@@ -115,7 +115,7 @@ class EnumerationResource(HttpServer):
def get_registered_paths_for_hs(
hs: HomeServer,
) -> Dict[Tuple[str, str], EndpointDescription]:
) -> dict[tuple[str, str], EndpointDescription]:
"""
Given a homeserver, get all registered endpoints and their descriptions.
"""
@@ -142,7 +142,7 @@ def get_registered_paths_for_hs(
def get_registered_paths_for_default(
worker_app: Optional[str], base_config: HomeServerConfig
) -> Dict[Tuple[str, str], EndpointDescription]:
) -> dict[tuple[str, str], EndpointDescription]:
"""
Given the name of a worker application and a base homeserver configuration,
returns:
@@ -168,9 +168,9 @@ def get_registered_paths_for_default(
def elide_http_methods_if_unconflicting(
registrations: Dict[Tuple[str, str], EndpointDescription],
all_possible_registrations: Dict[Tuple[str, str], EndpointDescription],
) -> Dict[Tuple[str, str], EndpointDescription]:
registrations: dict[tuple[str, str], EndpointDescription],
all_possible_registrations: dict[tuple[str, str], EndpointDescription],
) -> dict[tuple[str, str], EndpointDescription]:
"""
Elides HTTP methods (by replacing them with `*`) if all possible registered methods
can be handled by the worker whose registration map is `registrations`.
@@ -180,13 +180,13 @@ def elide_http_methods_if_unconflicting(
"""
def paths_to_methods_dict(
methods_and_paths: Iterable[Tuple[str, str]],
) -> Dict[str, Set[str]]:
methods_and_paths: Iterable[tuple[str, str]],
) -> dict[str, set[str]]:
"""
Given (method, path) pairs, produces a dict from path to set of methods
available at that path.
"""
result: Dict[str, Set[str]] = {}
result: dict[str, set[str]] = {}
for method, path in methods_and_paths:
result.setdefault(path, set()).add(method)
return result
@@ -210,8 +210,8 @@ def elide_http_methods_if_unconflicting(
def simplify_path_regexes(
registrations: Dict[Tuple[str, str], EndpointDescription],
) -> Dict[Tuple[str, str], EndpointDescription]:
registrations: dict[tuple[str, str], EndpointDescription],
) -> dict[tuple[str, str], EndpointDescription]:
"""
Simplify all the path regexes for the dict of endpoint descriptions,
so that we don't use the Python-specific regex extensions
@@ -270,8 +270,8 @@ def main() -> None:
# TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
categories_to_methods_and_paths: Dict[
Optional[str], Dict[Tuple[str, str], EndpointDescription]
categories_to_methods_and_paths: dict[
Optional[str], dict[tuple[str, str], EndpointDescription]
] = defaultdict(dict)
for (method, path), desc in elided_worker_paths.items():
@@ -283,7 +283,7 @@ def main() -> None:
def print_category(
category_name: Optional[str],
elided_worker_paths: Dict[Tuple[str, str], EndpointDescription],
elided_worker_paths: dict[tuple[str, str], EndpointDescription],
) -> None:
"""
Prints out a category, in documentation page style.

View File

@@ -26,7 +26,7 @@ import hashlib
import hmac
import logging
import sys
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Optional
import requests
import yaml
@@ -262,7 +262,7 @@ def main() -> None:
args = parser.parse_args()
config: Optional[Dict[str, Any]] = None
config: Optional[dict[str, Any]] = None
if "config" in args and args.config:
config = yaml.safe_load(args.config)
@@ -350,7 +350,7 @@ def _read_file(file_path: Any, config_path: str) -> str:
sys.exit(1)
def _find_client_listener(config: Dict[str, Any]) -> Optional[str]:
def _find_client_listener(config: dict[str, Any]) -> Optional[str]:
# try to find a listener in the config. Returns a host:port pair
for listener in config.get("listeners", []):
if listener.get("type") != "http" or listener.get("tls", False):

View File

@@ -23,7 +23,6 @@ import argparse
import sys
import time
from datetime import datetime
from typing import List
import attr
@@ -50,15 +49,15 @@ class ReviewConfig(RootConfig):
class UserInfo:
user_id: str
creation_ts: int
emails: List[str] = attr.Factory(list)
private_rooms: List[str] = attr.Factory(list)
public_rooms: List[str] = attr.Factory(list)
ips: List[str] = attr.Factory(list)
emails: list[str] = attr.Factory(list)
private_rooms: list[str] = attr.Factory(list)
public_rooms: list[str] = attr.Factory(list)
ips: list[str] = attr.Factory(list)
def get_recent_users(
txn: LoggingTransaction, since_ms: int, exclude_app_service: bool
) -> List[UserInfo]:
) -> list[UserInfo]:
"""Fetches recently registered users and some info on them."""
sql = """

View File

@@ -33,15 +33,10 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
Type,
TypedDict,
TypeVar,
cast,
@@ -244,7 +239,7 @@ end_error: Optional[str] = None
# not the error then the script will show nothing outside of what's printed in the run
# function. If both are defined, the script will print both the error and the stacktrace.
end_error_exec_info: Optional[
Tuple[Type[BaseException], BaseException, TracebackType]
tuple[type[BaseException], BaseException, TracebackType]
] = None
R = TypeVar("R")
@@ -281,8 +276,8 @@ class Store(
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql: str, *args: object) -> Awaitable[List[Tuple]]:
def r(txn: LoggingTransaction) -> List[Tuple]:
def execute_sql(self, sql: str, *args: object) -> Awaitable[list[tuple]]:
def r(txn: LoggingTransaction) -> list[tuple]:
txn.execute(sql, args)
return txn.fetchall()
@@ -292,8 +287,8 @@ class Store(
self,
txn: LoggingTransaction,
table: str,
headers: List[str],
rows: List[Tuple],
headers: list[str],
rows: list[tuple],
override_system_value: bool = False,
) -> None:
sql = "INSERT INTO %s (%s) %s VALUES (%s)" % (
@@ -330,7 +325,7 @@ class MockHomeserver(HomeServer):
class Porter:
def __init__(
self,
sqlite_config: Dict[str, Any],
sqlite_config: dict[str, Any],
progress: "Progress",
batch_size: int,
hs: HomeServer,
@@ -340,7 +335,7 @@ class Porter:
self.batch_size = batch_size
self.hs = hs
async def setup_table(self, table: str) -> Tuple[str, int, int, int, int]:
async def setup_table(self, table: str) -> tuple[str, int, int, int, int]:
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
row = await self.postgres_store.db_pool.simple_select_one(
@@ -403,10 +398,10 @@ class Porter:
return table, already_ported, total_to_port, forward_chunk, backward_chunk
async def get_table_constraints(self) -> Dict[str, Set[str]]:
async def get_table_constraints(self) -> dict[str, set[str]]:
"""Returns a map of tables that have foreign key constraints to tables they depend on."""
def _get_constraints(txn: LoggingTransaction) -> Dict[str, Set[str]]:
def _get_constraints(txn: LoggingTransaction) -> dict[str, set[str]]:
# We can pull the information about foreign key constraints out from
# the postgres schema tables.
sql = """
@@ -422,7 +417,7 @@ class Porter:
"""
txn.execute(sql)
results: Dict[str, Set[str]] = {}
results: dict[str, set[str]] = {}
for table, foreign_table in txn:
results.setdefault(table, set()).add(foreign_table)
return results
@@ -490,7 +485,7 @@ class Porter:
def r(
txn: LoggingTransaction,
) -> Tuple[Optional[List[str]], List[Tuple], List[Tuple]]:
) -> tuple[Optional[list[str]], list[tuple], list[tuple]]:
forward_rows = []
backward_rows = []
if do_forward[0]:
@@ -507,7 +502,7 @@ class Porter:
if forward_rows or backward_rows:
assert txn.description is not None
headers: Optional[List[str]] = [
headers: Optional[list[str]] = [
column[0] for column in txn.description
]
else:
@@ -574,7 +569,7 @@ class Porter:
while True:
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
def r(txn: LoggingTransaction) -> tuple[list[str], list[tuple]]:
txn.execute(select, (forward_chunk, self.batch_size))
rows = txn.fetchall()
assert txn.description is not None
@@ -956,7 +951,7 @@ class Porter:
self.progress.set_state("Copying to postgres")
constraints = await self.get_table_constraints()
tables_ported = set() # type: Set[str]
tables_ported = set() # type: set[str]
while tables_to_port_info_map:
# Pulls out all tables that are still to be ported and which
@@ -995,8 +990,8 @@ class Porter:
reactor.stop()
def _convert_rows(
self, table: str, headers: List[str], rows: List[Tuple]
) -> List[Tuple]:
self, table: str, headers: list[str], rows: list[tuple]
) -> list[tuple]:
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
@@ -1030,7 +1025,7 @@ class Porter:
return outrows
async def _setup_sent_transactions(self) -> Tuple[int, int, int]:
async def _setup_sent_transactions(self) -> tuple[int, int, int]:
# Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000
@@ -1042,7 +1037,7 @@ class Porter:
")"
)
def r(txn: LoggingTransaction) -> Tuple[List[str], List[Tuple]]:
def r(txn: LoggingTransaction) -> tuple[list[str], list[tuple]]:
txn.execute(select)
rows = txn.fetchall()
assert txn.description is not None
@@ -1112,14 +1107,14 @@ class Porter:
self, table: str, forward_chunk: int, backward_chunk: int
) -> int:
frows = cast(
List[Tuple[int]],
list[tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
),
)
brows = cast(
List[Tuple[int]],
list[tuple[int]],
await self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
),
@@ -1136,7 +1131,7 @@ class Porter:
async def _get_total_count_to_port(
self, table: str, forward_chunk: int, backward_chunk: int
) -> Tuple[int, int]:
) -> tuple[int, int]:
remaining, done = await make_deferred_yieldable(
defer.gatherResults(
[
@@ -1221,7 +1216,7 @@ class Porter:
async def _setup_sequence(
self,
sequence_name: str,
stream_id_tables: Iterable[Tuple[str, str]],
stream_id_tables: Iterable[tuple[str, str]],
) -> None:
"""Set a sequence to the correct value."""
current_stream_ids = []
@@ -1331,7 +1326,7 @@ class Progress:
"""Used to report progress of the port"""
def __init__(self) -> None:
self.tables: Dict[str, TableProgress] = {}
self.tables: dict[str, TableProgress] = {}
self.start_time = int(time.time())

View File

@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING, Optional, Protocol, Tuple
from typing import TYPE_CHECKING, Optional, Protocol
from prometheus_client import Histogram
@@ -51,7 +51,7 @@ class Auth(Protocol):
room_id: str,
requester: Requester,
allow_departed_users: bool = False,
) -> Tuple[str, Optional[str]]:
) -> tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -190,7 +190,7 @@ class Auth(Protocol):
async def check_user_in_room_or_world_readable(
self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> Tuple[str, Optional[str]]:
) -> tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.

View File

@@ -19,7 +19,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional
from netaddr import IPAddress
@@ -64,7 +64,7 @@ class BaseAuth:
room_id: str,
requester: Requester,
allow_departed_users: bool = False,
) -> Tuple[str, Optional[str]]:
) -> tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -114,7 +114,7 @@ class BaseAuth:
@trace
async def check_user_in_room_or_world_readable(
self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> Tuple[str, Optional[str]]:
) -> tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.

View File

@@ -13,7 +13,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Optional, Set
from typing import TYPE_CHECKING, Optional
from urllib.parse import urlencode
from synapse._pydantic_compat import (
@@ -369,7 +369,7 @@ class MasDelegatedAuth(BaseAuth):
# We only allow a single device_id in the scope, so we find them all in the
# scope list, and raise if there are more than one. The OIDC server should be
# the one enforcing valid scopes, so we raise a 500 if we find an invalid scope.
device_ids: Set[str] = set()
device_ids: set[str] = set()
for tok in scope:
if tok.startswith(UNSTABLE_SCOPE_MATRIX_DEVICE_PREFIX):
device_ids.add(tok[len(UNSTABLE_SCOPE_MATRIX_DEVICE_PREFIX) :])

View File

@@ -20,7 +20,7 @@
#
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, Callable, Optional
from urllib.parse import urlencode
from authlib.oauth2 import ClientAuth
@@ -70,7 +70,7 @@ STABLE_SCOPE_MATRIX_DEVICE_PREFIX = "urn:matrix:client:device:"
SCOPE_SYNAPSE_ADMIN = "urn:synapse:admin:*"
def scope_to_list(scope: str) -> List[str]:
def scope_to_list(scope: str) -> list[str]:
"""Convert a scope string to a list of scope tokens"""
return scope.strip().split(" ")
@@ -96,7 +96,7 @@ class IntrospectionResult:
absolute_expiry_ms = expires_in * 1000 + self.retrieved_at_ms
return now_ms < absolute_expiry_ms
def get_scope_list(self) -> List[str]:
def get_scope_list(self) -> list[str]:
value = self._inner.get("scope")
if not isinstance(value, str):
return []
@@ -264,7 +264,7 @@ class MSC3861DelegatedAuth(BaseAuth):
logger.warning("Failed to load metadata:", exc_info=True)
return None
async def auth_metadata(self) -> Dict[str, Any]:
async def auth_metadata(self) -> dict[str, Any]:
"""
Returns the auth metadata dict
"""
@@ -303,7 +303,7 @@ class MSC3861DelegatedAuth(BaseAuth):
# By default, we shouldn't cache the result unless we know it's valid
cache_context.should_cache = False
introspection_endpoint = await self._introspection_endpoint()
raw_headers: Dict[str, str] = {
raw_headers: dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
# Tell MAS that we support reading the device ID as an explicit
@@ -520,7 +520,7 @@ class MSC3861DelegatedAuth(BaseAuth):
raise InvalidClientTokenError("Token is not active")
# Let's look at the scope
scope: List[str] = introspection_result.get_scope_list()
scope: list[str] = introspection_result.get_scope_list()
# Determine type of user based on presence of particular scopes
has_user_scope = (
@@ -575,7 +575,7 @@ class MSC3861DelegatedAuth(BaseAuth):
# We only allow a single device_id in the scope, so we find them all in the
# scope list, and raise if there are more than one. The OIDC server should be
# the one enforcing valid scopes, so we raise a 500 if we find an invalid scope.
device_ids: Set[str] = set()
device_ids: set[str] = set()
for tok in scope:
if tok.startswith(UNSTABLE_SCOPE_MATRIX_DEVICE_PREFIX):
device_ids.add(tok[len(UNSTABLE_SCOPE_MATRIX_DEVICE_PREFIX) :])

View File

@@ -26,7 +26,7 @@ import math
import typing
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from twisted.web import http
@@ -166,7 +166,7 @@ class CodeMessageException(RuntimeError):
self,
code: Union[int, HTTPStatus],
msg: str,
headers: Optional[Dict[str, str]] = None,
headers: Optional[dict[str, str]] = None,
):
super().__init__("%d: %s" % (code, msg))
@@ -201,7 +201,7 @@ class RedirectException(CodeMessageException):
super().__init__(code=http_code, msg=msg)
self.location = location
self.cookies: List[bytes] = []
self.cookies: list[bytes] = []
class SynapseError(CodeMessageException):
@@ -223,8 +223,8 @@ class SynapseError(CodeMessageException):
code: int,
msg: str,
errcode: str = Codes.UNKNOWN,
additional_fields: Optional[Dict] = None,
headers: Optional[Dict[str, str]] = None,
additional_fields: Optional[dict] = None,
headers: Optional[dict[str, str]] = None,
):
"""Constructs a synapse error.
@@ -236,7 +236,7 @@ class SynapseError(CodeMessageException):
super().__init__(code, msg, headers)
self.errcode = errcode
if additional_fields is None:
self._additional_fields: Dict = {}
self._additional_fields: dict = {}
else:
self._additional_fields = dict(additional_fields)
@@ -276,7 +276,7 @@ class ProxiedRequestError(SynapseError):
code: int,
msg: str,
errcode: str = Codes.UNKNOWN,
additional_fields: Optional[Dict] = None,
additional_fields: Optional[dict] = None,
):
super().__init__(code, msg, errcode, additional_fields)
@@ -409,7 +409,7 @@ class OAuthInsufficientScopeError(SynapseError):
def __init__(
self,
required_scopes: List[str],
required_scopes: list[str],
):
headers = {
"WWW-Authenticate": 'Bearer error="insufficient_scope", scope="%s"'

View File

@@ -26,12 +26,9 @@ from typing import (
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
TypeVar,
Union,
)
@@ -248,34 +245,34 @@ class FilterCollection:
async def filter_presence(
self, presence_states: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
) -> list[UserPresenceState]:
return await self._presence_filter.filter(presence_states)
async def filter_global_account_data(
self, events: Iterable[JsonDict]
) -> List[JsonDict]:
) -> list[JsonDict]:
return await self._global_account_data_filter.filter(events)
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
async def filter_room_state(self, events: Iterable[EventBase]) -> list[EventBase]:
return await self._room_state_filter.filter(
await self._room_filter.filter(events)
)
async def filter_room_timeline(
self, events: Iterable[EventBase]
) -> List[EventBase]:
) -> list[EventBase]:
return await self._room_timeline_filter.filter(
await self._room_filter.filter(events)
)
async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> list[JsonDict]:
return await self._room_ephemeral_filter.filter(
await self._room_filter.filter(events)
)
async def filter_room_account_data(
self, events: Iterable[JsonDict]
) -> List[JsonDict]:
) -> list[JsonDict]:
return await self._room_account_data_filter.filter(
await self._room_filter.filter(events)
)
@@ -440,7 +437,7 @@ class Filter:
return True
def _check_fields(self, field_matchers: Dict[str, Callable[[str], bool]]) -> bool:
def _check_fields(self, field_matchers: dict[str, Callable[[str], bool]]) -> bool:
"""Checks whether the filter matches the given event fields.
Args:
@@ -474,7 +471,7 @@ class Filter:
# Otherwise, accept it.
return True
def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
def filter_rooms(self, room_ids: Iterable[str]) -> set[str]:
"""Apply the 'rooms' filter to a given list of rooms.
Args:
@@ -496,7 +493,7 @@ class Filter:
async def _check_event_relations(
self, events: Collection[FilterEvent]
) -> List[FilterEvent]:
) -> list[FilterEvent]:
# The event IDs to check, mypy doesn't understand the isinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
@@ -511,7 +508,7 @@ class Filter:
if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
]
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
async def filter(self, events: Iterable[FilterEvent]) -> list[FilterEvent]:
result = [event for event in events if self._check(event)]
if self.related_by_senders or self.related_by_rel_types:

View File

@@ -20,7 +20,7 @@
#
#
from typing import TYPE_CHECKING, Dict, Hashable, Optional, Tuple
from typing import TYPE_CHECKING, Hashable, Optional
from synapse.api.errors import LimitExceededError
from synapse.config.ratelimiting import RatelimitSettings
@@ -92,7 +92,7 @@ class Ratelimiter:
# * The number of tokens currently in the bucket,
# * The time point when the bucket was last completely empty, and
# * The rate_hz (leak rate) of this particular bucket.
self.actions: Dict[Hashable, Tuple[float, float, float]] = {}
self.actions: dict[Hashable, tuple[float, float, float]] = {}
self.clock.looping_call(self._prune_message_counts, 60 * 1000)
@@ -109,7 +109,7 @@ class Ratelimiter:
def _get_action_counts(
self, key: Hashable, time_now_s: float
) -> Tuple[float, float, float]:
) -> tuple[float, float, float]:
"""Retrieve the action counts, with a fallback representing an empty bucket."""
return self.actions.get(key, (0.0, time_now_s, 0.0))
@@ -122,7 +122,7 @@ class Ratelimiter:
update: bool = True,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
) -> Tuple[bool, float]:
) -> tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
Checks if the user has ratelimiting disabled in the database by looking

View File

@@ -18,7 +18,7 @@
#
#
from typing import Callable, Dict, Optional, Tuple
from typing import Callable, Optional
import attr
@@ -109,7 +109,7 @@ class RoomVersion:
# is not enough to mark it "supported": the push rule evaluator also needs to
# support the flag. Unknown flags are ignored by the evaluator, making conditions
# fail if used.
msc3931_push_features: Tuple[str, ...] # values from PushRuleRoomFlag
msc3931_push_features: tuple[str, ...] # values from PushRuleRoomFlag
# MSC3757: Restricting who can overwrite a state event
msc3757_enabled: bool
# MSC4289: Creator power enabled
@@ -476,7 +476,7 @@ class RoomVersions:
)
KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
KNOWN_ROOM_VERSIONS: dict[str, RoomVersion] = {
v.identifier: v
for v in (
RoomVersions.V1,

View File

@@ -34,11 +34,8 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
NoReturn,
Optional,
Tuple,
cast,
)
from wsgiref.simple_server import WSGIServer
@@ -98,8 +95,8 @@ reactor = cast(ISynapseReactor, _reactor)
logger = logging.getLogger(__name__)
_instance_id_to_sighup_callbacks_map: Dict[
str, List[Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]]
_instance_id_to_sighup_callbacks_map: dict[
str, list[tuple[Callable[..., None], tuple[object, ...], dict[str, object]]]
] = {}
"""
Map from homeserver instance_id to a list of callbacks.
@@ -176,7 +173,7 @@ def start_worker_reactor(
def start_reactor(
appname: str,
soft_file_limit: int,
gc_thresholds: Optional[Tuple[int, int, int]],
gc_thresholds: Optional[tuple[int, int, int]],
pid_file: Optional[str],
daemonize: bool,
print_pidfile: bool,
@@ -309,7 +306,7 @@ def register_start(
def listen_metrics(
bind_addresses: StrCollection, port: int
) -> List[Tuple[WSGIServer, Thread]]:
) -> list[tuple[WSGIServer, Thread]]:
"""
Start Prometheus metrics server.
@@ -330,7 +327,7 @@ def listen_metrics(
from synapse.metrics import RegistryProxy
servers: List[Tuple[WSGIServer, Thread]] = []
servers: list[tuple[WSGIServer, Thread]] = []
for host in bind_addresses:
logger.info("Starting metrics listener on %s:%d", host, port)
server, thread = start_http_server_prometheus(
@@ -345,7 +342,7 @@ def listen_manhole(
port: int,
manhole_settings: ManholeConfig,
manhole_globals: dict,
) -> List[Port]:
) -> list[Port]:
# twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing
# warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so
# suppress the warning for now.
@@ -370,7 +367,7 @@ def listen_tcp(
factory: ServerFactory,
reactor: IReactorTCP = reactor,
backlog: int = 50,
) -> List[Port]:
) -> list[Port]:
"""
Create a TCP socket for a port and several addresses
@@ -395,7 +392,7 @@ def listen_unix(
factory: ServerFactory,
reactor: IReactorUNIX = reactor,
backlog: int = 50,
) -> List[Port]:
) -> list[Port]:
"""
Create a UNIX socket for a given path and 'mode' permission
@@ -419,7 +416,7 @@ def listen_http(
max_request_body_size: int,
context_factory: Optional[IOpenSSLContextFactory],
reactor: ISynapseReactor = reactor,
) -> List[Port]:
) -> list[Port]:
"""
Args:
listener_config: TODO
@@ -489,7 +486,7 @@ def listen_ssl(
context_factory: IOpenSSLContextFactory,
reactor: IReactorSSL = reactor,
backlog: int = 50,
) -> List[Port]:
) -> list[Port]:
"""
Create an TLS-over-TCP socket for a port and several addresses

View File

@@ -24,7 +24,7 @@ import logging
import os
import sys
import tempfile
from typing import List, Mapping, Optional, Sequence, Tuple
from typing import Mapping, Optional, Sequence
from twisted.internet import defer, task
@@ -150,7 +150,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
if list(os.listdir(self.base_directory)):
raise Exception("Directory must be empty")
def write_events(self, room_id: str, events: List[EventBase]) -> None:
def write_events(self, room_id: str, events: list[EventBase]) -> None:
room_directory = os.path.join(self.base_directory, "rooms", room_id)
os.makedirs(room_directory, exist_ok=True)
events_file = os.path.join(room_directory, "events")
@@ -255,7 +255,7 @@ class FileExfiltrationWriter(ExfiltrationWriter):
return self.base_directory
def load_config(argv_options: List[str]) -> Tuple[HomeServerConfig, argparse.Namespace]:
def load_config(argv_options: list[str]) -> tuple[HomeServerConfig, argparse.Namespace]:
parser = argparse.ArgumentParser(description="Synapse Admin Command")
HomeServerConfig.add_arguments_to_parser(parser)

View File

@@ -26,13 +26,13 @@ import os
import signal
import sys
from types import FrameType
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
from twisted.internet.main import installReactor
# a list of the original signal handlers, before we installed our custom ones.
# We restore these in our child processes.
_original_signal_handlers: Dict[int, Any] = {}
_original_signal_handlers: dict[int, Any] = {}
class ProxiedReactor:
@@ -72,7 +72,7 @@ class ProxiedReactor:
def _worker_entrypoint(
func: Callable[[], None], proxy_reactor: ProxiedReactor, args: List[str]
func: Callable[[], None], proxy_reactor: ProxiedReactor, args: list[str]
) -> None:
"""
Entrypoint for a forked worker process.
@@ -128,7 +128,7 @@ def main() -> None:
# Split up the subsequent arguments into each workers' arguments;
# `--` is our delimiter of choice.
args_by_worker: List[List[str]] = [
args_by_worker: list[list[str]] = [
list(args)
for cond, args in itertools.groupby(ns.args, lambda ele: ele != "--")
if cond and args
@@ -167,7 +167,7 @@ def main() -> None:
update_proc.join()
print("===== PREPARED DATABASE =====", file=sys.stderr)
processes: List[multiprocessing.Process] = []
processes: list[multiprocessing.Process] = []
# Install signal handlers to propagate signals to all our children, so that they
# shut down cleanly. This also inhibits our own exit, but that's good: we want to

View File

@@ -21,7 +21,6 @@
#
import logging
import sys
from typing import Dict, List
from twisted.web.resource import Resource
@@ -181,7 +180,7 @@ class GenericWorkerServer(HomeServer):
# We always include an admin resource that we populate with servlets as needed
admin_resource = JsonResource(self, canonical_json=False)
resources: Dict[str, Resource] = {
resources: dict[str, Resource] = {
# We always include a health resource.
"/health": HealthResource(),
"/_synapse/admin": admin_resource,
@@ -314,7 +313,7 @@ class GenericWorkerServer(HomeServer):
self.get_replication_command_handler().start_replication(self)
def load_config(argv_options: List[str]) -> HomeServerConfig:
def load_config(argv_options: list[str]) -> HomeServerConfig:
"""
Parse the commandline and config files (does not generate config)

View File

@@ -22,7 +22,7 @@
import logging
import os
import sys
from typing import Dict, Iterable, List, Optional
from typing import Iterable, Optional
from twisted.internet.tcp import Port
from twisted.web.resource import EncodingResourceWrapper, Resource
@@ -99,7 +99,7 @@ class SynapseHomeServer(HomeServer):
site_tag = listener_config.get_site_tag()
# We always include a health resource.
resources: Dict[str, Resource] = {"/health": HealthResource()}
resources: dict[str, Resource] = {"/health": HealthResource()}
for res in listener_config.http_options.resources:
for name in res.names:
@@ -170,7 +170,7 @@ class SynapseHomeServer(HomeServer):
def _configure_named_resource(
self, name: str, compress: bool = False
) -> Dict[str, Resource]:
) -> dict[str, Resource]:
"""Build a resource map for a named resource
Args:
@@ -180,7 +180,7 @@ class SynapseHomeServer(HomeServer):
Returns:
map from path to HTTP resource
"""
resources: Dict[str, Resource] = {}
resources: dict[str, Resource] = {}
if name == "client":
client_resource: Resource = ClientRestResource(self)
if compress:
@@ -318,7 +318,7 @@ class SynapseHomeServer(HomeServer):
logger.warning("Unrecognized listener type: %s", listener.type)
def load_or_generate_config(argv_options: List[str]) -> HomeServerConfig:
def load_or_generate_config(argv_options: list[str]) -> HomeServerConfig:
"""
Parse the commandline and config files

View File

@@ -22,7 +22,7 @@ import logging
import math
import resource
import sys
from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple
from typing import TYPE_CHECKING, Mapping, Sized
from prometheus_client import Gauge
@@ -54,7 +54,7 @@ Phone home stats are sent every 3 hours
# Contains the list of processes we will be monitoring
# currently either 0 or 1
_stats_process: List[Tuple[int, "resource.struct_rusage"]] = []
_stats_process: list[tuple[int, "resource.struct_rusage"]] = []
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge(
@@ -82,12 +82,12 @@ registered_reserved_users_mau_gauge = Gauge(
def phone_stats_home(
hs: "HomeServer",
stats: JsonDict,
stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
stats_process: list[tuple[int, "resource.struct_rusage"]] = _stats_process,
) -> "defer.Deferred[None]":
async def _phone_stats_home(
hs: "HomeServer",
stats: JsonDict,
stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process,
stats_process: list[tuple[int, "resource.struct_rusage"]] = _stats_process,
) -> None:
"""Collect usage statistics and send them to the configured endpoint.

View File

@@ -25,9 +25,7 @@ import re
from enum import Enum
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Optional,
Pattern,
Sequence,
@@ -59,11 +57,11 @@ logger = logging.getLogger(__name__)
# Type for the `device_one_time_keys_count` field in an appservice transaction
# user ID -> {device ID -> {algorithm -> count}}
TransactionOneTimeKeysCount = Dict[str, Dict[str, Dict[str, int]]]
TransactionOneTimeKeysCount = dict[str, dict[str, dict[str, int]]]
# Type for the `device_unused_fallback_key_types` field in an appservice transaction
# user ID -> {device ID -> [algorithm]}
TransactionUnusedFallbackKeys = Dict[str, Dict[str, List[str]]]
TransactionUnusedFallbackKeys = dict[str, dict[str, list[str]]]
class ApplicationServiceState(Enum):
@@ -145,7 +143,7 @@ class ApplicationService:
def _check_namespaces(
self, namespaces: Optional[JsonDict]
) -> Dict[str, List[Namespace]]:
) -> dict[str, list[Namespace]]:
# Sanity check that it is of the form:
# {
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
@@ -155,7 +153,7 @@ class ApplicationService:
if namespaces is None:
namespaces = {}
result: Dict[str, List[Namespace]] = {}
result: dict[str, list[Namespace]] = {}
for ns in ApplicationService.NS_LIST:
result[ns] = []
@@ -388,7 +386,7 @@ class ApplicationService:
def is_exclusive_room(self, room_id: str) -> bool:
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def get_exclusive_user_regexes(self) -> List[Pattern[str]]:
def get_exclusive_user_regexes(self) -> list[Pattern[str]]:
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
@@ -417,8 +415,8 @@ class AppServiceTransaction:
service: ApplicationService,
id: int,
events: Sequence[EventBase],
ephemeral: List[JsonMapping],
to_device_messages: List[JsonMapping],
ephemeral: list[JsonMapping],
to_device_messages: list[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,

View File

@@ -23,13 +23,10 @@ import logging
import urllib.parse
from typing import (
TYPE_CHECKING,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
@@ -133,14 +130,14 @@ class ApplicationServiceApi(SimpleHttpClient):
self.clock = hs.get_clock()
self.config = hs.config.appservice
self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
self.protocol_meta_cache: ResponseCache[tuple[str, str]] = ResponseCache(
clock=hs.get_clock(),
name="as_protocol_meta",
server_name=self.server_name,
timeout_ms=HOUR_IN_MS,
)
def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]:
def _get_headers(self, service: "ApplicationService") -> dict[bytes, list[bytes]]:
"""This makes sure we have always the auth header and opentracing headers set."""
# This is also ensured before in the functions. However this is needed to please
@@ -210,8 +207,8 @@ class ApplicationServiceApi(SimpleHttpClient):
service: "ApplicationService",
kind: str,
protocol: str,
fields: Dict[bytes, List[bytes]],
) -> List[JsonDict]:
fields: dict[bytes, list[bytes]],
) -> list[JsonDict]:
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
@@ -225,7 +222,7 @@ class ApplicationServiceApi(SimpleHttpClient):
assert service.hs_token is not None
try:
args: Mapping[bytes, Union[List[bytes], str]] = fields
args: Mapping[bytes, Union[list[bytes], str]] = fields
if self.config.use_appservice_legacy_authorization:
args = {
**fields,
@@ -320,8 +317,8 @@ class ApplicationServiceApi(SimpleHttpClient):
self,
service: "ApplicationService",
events: Sequence[EventBase],
ephemeral: List[JsonMapping],
to_device_messages: List[JsonMapping],
ephemeral: list[JsonMapping],
to_device_messages: list[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
@@ -429,9 +426,9 @@ class ApplicationServiceApi(SimpleHttpClient):
return False
async def claim_client_keys(
self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
self, service: "ApplicationService", query: list[tuple[str, str, str, int]]
) -> tuple[
dict[str, dict[str, dict[str, JsonDict]]], list[tuple[str, str, str, int]]
]:
"""Claim one time keys from an application service.
@@ -457,7 +454,7 @@ class ApplicationServiceApi(SimpleHttpClient):
assert service.hs_token is not None
# Create the expected payload shape.
body: Dict[str, Dict[str, List[str]]] = {}
body: dict[str, dict[str, list[str]]] = {}
for user_id, device, algorithm, count in query:
body.setdefault(user_id, {}).setdefault(device, []).extend(
[algorithm] * count
@@ -502,8 +499,8 @@ class ApplicationServiceApi(SimpleHttpClient):
return response, missing
async def query_keys(
self, service: "ApplicationService", query: Dict[str, List[str]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
self, service: "ApplicationService", query: dict[str, list[str]]
) -> dict[str, dict[str, dict[str, JsonDict]]]:
"""Query the application service for keys.
Note that any error (including a timeout) is treated as the application
@@ -545,7 +542,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
) -> list[JsonDict]:
time_now = self.clock.time_msec()
return [
serialize_event(

View File

@@ -61,13 +61,9 @@ from typing import (
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
from twisted.internet.interfaces import IDelayedCall
@@ -183,16 +179,16 @@ class _ServiceQueuer:
def __init__(self, txn_ctrl: "_TransactionController", hs: "HomeServer"):
# dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {}
self.queued_events: dict[str, list[EventBase]] = {}
# dict of {service_id: [events]}
self.queued_ephemeral: Dict[str, List[JsonMapping]] = {}
self.queued_ephemeral: dict[str, list[JsonMapping]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {}
self.queued_to_device_messages: dict[str, list[JsonMapping]] = {}
# dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
self.queued_device_list_summaries: dict[str, list[DeviceListUpdates]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
self.requests_in_flight: set[str] = set()
self.txn_ctrl = txn_ctrl
self._msc3202_transaction_extensions_enabled: bool = (
hs.config.experimental.msc3202_transaction_extensions
@@ -302,7 +298,7 @@ class _ServiceQueuer:
events: Iterable[EventBase],
ephemerals: Iterable[JsonMapping],
to_device_messages: Iterable[JsonMapping],
) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]:
) -> tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]:
"""
Given a list of the events, ephemeral messages and to-device messages,
- first computes a list of application services users that may have
@@ -313,14 +309,14 @@ class _ServiceQueuer:
"""
# Set of 'interesting' users who may have updates
users: Set[str] = set()
users: set[str] = set()
# The sender is always included
users.add(service.sender.to_string())
# All AS users that would receive the PDUs or EDUs sent to these rooms
# are classed as 'interesting'.
rooms_of_interesting_users: Set[str] = set()
rooms_of_interesting_users: set[str] = set()
# PDUs
rooms_of_interesting_users.update(event.room_id for event in events)
# EDUs
@@ -364,7 +360,7 @@ class _TransactionController:
self.as_api = hs.get_application_service_api()
# map from service id to recoverer instance
self.recoverers: Dict[str, "_Recoverer"] = {}
self.recoverers: dict[str, "_Recoverer"] = {}
# for UTs
self.RECOVERER_CLASS = _Recoverer
@@ -373,8 +369,8 @@ class _TransactionController:
self,
service: ApplicationService,
events: Sequence[EventBase],
ephemeral: Optional[List[JsonMapping]] = None,
to_device_messages: Optional[List[JsonMapping]] = None,
ephemeral: Optional[list[JsonMapping]] = None,
to_device_messages: Optional[list[JsonMapping]] = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None,

View File

@@ -20,13 +20,12 @@
#
#
import sys
from typing import List
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
def main(args: List[str]) -> None:
def main(args: list[str]) -> None:
action = args[1] if len(args) > 1 and args[1] == "read" else None
# If we're reading a key in the config file, then `args[1]` will be `read` and `args[2]`
# will be the key to read.

View File

@@ -33,14 +33,10 @@ from textwrap import dedent
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
@@ -321,9 +317,9 @@ class Config:
def read_templates(
self,
filenames: List[str],
filenames: list[str],
custom_template_directories: Optional[Iterable[str]] = None,
) -> List[jinja2.Template]:
) -> list[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
This function will attempt to load the given templates from the default Synapse
@@ -402,7 +398,7 @@ class RootConfig:
class, lower-cased and with "Config" removed.
"""
config_classes: List[Type[Config]] = []
config_classes: list[type[Config]] = []
def __init__(self, config_files: StrSequence = ()):
# Capture absolute paths here, so we can reload config after we daemonize.
@@ -471,7 +467,7 @@ class RootConfig:
generate_secrets: bool = False,
report_stats: Optional[bool] = None,
open_private_ports: bool = False,
listeners: Optional[List[dict]] = None,
listeners: Optional[list[dict]] = None,
tls_certificate_path: Optional[str] = None,
tls_private_key_path: Optional[str] = None,
) -> str:
@@ -545,7 +541,7 @@ class RootConfig:
@classmethod
def load_config(
cls: Type[TRootConfig], description: str, argv_options: List[str]
cls: type[TRootConfig], description: str, argv_options: list[str]
) -> TRootConfig:
"""Parse the commandline and config files
@@ -605,8 +601,8 @@ class RootConfig:
@classmethod
def load_config_with_parser(
cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv_options: List[str]
) -> Tuple[TRootConfig, argparse.Namespace]:
cls: type[TRootConfig], parser: argparse.ArgumentParser, argv_options: list[str]
) -> tuple[TRootConfig, argparse.Namespace]:
"""Parse the commandline and config files with the given parser
Doesn't support config-file-generation: used by the worker apps.
@@ -658,7 +654,7 @@ class RootConfig:
@classmethod
def load_or_generate_config(
cls: Type[TRootConfig], description: str, argv_options: List[str]
cls: type[TRootConfig], description: str, argv_options: list[str]
) -> Optional[TRootConfig]:
"""Parse the commandline and config files
@@ -858,7 +854,7 @@ class RootConfig:
def parse_config_dict(
self,
config_dict: Dict[str, Any],
config_dict: dict[str, Any],
config_dir_path: str,
data_dir_path: str,
allow_secrets_in_config: bool = True,
@@ -883,7 +879,7 @@ class RootConfig:
)
def generate_missing_files(
self, config_dict: Dict[str, Any], config_dir_path: str
self, config_dict: dict[str, Any], config_dir_path: str
) -> None:
self.invoke_all("generate_files", config_dict, config_dir_path)
@@ -930,7 +926,7 @@ class RootConfig:
"""
def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
def read_config_files(config_files: Iterable[str]) -> dict[str, Any]:
"""Read the config files and shallowly merge them into a dict.
Successive configurations are shallowly merged into ones provided earlier,
@@ -964,7 +960,7 @@ def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
return specified_config
def find_config_files(search_paths: List[str]) -> List[str]:
def find_config_files(search_paths: list[str]) -> list[str]:
"""Finds config files using a list of search paths. If a path is a file
then that file path is added to the list. If a search path is a directory
then all the "*.yaml" files in that directory are added to the list in
@@ -1018,7 +1014,7 @@ class ShardedWorkerHandlingConfig:
below).
"""
instances: List[str]
instances: list[str]
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key."""

View File

@@ -2,15 +2,11 @@ import argparse
from typing import (
Any,
Collection,
Dict,
Iterable,
Iterator,
List,
Literal,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
@@ -129,8 +125,8 @@ class RootConfig:
mas: mas.MasConfig
matrix_rtc: matrixrtc.MatrixRtcConfig
config_classes: List[Type["Config"]] = ...
config_files: List[str]
config_classes: list[type["Config"]] = ...
config_files: list[str]
def __init__(self, config_files: Collection[str] = ...) -> None: ...
def invoke_all(
self, func_name: str, *args: Any, **kwargs: Any
@@ -139,7 +135,7 @@ class RootConfig:
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
def parse_config_dict(
self,
config_dict: Dict[str, Any],
config_dict: dict[str, Any],
config_dir_path: str,
data_dir_path: str,
allow_secrets_in_config: bool = ...,
@@ -158,11 +154,11 @@ class RootConfig:
) -> str: ...
@classmethod
def load_or_generate_config(
cls: Type[TRootConfig], description: str, argv_options: List[str]
cls: type[TRootConfig], description: str, argv_options: list[str]
) -> Optional[TRootConfig]: ...
@classmethod
def load_config(
cls: Type[TRootConfig], description: str, argv_options: List[str]
cls: type[TRootConfig], description: str, argv_options: list[str]
) -> TRootConfig: ...
@classmethod
def add_arguments_to_parser(
@@ -170,8 +166,8 @@ class RootConfig:
) -> None: ...
@classmethod
def load_config_with_parser(
cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv_options: List[str]
) -> Tuple[TRootConfig, argparse.Namespace]: ...
cls: type[TRootConfig], parser: argparse.ArgumentParser, argv_options: list[str]
) -> tuple[TRootConfig, argparse.Namespace]: ...
def generate_missing_files(
self, config_dict: dict, config_dir_path: str
) -> None: ...
@@ -203,16 +199,16 @@ class Config:
def read_template(self, filenames: str) -> jinja2.Template: ...
def read_templates(
self,
filenames: List[str],
filenames: list[str],
custom_template_directories: Optional[Iterable[str]] = None,
) -> List[jinja2.Template]: ...
) -> list[jinja2.Template]: ...
def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ...
def find_config_files(search_paths: List[str]) -> List[str]: ...
def read_config_files(config_files: Iterable[str]) -> dict[str, Any]: ...
def find_config_files(search_paths: list[str]) -> list[str]: ...
class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
instances: list[str]
def __init__(self, instances: list[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ... # noqa: F811
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):

View File

@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Any, Dict, Type, TypeVar
from typing import Any, TypeVar
import jsonschema
@@ -79,8 +79,8 @@ Model = TypeVar("Model", bound=BaseModel)
def parse_and_validate_mapping(
config: Any,
model_type: Type[Model],
) -> Dict[str, Model]:
model_type: type[Model],
) -> dict[str, Model]:
"""Parse `config` as a mapping from strings to a given `Model` type.
Args:
config: The configuration data to check
@@ -93,7 +93,7 @@ def parse_and_validate_mapping(
try:
# type-ignore: mypy doesn't like constructing `Dict[str, model_type]` because
# `model_type` is a runtime variable. Pydantic is fine with this.
instances = parse_obj_as(Dict[str, model_type], config) # type: ignore[valid-type]
instances = parse_obj_as(dict[str, model_type], config) # type: ignore[valid-type]
except ValidationError as e:
raise ConfigError(str(e)) from e
return instances

View File

@@ -20,7 +20,7 @@
#
import logging
from typing import Any, Iterable, Optional, Tuple
from typing import Any, Iterable, Optional
from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
@@ -46,7 +46,7 @@ class ApiConfig(Config):
def _get_prejoin_state_entries(
self, config: JsonDict
) -> Iterable[Tuple[str, Optional[str]]]:
) -> Iterable[tuple[str, Optional[str]]]:
"""Get the event types and state keys to include in the prejoin state."""
room_prejoin_state_config = config.get("room_prejoin_state") or {}

View File

@@ -21,7 +21,7 @@
#
import logging
from typing import Any, Dict, List
from typing import Any
from urllib import parse as urlparse
import yaml
@@ -61,13 +61,13 @@ class AppServiceConfig(Config):
def load_appservices(
hostname: str, config_files: List[str]
) -> List[ApplicationService]:
hostname: str, config_files: list[str]
) -> list[ApplicationService]:
"""Returns a list of Application Services from the config files."""
# Dicts of value -> filename
seen_as_tokens: Dict[str, str] = {}
seen_ids: Dict[str, str] = {}
seen_as_tokens: dict[str, str] = {}
seen_ids: dict[str, str] = {}
appservices = []

View File

@@ -23,7 +23,7 @@ import logging
import os
import re
import threading
from typing import Any, Callable, Dict, Mapping, Optional
from typing import Any, Callable, Mapping, Optional
import attr
@@ -38,7 +38,7 @@ logger = logging.getLogger(__name__)
_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
# Map from canonicalised cache name to cache.
_CACHES: Dict[str, Callable[[float], None]] = {}
_CACHES: dict[str, Callable[[float], None]] = {}
# a lock on the contents of _CACHES
_CACHES_LOCK = threading.Lock()
@@ -104,7 +104,7 @@ class CacheConfig(Config):
_environ: Mapping[str, str] = os.environ
event_cache_size: int
cache_factors: Dict[str, float]
cache_factors: dict[str, float]
global_factor: float
track_memory_usage: bool
expiry_time_msec: Optional[int]

View File

@@ -20,7 +20,7 @@
#
#
from typing import Any, List, Optional
from typing import Any, Optional
from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict
@@ -107,7 +107,7 @@ REQUIRED_ATTRIBUTES_SCHEMA = {
def _parsed_required_attributes_def(
required_attributes: Any,
) -> List[SsoAttributeRequirement]:
) -> list[SsoAttributeRequirement]:
validate_config(
REQUIRED_ATTRIBUTES_SCHEMA,
required_attributes,

View File

@@ -22,7 +22,7 @@
import argparse
import logging
import os
from typing import Any, List
from typing import Any
from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict
@@ -83,7 +83,7 @@ class DatabaseConfig(Config):
def __init__(self, *args: Any):
super().__init__(*args)
self.databases: List[DatabaseConnectionConfig] = []
self.databases: list[DatabaseConnectionConfig] = []
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
# We *experimentally* support specifying multiple databases via the

View File

@@ -23,7 +23,7 @@
import hashlib
import logging
import os
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional
from typing import TYPE_CHECKING, Any, Iterator, Optional
import attr
import jsonschema
@@ -110,7 +110,7 @@ class TrustedKeyServer:
server_name: str
# map from key id to key object, or None to disable signature verification.
verify_keys: Optional[Dict[str, VerifyKey]] = None
verify_keys: Optional[dict[str, VerifyKey]] = None
class KeyConfig(Config):
@@ -250,7 +250,7 @@ class KeyConfig(Config):
- server_name: "matrix.org"
""" % locals()
def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
def read_signing_keys(self, signing_key_path: str, name: str) -> list[SigningKey]:
"""Read the signing keys in the given path.
Args:
@@ -280,7 +280,7 @@ class KeyConfig(Config):
def read_old_signing_keys(
self, old_signing_keys: Optional[JsonDict]
) -> Dict[str, "VerifyKeyWithExpiry"]:
) -> dict[str, "VerifyKeyWithExpiry"]:
if old_signing_keys is None:
return {}
keys = {}
@@ -299,7 +299,7 @@ class KeyConfig(Config):
)
return keys
def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
def generate_files(self, config: dict[str, Any], config_dir_path: str) -> None:
if "signing_key" in config:
return
@@ -393,7 +393,7 @@ TRUSTED_KEY_SERVERS_SCHEMA = {
def _parse_key_servers(
key_servers: List[Any], federation_verify_certificates: bool
key_servers: list[Any], federation_verify_certificates: bool
) -> Iterator[TrustedKeyServer]:
try:
jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
@@ -408,7 +408,7 @@ def _parse_key_servers(
server_name = server["server_name"]
result = TrustedKeyServer(server_name=server_name)
verify_keys: Optional[Dict[str, str]] = server.get("verify_keys")
verify_keys: Optional[dict[str, str]] = server.get("verify_keys")
if verify_keys is not None:
result.verify_keys = {}
for key_id, key_base64 in verify_keys.items():

View File

@@ -26,7 +26,7 @@ import os
import sys
import threading
from string import Template
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
import yaml
from zope.interface import implementer
@@ -186,7 +186,7 @@ class LoggingConfig(Config):
help=argparse.SUPPRESS,
)
def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
def generate_files(self, config: dict[str, Any], config_dir_path: str) -> None:
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")

View File

@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Any, Dict, List, Tuple
from typing import Any
from synapse.config._base import Config, ConfigError
from synapse.types import JsonDict
@@ -29,7 +29,7 @@ class ModulesConfig(Config):
section = "modules"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.loaded_modules: List[Tuple[Any, Dict]] = []
self.loaded_modules: list[tuple[Any, dict]] = []
configured_modules = config.get("modules") or []
for i, module in enumerate(configured_modules):

View File

@@ -21,7 +21,7 @@
import importlib.resources as importlib_resources
import json
import re
from typing import Any, Dict, Iterable, List, Optional, Pattern
from typing import Any, Iterable, Optional, Pattern
from urllib import parse as urlparse
import attr
@@ -37,9 +37,9 @@ class OEmbedEndpointConfig:
# The API endpoint to fetch.
api_endpoint: str
# The patterns to match.
url_patterns: List[Pattern[str]]
url_patterns: list[Pattern[str]]
# The supported formats.
formats: Optional[List[str]]
formats: Optional[list[str]]
class OembedConfig(Config):
@@ -48,10 +48,10 @@ class OembedConfig(Config):
section = "oembed"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
oembed_config: Dict[str, Any] = config.get("oembed") or {}
oembed_config: dict[str, Any] = config.get("oembed") or {}
# A list of patterns which will be used.
self.oembed_patterns: List[OEmbedEndpointConfig] = list(
self.oembed_patterns: list[OEmbedEndpointConfig] = list(
self._parse_and_validate_providers(oembed_config)
)
@@ -92,7 +92,7 @@ class OembedConfig(Config):
)
def _parse_and_validate_provider(
self, providers: List[JsonDict], config_path: StrSequence
self, providers: list[JsonDict], config_path: StrSequence
) -> Iterable[OEmbedEndpointConfig]:
# Ensure it is the proper form.
validate_config(

View File

@@ -21,7 +21,7 @@
#
from collections import Counter
from typing import Any, Collection, Iterable, List, Mapping, Optional, Tuple, Type
from typing import Any, Collection, Iterable, Mapping, Optional
import attr
@@ -213,7 +213,7 @@ def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConf
def _parse_oidc_config_dict(
oidc_config: JsonDict, config_path: Tuple[str, ...]
oidc_config: JsonDict, config_path: tuple[str, ...]
) -> "OidcProviderConfig":
"""Take the configuration dict and parse it into an OidcProviderConfig
@@ -416,7 +416,7 @@ class OidcProviderConfig:
# Valid values are 'auto', 'always', and 'never'.
pkce_method: str
id_token_signing_alg_values_supported: Optional[List[str]]
id_token_signing_alg_values_supported: Optional[list[str]]
"""
List of the JWS signing algorithms (`alg` values) that are supported for signing the
`id_token`.
@@ -491,13 +491,13 @@ class OidcProviderConfig:
allow_existing_users: bool
# the class of the user mapping provider
user_mapping_provider_class: Type
user_mapping_provider_class: type
# the config of the user mapping provider
user_mapping_provider_config: Any
# required attributes to require in userinfo to allow login/registration
attribute_requirements: List[SsoAttributeRequirement]
attribute_requirements: list[SsoAttributeRequirement]
# Whether automatic registrations are enabled in the ODIC flow. Defaults to True
enable_registration: bool

View File

@@ -19,7 +19,7 @@
#
#
from typing import Any, List, Tuple, Type
from typing import Any
from synapse.types import JsonDict
from synapse.util.module_loader import load_module
@@ -56,7 +56,7 @@ class PasswordAuthProviderConfig(Config):
for backwards compatibility.
"""
self.password_providers: List[Tuple[Type, Any]] = []
self.password_providers: list[tuple[type, Any]] = []
providers = []
# We want to be backwards compatible with the old `ldap_config`

View File

@@ -19,7 +19,7 @@
#
#
from typing import Any, Dict, Optional, cast
from typing import Any, Optional, cast
import attr
@@ -37,9 +37,9 @@ class RatelimitSettings:
@classmethod
def parse(
cls,
config: Dict[str, Any],
config: dict[str, Any],
key: str,
defaults: Optional[Dict[str, float]] = None,
defaults: Optional[dict[str, float]] = None,
) -> "RatelimitSettings":
"""Parse config[key] as a new-style rate limiter config.
@@ -62,7 +62,7 @@ class RatelimitSettings:
# By this point we should have hit the rate limiter parameters.
# We don't actually check this though!
rl_config = cast(Dict[str, float], rl_config)
rl_config = cast(dict[str, float], rl_config)
return cls(
key=key,

View File

@@ -20,7 +20,7 @@
#
#
import argparse
from typing import Any, Dict, Optional
from typing import Any, Optional
from synapse.api.constants import RoomCreationPreset
from synapse.config._base import Config, ConfigError, read_file
@@ -266,7 +266,7 @@ class RegistrationConfig(Config):
else:
return ""
def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
def generate_files(self, config: dict[str, Any], config_dir_path: str) -> None:
# if 'registration_shared_secret_path' is specified, and the target file
# does not exist, generate it.
registration_shared_secret_path = config.get("registration_shared_secret_path")

View File

@@ -21,7 +21,7 @@
import logging
import os
from typing import Any, Dict, List, Tuple
from typing import Any
import attr
@@ -80,8 +80,8 @@ class MediaStorageProviderConfig:
def parse_thumbnail_requirements(
thumbnail_sizes: List[JsonDict],
) -> Dict[str, Tuple[ThumbnailRequirement, ...]]:
thumbnail_sizes: list[JsonDict],
) -> dict[str, tuple[ThumbnailRequirement, ...]]:
"""Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumbnailing
method, and thumbnail media type to precalculate
@@ -92,7 +92,7 @@ def parse_thumbnail_requirements(
Returns:
Dictionary mapping from media type string to list of ThumbnailRequirement.
"""
requirements: Dict[str, List[ThumbnailRequirement]] = {}
requirements: dict[str, list[ThumbnailRequirement]] = {}
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
@@ -206,7 +206,7 @@ class ContentRepositoryConfig(Config):
#
# We don't create the storage providers here as not all workers need
# them to be started.
self.media_storage_providers: List[tuple] = []
self.media_storage_providers: list[tuple] = []
for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to
@@ -298,7 +298,7 @@ class ContentRepositoryConfig(Config):
self.enable_authenticated_media = config.get("enable_authenticated_media", True)
self.media_upload_limits: List[MediaUploadLimit] = []
self.media_upload_limits: list[MediaUploadLimit] = []
for limit_config in config.get("media_upload_limits", []):
time_period_ms = self.parse_duration(limit_config["time_period"])
max_bytes = self.parse_size(limit_config["max_size"])

View File

@@ -20,7 +20,7 @@
#
import logging
from typing import Any, List, Optional
from typing import Any, Optional
import attr
@@ -119,7 +119,7 @@ class RetentionConfig(Config):
" greater than 'allowed_lifetime_max'"
)
self.retention_purge_jobs: List[RetentionPurgeJob] = []
self.retention_purge_jobs: list[RetentionPurgeJob] = []
for purge_job_config in retention_config.get("purge_jobs", []):
interval_config = purge_job_config.get("interval")

View File

@@ -20,7 +20,7 @@
#
import logging
from typing import Any, List, Set
from typing import Any
from synapse.config.sso import SsoAttributeRequirement
from synapse.types import JsonDict
@@ -160,8 +160,11 @@ class SAML2Config(Config):
)
# Get the desired saml auth response attributes from the module
# type-ignore: the provider class was already checked for having the method being called
# with the runtime checks above, which mypy is not aware of, and treats as an error
# ever since the typehint of provider class was changed from "typing.Type" to "type"
saml2_config_dict = self._default_saml_config_dict(
*self.saml2_user_mapping_provider_class.get_saml_attributes(
*self.saml2_user_mapping_provider_class.get_saml_attributes( # type: ignore[attr-defined]
self.saml2_user_mapping_provider_config
)
)
@@ -191,7 +194,7 @@ class SAML2Config(Config):
)
def _default_saml_config_dict(
self, required_attributes: Set[str], optional_attributes: Set[str]
self, required_attributes: set[str], optional_attributes: set[str]
) -> JsonDict:
"""Generate a configuration dictionary with required and optional attributes that
will be needed to process new user registration
@@ -239,7 +242,7 @@ ATTRIBUTE_REQUIREMENTS_SCHEMA = {
def _parse_attribute_requirements_def(
attribute_requirements: Any,
) -> List[SsoAttributeRequirement]:
) -> list[SsoAttributeRequirement]:
validate_config(
ATTRIBUTE_REQUIREMENTS_SCHEMA,
attribute_requirements,

View File

@@ -25,7 +25,7 @@ import logging
import os.path
import urllib.parse
from textwrap import indent
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Iterable, Optional, TypedDict, Union
from urllib.request import getproxies_environment
import attr
@@ -213,7 +213,7 @@ KNOWN_RESOURCES = {
@attr.s(frozen=True)
class HttpResourceConfig:
names: List[str] = attr.ib(
names: list[str] = attr.ib(
factory=list,
validator=attr.validators.deep_iterable(attr.validators.in_(KNOWN_RESOURCES)),
)
@@ -228,8 +228,8 @@ class HttpListenerConfig:
"""Object describing the http-specific parts of the config of a listener"""
x_forwarded: bool = False
resources: List[HttpResourceConfig] = attr.Factory(list)
additional_resources: Dict[str, dict] = attr.Factory(dict)
resources: list[HttpResourceConfig] = attr.Factory(list)
additional_resources: dict[str, dict] = attr.Factory(dict)
tag: Optional[str] = None
request_id_header: Optional[str] = None
@@ -239,7 +239,7 @@ class TCPListenerConfig:
"""Object describing the configuration of a single TCP listener."""
port: int = attr.ib(validator=attr.validators.instance_of(int))
bind_addresses: List[str] = attr.ib(validator=attr.validators.instance_of(List))
bind_addresses: list[str] = attr.ib(validator=attr.validators.instance_of(list))
type: str = attr.ib(validator=attr.validators.in_(KNOWN_LISTENER_TYPES))
tls: bool = False
@@ -344,7 +344,7 @@ class ProxyConfig:
"""
Proxy server to use for HTTPS requests.
"""
no_proxy_hosts: Optional[List[str]]
no_proxy_hosts: Optional[list[str]]
"""
List of hosts, IP addresses, or IP ranges in CIDR format which should not use the
proxy. Synapse will directly connect to these hosts.
@@ -864,11 +864,11 @@ class ServerConfig(Config):
)
# Whitelist of domain names that given next_link parameters must have
next_link_domain_whitelist: Optional[List[str]] = config.get(
next_link_domain_whitelist: Optional[list[str]] = config.get(
"next_link_domain_whitelist"
)
self.next_link_domain_whitelist: Optional[Set[str]] = None
self.next_link_domain_whitelist: Optional[set[str]] = None
if next_link_domain_whitelist is not None:
if not isinstance(next_link_domain_whitelist, list):
raise ConfigError("'next_link_domain_whitelist' must be a list")
@@ -892,7 +892,7 @@ class ServerConfig(Config):
config.get("use_account_validity_in_account_status") or False
)
self.rooms_to_exclude_from_sync: List[str] = (
self.rooms_to_exclude_from_sync: list[str] = (
config.get("exclude_rooms_from_sync") or []
)
@@ -927,7 +927,7 @@ class ServerConfig(Config):
data_dir_path: str,
server_name: str,
open_private_ports: bool,
listeners: Optional[List[dict]],
listeners: Optional[list[dict]],
**kwargs: Any,
) -> str:
_, bind_port = parse_and_validate_server_name(server_name)
@@ -1028,7 +1028,7 @@ class ServerConfig(Config):
help="Turn on the twisted telnet manhole service on the given port.",
)
def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]:
def read_gc_intervals(self, durations: Any) -> Optional[tuple[float, float, float]]:
"""Reads the three durations for the GC min interval option, returning seconds."""
if durations is None:
return None
@@ -1048,7 +1048,7 @@ class ServerConfig(Config):
def is_threepid_reserved(
reserved_threepids: List[JsonDict], threepid: JsonDict
reserved_threepids: list[JsonDict], threepid: JsonDict
) -> bool:
"""Check the threepid against the reserved threepid config
Args:
@@ -1066,8 +1066,8 @@ def is_threepid_reserved(
def read_gc_thresholds(
thresholds: Optional[List[Any]],
) -> Optional[Tuple[int, int, int]]:
thresholds: Optional[list[Any]],
) -> Optional[tuple[int, int, int]]:
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.
"""

View File

@@ -19,7 +19,7 @@
#
import logging
from typing import Any, Dict, List, Tuple
from typing import Any
from synapse.config import ConfigError
from synapse.types import JsonDict
@@ -41,7 +41,7 @@ class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
self.spam_checkers: List[Tuple[Any, Dict]] = []
self.spam_checkers: list[tuple[Any, dict]] = []
spam_checkers = config.get("spam_checker") or []
if isinstance(spam_checkers, dict):

View File

@@ -19,7 +19,7 @@
#
#
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import attr
@@ -45,7 +45,7 @@ class SsoAttributeRequirement:
attribute: str
# If neither `value` nor `one_of` is given, the attribute must simply exist.
value: Optional[str] = None
one_of: Optional[List[str]] = None
one_of: Optional[list[str]] = None
JSON_SCHEMA = {
"type": "object",
@@ -64,7 +64,7 @@ class SSOConfig(Config):
section = "sso"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
sso_config: Dict[str, Any] = config.get("sso") or {}
sso_config: dict[str, Any] = config.get("sso") or {}
# The sso-specific template_dir
self.sso_template_dir = sso_config.get("template_dir")

View File

@@ -20,7 +20,7 @@
#
import logging
from typing import Any, List, Optional, Pattern
from typing import Any, Optional, Pattern
from matrix_common.regex import glob_to_regex
@@ -84,7 +84,7 @@ class TlsConfig(Config):
fed_whitelist_entries = []
# Support globs (*) in whitelist values
self.federation_certificate_verification_whitelist: List[Pattern] = []
self.federation_certificate_verification_whitelist: list[Pattern] = []
for entry in fed_whitelist_entries:
try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))

View File

@@ -19,7 +19,7 @@
#
#
from typing import Any, List, Set
from typing import Any
from synapse.types import JsonDict
from synapse.util.check_dependencies import check_requirements
@@ -42,7 +42,7 @@ class TracerConfig(Config):
{"sampler": {"type": "const", "param": 1}, "logging": False},
)
self.force_tracing_for_users: Set[str] = set()
self.force_tracing_for_users: set[str] = set()
if not self.opentracer_enabled:
return
@@ -51,7 +51,7 @@ class TracerConfig(Config):
# The tracer is enabled so sanitize the config
self.opentracer_whitelist: List[str] = opentracing_config.get(
self.opentracer_whitelist: list[str] = opentracing_config.get(
"homeserver_whitelist", []
)
if not isinstance(self.opentracer_whitelist, list):

View File

@@ -12,7 +12,7 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Any, List, Optional
from typing import Any, Optional
from synapse.api.constants import UserTypes
from synapse.types import JsonDict
@@ -29,9 +29,9 @@ class UserTypesConfig(Config):
self.default_user_type: Optional[str] = user_types.get(
"default_user_type", None
)
self.extra_user_types: List[str] = user_types.get("extra_user_types", [])
self.extra_user_types: list[str] = user_types.get("extra_user_types", [])
all_user_types: List[str] = []
all_user_types: list[str] = []
all_user_types.extend(UserTypes.ALL_BUILTIN_USER_TYPES)
all_user_types.extend(self.extra_user_types)

View File

@@ -22,7 +22,7 @@
import argparse
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
import attr
@@ -79,7 +79,7 @@ MAIN_PROCESS_INSTANCE_MAP_NAME = "main"
logger = logging.getLogger(__name__)
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
def _instance_to_list_converter(obj: Union[str, list[str]]) -> list[str]:
"""Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings.
"""
@@ -142,39 +142,39 @@ class WriterLocations:
device_lists: The instances that write to the device list stream.
"""
events: List[str] = attr.ib(
events: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
typing: List[str] = attr.ib(
typing: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
to_device: List[str] = attr.ib(
to_device: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
account_data: List[str] = attr.ib(
account_data: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
receipts: List[str] = attr.ib(
receipts: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
presence: List[str] = attr.ib(
presence: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
push_rules: List[str] = attr.ib(
push_rules: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
device_lists: List[str] = attr.ib(
device_lists: list[str] = attr.ib(
default=[MAIN_PROCESS_INSTANCE_NAME],
converter=_instance_to_list_converter,
)
thread_subscriptions: List[str] = attr.ib(
thread_subscriptions: list[str] = attr.ib(
default=["master"],
converter=_instance_to_list_converter,
)
@@ -190,8 +190,8 @@ class OutboundFederationRestrictedTo:
locations: list of instance locations to connect to proxy via.
"""
instances: Optional[List[str]]
locations: List[InstanceLocationConfig] = attr.Factory(list)
instances: Optional[list[str]]
locations: list[InstanceLocationConfig] = attr.Factory(list)
def __contains__(self, instance: str) -> bool:
# It feels a bit dirty to return `True` if `instances` is `None`, but it makes
@@ -295,7 +295,7 @@ class WorkerConfig(Config):
# A map from instance name to host/port of their HTTP replication endpoint.
# Check if the main process is declared. The main process itself doesn't need
# this data as it would never have to talk to itself.
instance_map: Dict[str, Any] = config.get("instance_map", {})
instance_map: dict[str, Any] = config.get("instance_map", {})
if self.instance_name is not MAIN_PROCESS_INSTANCE_NAME:
# TODO: The next 3 condition blocks can be deleted after some time has
@@ -342,7 +342,7 @@ class WorkerConfig(Config):
)
# type-ignore: the expression `Union[A, B]` is not a Type[Union[A, B]] currently
self.instance_map: Dict[str, InstanceLocationConfig] = (
self.instance_map: dict[str, InstanceLocationConfig] = (
parse_and_validate_mapping(
instance_map,
InstanceLocationConfig, # type: ignore[arg-type]
@@ -481,7 +481,7 @@ class WorkerConfig(Config):
def _should_this_worker_perform_duty(
self,
config: Dict[str, Any],
config: dict[str, Any],
legacy_master_option_name: str,
legacy_worker_app_name: str,
new_option_name: str,
@@ -574,11 +574,11 @@ class WorkerConfig(Config):
def _worker_names_performing_this_duty(
self,
config: Dict[str, Any],
config: dict[str, Any],
legacy_option_name: str,
legacy_app_name: str,
modern_instance_list_name: str,
) -> List[str]:
) -> list[str]:
"""
Retrieves the names of the workers handling a given duty, by either legacy
option or instance list.

View File

@@ -23,7 +23,7 @@
import collections.abc
import hashlib
import logging
from typing import Any, Callable, Dict, Tuple
from typing import Any, Callable
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
@@ -80,8 +80,8 @@ def check_event_content_hash(
def compute_content_hash(
event_dict: Dict[str, Any], hash_algorithm: Hasher
) -> Tuple[str, bytes]:
event_dict: dict[str, Any], hash_algorithm: Hasher
) -> tuple[str, bytes]:
"""Compute the content hash of an event, which is the hash of the
unredacted event.
@@ -112,7 +112,7 @@ def compute_content_hash(
def compute_event_reference_hash(
event: EventBase, hash_algorithm: Hasher = hashlib.sha256
) -> Tuple[str, bytes]:
) -> tuple[str, bytes]:
"""Computes the event reference hash. This is the hash of the redacted
event.
@@ -139,7 +139,7 @@ def compute_event_signature(
event_dict: JsonDict,
signature_name: str,
signing_key: SigningKey,
) -> Dict[str, Dict[str, str]]:
) -> dict[str, dict[str, str]]:
"""Compute the signature of the event for the given name and key.
Args:

View File

@@ -21,7 +21,7 @@
import abc
import logging
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Iterable, Optional
import attr
from signedjson.key import (
@@ -82,7 +82,7 @@ class VerifyJsonRequest:
server_name: str
get_json_object: Callable[[], JsonDict]
minimum_valid_until_ts: int
key_ids: List[str]
key_ids: list[str]
@staticmethod
def from_json_object(
@@ -141,7 +141,7 @@ class _FetchKeyRequest:
server_name: str
minimum_valid_until_ts: int
key_ids: List[str]
key_ids: list[str]
class Keyring:
@@ -156,7 +156,7 @@ class Keyring:
if key_fetchers is None:
# Always fetch keys from the database.
mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)]
mutable_key_fetchers: list[KeyFetcher] = [StoreKeyFetcher(hs)]
# Fetch keys from configured trusted key servers, if any exist.
key_servers = hs.config.key.key_servers
if key_servers:
@@ -169,7 +169,7 @@ class Keyring:
self._key_fetchers = key_fetchers
self._fetch_keys_queue: BatchingQueue[
_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
_FetchKeyRequest, dict[str, dict[str, FetchKeyResult]]
] = BatchingQueue(
name="keyring_server",
hs=hs,
@@ -182,7 +182,7 @@ class Keyring:
# build a FetchKeyResult for each of our own keys, to shortcircuit the
# fetcher.
self._local_verify_keys: Dict[str, FetchKeyResult] = {}
self._local_verify_keys: dict[str, FetchKeyResult] = {}
for key_id, key in hs.config.key.old_signing_keys.items():
self._local_verify_keys[key_id] = FetchKeyResult(
verify_key=key, valid_until_ts=key.expired
@@ -229,8 +229,8 @@ class Keyring:
return await self.process_request(request)
def verify_json_objects_for_server(
self, server_and_json: Iterable[Tuple[str, dict, int]]
) -> List["defer.Deferred[None]"]:
self, server_and_json: Iterable[tuple[str, dict, int]]
) -> list["defer.Deferred[None]"]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
@@ -286,7 +286,7 @@ class Keyring:
Codes.UNAUTHORIZED,
)
found_keys: Dict[str, FetchKeyResult] = {}
found_keys: dict[str, FetchKeyResult] = {}
# If we are the originating server, short-circuit the key-fetch for any keys
# we already have
@@ -368,8 +368,8 @@ class Keyring:
)
async def _inner_fetch_key_requests(
self, requests: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
self, requests: list[_FetchKeyRequest]
) -> dict[str, dict[str, FetchKeyResult]]:
"""Processing function for the queue of `_FetchKeyRequest`.
Takes a list of key fetch requests, de-duplicates them and then carries out
@@ -387,7 +387,7 @@ class Keyring:
# First we need to deduplicate requests for the same key. We do this by
# taking the *maximum* requested `minimum_valid_until_ts` for each pair
# of server name/key ID.
server_to_key_to_ts: Dict[str, Dict[str, int]] = {}
server_to_key_to_ts: dict[str, dict[str, int]] = {}
for request in requests:
by_server = server_to_key_to_ts.setdefault(request.server_name, {})
for key_id in request.key_ids:
@@ -412,7 +412,7 @@ class Keyring:
# We now convert the returned list of results into a map from server
# name to key ID to FetchKeyResult, to return.
to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
to_return: dict[str, dict[str, FetchKeyResult]] = {}
for request, results in zip(deduped_requests, results_per_request):
to_return_by_server = to_return.setdefault(request.server_name, {})
for key_id, key_result in results.items():
@@ -424,7 +424,7 @@ class Keyring:
async def _inner_fetch_key_request(
self, verify_request: _FetchKeyRequest
) -> Dict[str, FetchKeyResult]:
) -> dict[str, FetchKeyResult]:
"""Attempt to fetch the given key by calling each key fetcher one by one.
If a key is found, check whether its `valid_until_ts` attribute satisfies the
@@ -445,7 +445,7 @@ class Keyring:
"""
logger.debug("Starting fetch for %s", verify_request)
found_keys: Dict[str, FetchKeyResult] = {}
found_keys: dict[str, FetchKeyResult] = {}
missing_key_ids = set(verify_request.key_ids)
for fetcher in self._key_fetchers:
@@ -499,8 +499,8 @@ class KeyFetcher(metaclass=abc.ABCMeta):
self._queue.shutdown()
async def get_keys(
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
self, server_name: str, key_ids: list[str], minimum_valid_until_ts: int
) -> dict[str, FetchKeyResult]:
results = await self._queue.add_to_queue(
_FetchKeyRequest(
server_name=server_name,
@@ -512,8 +512,8 @@ class KeyFetcher(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
self, keys_to_fetch: list[_FetchKeyRequest]
) -> dict[str, dict[str, FetchKeyResult]]:
pass
@@ -526,8 +526,8 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastores().main
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
self, keys_to_fetch: list[_FetchKeyRequest]
) -> dict[str, dict[str, FetchKeyResult]]:
key_ids_to_fetch = (
(queue_value.server_name, key_id)
for queue_value in keys_to_fetch
@@ -535,7 +535,7 @@ class StoreKeyFetcher(KeyFetcher):
)
res = await self.store.get_server_keys_json(key_ids_to_fetch)
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
keys: dict[str, dict[str, FetchKeyResult]] = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
return keys
@@ -549,7 +549,7 @@ class BaseV2KeyFetcher(KeyFetcher):
async def process_v2_response(
self, from_server: str, response_json: JsonDict, time_added_ms: int
) -> Dict[str, FetchKeyResult]:
) -> dict[str, FetchKeyResult]:
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@@ -640,11 +640,11 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.key_servers = hs.config.key.key_servers
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
self, keys_to_fetch: list[_FetchKeyRequest]
) -> dict[str, dict[str, FetchKeyResult]]:
"""see KeyFetcher._fetch_keys"""
async def get_key(key_server: TrustedKeyServer) -> Dict:
async def get_key(key_server: TrustedKeyServer) -> dict:
try:
return await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
@@ -670,7 +670,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError)
)
union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {}
union_of_keys: dict[str, dict[str, FetchKeyResult]] = {}
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
@@ -678,8 +678,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
async def get_server_verify_key_v2_indirect(
self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
) -> Dict[str, Dict[str, FetchKeyResult]]:
self, keys_to_fetch: list[_FetchKeyRequest], key_server: TrustedKeyServer
) -> dict[str, dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
@@ -731,8 +731,8 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"Response from notary server %s: %s", perspective_name, query_response
)
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
added_keys: Dict[Tuple[str, str], FetchKeyResult] = {}
keys: dict[str, dict[str, FetchKeyResult]] = {}
added_keys: dict[tuple[str, str], FetchKeyResult] = {}
time_now_ms = self.clock.time_msec()
@@ -836,8 +836,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_federation_http_client()
async def get_keys(
self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
) -> Dict[str, FetchKeyResult]:
self, server_name: str, key_ids: list[str], minimum_valid_until_ts: int
) -> dict[str, FetchKeyResult]:
results = await self._queue.add_to_queue(
_FetchKeyRequest(
server_name=server_name,
@@ -849,8 +849,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return results.get(server_name, {})
async def _fetch_keys(
self, keys_to_fetch: List[_FetchKeyRequest]
) -> Dict[str, Dict[str, FetchKeyResult]]:
self, keys_to_fetch: list[_FetchKeyRequest]
) -> dict[str, dict[str, FetchKeyResult]]:
"""
Args:
keys_to_fetch:
@@ -879,7 +879,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
async def get_server_verify_keys_v2_direct(
self, server_name: str
) -> Dict[str, FetchKeyResult]:
) -> dict[str, FetchKeyResult]:
"""
Args:

View File

@@ -26,15 +26,11 @@ import typing
from typing import (
Any,
ChainMap,
Dict,
Iterable,
List,
Mapping,
MutableMapping,
Optional,
Protocol,
Set,
Tuple,
Union,
cast,
)
@@ -91,7 +87,7 @@ class _EventSourceStore(Protocol):
redact_behaviour: EventRedactBehaviour,
get_prev_content: bool = False,
allow_rejected: bool = False,
) -> Dict[str, "EventBase"]: ...
) -> dict[str, "EventBase"]: ...
def validate_event_for_room_version(event: "EventBase") -> None:
@@ -993,7 +989,7 @@ def _check_power_levels(
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check: List[Tuple[str, Optional[str]]] = [
levels_to_check: list[tuple[str, Optional[str]]] = [
("users_default", None),
("events_default", None),
("state_default", None),
@@ -1191,7 +1187,7 @@ def _verify_third_party_invite(
return False
def get_public_keys(invite_event: "EventBase") -> List[Dict[str, Any]]:
def get_public_keys(invite_event: "EventBase") -> list[dict[str, Any]]:
public_keys = []
if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]}
@@ -1204,7 +1200,7 @@ def get_public_keys(invite_event: "EventBase") -> List[Dict[str, Any]]:
def auth_types_for_event(
room_version: RoomVersion, event: Union["EventBase", "EventBuilder"]
) -> Set[Tuple[str, str]]:
) -> set[tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.

View File

@@ -25,14 +25,10 @@ import collections.abc
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Iterable,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
@@ -94,20 +90,20 @@ class DictProperty(Generic[T]):
def __get__(
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
owner: Optional[type[_DictPropertyInstance]] = None,
) -> "DictProperty": ...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
owner: Optional[type[_DictPropertyInstance]] = None,
) -> T: ...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[Type[_DictPropertyInstance]] = None,
owner: Optional[type[_DictPropertyInstance]] = None,
) -> Union[T, "DictProperty"]:
# if the property is accessed as a class property rather than an instance
# property, return the property itself rather than the value
@@ -160,20 +156,20 @@ class DefaultDictProperty(DictProperty, Generic[T]):
def __get__(
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
owner: Optional[type[_DictPropertyInstance]] = None,
) -> "DefaultDictProperty": ...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
owner: Optional[type[_DictPropertyInstance]] = None,
) -> T: ...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[Type[_DictPropertyInstance]] = None,
owner: Optional[type[_DictPropertyInstance]] = None,
) -> Union[T, "DefaultDictProperty"]:
if instance is None:
return self
@@ -192,7 +188,7 @@ class EventBase(metaclass=abc.ABCMeta):
self,
event_dict: JsonDict,
room_version: RoomVersion,
signatures: Dict[str, Dict[str, str]],
signatures: dict[str, dict[str, str]],
unsigned: JsonDict,
internal_metadata_dict: JsonDict,
rejected_reason: Optional[str],
@@ -210,7 +206,7 @@ class EventBase(metaclass=abc.ABCMeta):
depth: DictProperty[int] = DictProperty("depth")
content: DictProperty[JsonDict] = DictProperty("content")
hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
hashes: DictProperty[dict[str, str]] = DictProperty("hashes")
origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
sender: DictProperty[str] = DictProperty("sender")
# TODO state_key should be Optional[str]. This is generally asserted in Synapse
@@ -293,13 +289,13 @@ class EventBase(metaclass=abc.ABCMeta):
def __contains__(self, field: str) -> bool:
return field in self._dict
def items(self) -> List[Tuple[str, Optional[Any]]]:
def items(self) -> list[tuple[str, Optional[Any]]]:
return list(self._dict.items())
def keys(self) -> Iterable[str]:
return self._dict.keys()
def prev_event_ids(self) -> List[str]:
def prev_event_ids(self) -> list[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@@ -457,7 +453,7 @@ class FrozenEventV2(EventBase):
def room_id(self) -> str:
return self._dict["room_id"]
def prev_event_ids(self) -> List[str]:
def prev_event_ids(self) -> list[str]:
"""Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@@ -558,7 +554,7 @@ class FrozenEventV4(FrozenEventV3):
def _event_type_from_format_version(
format_version: int,
) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
) -> type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
"""Returns the python type to use to construct an Event object for the
given event format version.
@@ -669,4 +665,4 @@ class StrippedStateEvent:
type: str
state_key: str
sender: str
content: Dict[str, Any]
content: dict[str, Any]

View File

@@ -20,7 +20,7 @@
#
import logging
from http import HTTPStatus
from typing import Any, Dict, Tuple
from typing import Any
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
@@ -146,7 +146,7 @@ class InviteAutoAccepter:
# Be careful: we convert the outer frozendict into a dict here,
# but the contents of the dict are still frozen (tuples in lieu of lists,
# etc.)
dm_map: Dict[str, Tuple[str, ...]] = dict(
dm_map: dict[str, tuple[str, ...]] = dict(
await self._api.account_data_manager.get_global(
user_id, AccountDataTypes.DIRECT
)

View File

@@ -19,7 +19,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import attr
from signedjson.types import SigningKey
@@ -125,8 +125,8 @@ class EventBuilder:
async def build(
self,
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
prev_event_ids: list[str],
auth_event_ids: Optional[list[str]],
depth: Optional[int] = None,
) -> EventBase:
"""Transform into a fully signed and hashed event
@@ -205,8 +205,8 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events: Union[StrCollection, list[tuple[str, dict[str, str]]]]
auth_events: Union[list[str], list[tuple[str, dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids)
@@ -228,7 +228,7 @@ class EventBuilder:
# the db)
depth = min(depth, MAX_DEPTH)
event_dict: Dict[str, Any] = {
event_dict: dict[str, Any] = {
"auth_events": auth_events,
"prev_events": prev_events,
"type": self.type,

View File

@@ -24,11 +24,8 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Set,
TypeVar,
Union,
)
@@ -44,10 +41,10 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
[Iterable[UserPresenceState]], Awaitable[dict[str, set[UserPresenceState]]]
]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[set[str], str]]]
logger = logging.getLogger(__name__)
@@ -98,7 +95,7 @@ def load_legacy_presence_router(hs: "HomeServer") -> None:
return run
# Register the hooks through the module API.
hooks: Dict[str, Optional[Callable[..., Any]]] = {
hooks: dict[str, Optional[Callable[..., Any]]] = {
hook: async_wrapper(getattr(presence_router, hook, None))
for hook in presence_router_methods
}
@@ -116,8 +113,8 @@ class PresenceRouter:
def __init__(self, hs: "HomeServer"):
# Initially there are no callbacks
self._get_users_for_states_callbacks: List[GET_USERS_FOR_STATES_CALLBACK] = []
self._get_interested_users_callbacks: List[GET_INTERESTED_USERS_CALLBACK] = []
self._get_users_for_states_callbacks: list[GET_USERS_FOR_STATES_CALLBACK] = []
self._get_interested_users_callbacks: list[GET_INTERESTED_USERS_CALLBACK] = []
def register_presence_router_callbacks(
self,
@@ -143,7 +140,7 @@ class PresenceRouter:
async def get_users_for_states(
self,
state_updates: Iterable[UserPresenceState],
) -> Dict[str, Set[UserPresenceState]]:
) -> dict[str, set[UserPresenceState]]:
"""
Given an iterable of user presence updates, determine where each one
needs to go.
@@ -161,7 +158,7 @@ class PresenceRouter:
# Don't include any extra destinations for presence updates
return {}
users_for_states: Dict[str, Set[UserPresenceState]] = {}
users_for_states: dict[str, set[UserPresenceState]] = {}
# run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks:
try:
@@ -174,7 +171,7 @@ class PresenceRouter:
logger.warning("Failed to run module API callback %s: %s", callback, e)
continue
if not isinstance(result, Dict):
if not isinstance(result, dict):
logger.warning(
"Wrong type returned by module API callback %s: %s, expected Dict",
callback,
@@ -183,7 +180,7 @@ class PresenceRouter:
continue
for key, new_entries in result.items():
if not isinstance(new_entries, Set):
if not isinstance(new_entries, set):
logger.warning(
"Wrong type returned by module API callback %s: %s, expected Set",
callback,
@@ -194,7 +191,7 @@ class PresenceRouter:
return users_for_states
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
async def get_interested_users(self, user_id: str) -> Union[set[str], str]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.
@@ -234,7 +231,7 @@ class PresenceRouter:
if result == PresenceRouter.ALL_USERS:
return PresenceRouter.ALL_USERS
if not isinstance(result, Set):
if not isinstance(result, set):
logger.warning(
"Wrong type returned by module API callback %s: %s, expected set",
callback,

View File

@@ -19,7 +19,7 @@
#
#
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import attr
from immutabledict import immutabledict
@@ -133,7 +133,7 @@ class EventContext(UnpersistedEventContextBase):
"""
_storage: "StorageControllers"
state_group_deltas: Dict[Tuple[int, int], StateMap[str]]
state_group_deltas: dict[tuple[int, int], StateMap[str]]
rejected: Optional[str] = None
_state_group: Optional[int] = None
state_group_before_event: Optional[int] = None
@@ -149,7 +149,7 @@ class EventContext(UnpersistedEventContextBase):
state_group_before_event: Optional[int],
state_delta_due_to_event: Optional[StateMap[str]],
partial_state: bool,
state_group_deltas: Dict[Tuple[int, int], StateMap[str]],
state_group_deltas: dict[tuple[int, int], StateMap[str]],
) -> "EventContext":
return EventContext(
storage=storage,
@@ -306,7 +306,7 @@ class EventContext(UnpersistedEventContextBase):
)
EventPersistencePair = Tuple[EventBase, EventContext]
EventPersistencePair = tuple[EventBase, EventContext]
"""
The combination of an event to be persisted and its context.
"""
@@ -365,11 +365,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
@classmethod
async def batch_persist_unpersisted_contexts(
cls,
events_and_context: List[Tuple[EventBase, "UnpersistedEventContextBase"]],
events_and_context: list[tuple[EventBase, "UnpersistedEventContextBase"]],
room_id: str,
last_known_state_group: int,
datastore: "StateGroupDataStore",
) -> List[EventPersistencePair]:
) -> list[EventPersistencePair]:
"""
Takes a list of events and their associated unpersisted contexts and persists
the unpersisted contexts, returning a list of events and persisted contexts.
@@ -472,7 +472,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
partial_state=self.partial_state,
)
def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]:
def _build_state_group_deltas(self) -> dict[tuple[int, int], StateMap]:
"""
Collect deltas between the state groups associated with this context
"""
@@ -510,8 +510,8 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
def _encode_state_group_delta(
state_group_delta: Dict[Tuple[int, int], StateMap[str]],
) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
state_group_delta: dict[tuple[int, int], StateMap[str]],
) -> list[tuple[int, int, Optional[list[tuple[str, str, str]]]]]:
if not state_group_delta:
return []
@@ -523,8 +523,8 @@ def _encode_state_group_delta(
def _decode_state_group_delta(
input: List[Tuple[int, int, List[Tuple[str, str, str]]]],
) -> Dict[Tuple[int, int], StateMap[str]]:
input: list[tuple[int, int, list[tuple[str, str, str]]]],
) -> dict[tuple[int, int], StateMap[str]]:
if not input:
return {}
@@ -539,7 +539,7 @@ def _decode_state_group_delta(
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[List[Tuple[str, str, str]]]:
) -> Optional[list[tuple[str, str, str]]]:
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can.
"""
@@ -550,7 +550,7 @@ def _encode_state_dict(
def _decode_state_dict(
input: Optional[List[Tuple[str, str, str]]],
input: Optional[list[tuple[str, str, str]]],
) -> Optional[StateMap[str]]:
"""Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:

View File

@@ -27,8 +27,6 @@ from typing import (
Awaitable,
Callable,
Collection,
Dict,
List,
Mapping,
Match,
MutableMapping,
@@ -239,7 +237,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic
return allowed_fields
def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
def _copy_field(src: JsonDict, dst: JsonDict, field: list[str]) -> None:
"""Copy the field in 'src' to 'dst'.
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
@@ -292,7 +290,7 @@ def _escape_slash(m: Match[str]) -> str:
return m.group(0)
def _split_field(field: str) -> List[str]:
def _split_field(field: str) -> list[str]:
"""
Splits strings on unescaped dots and removes escaping.
@@ -333,7 +331,7 @@ def _split_field(field: str) -> List[str]:
return result
def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
def only_fields(dictionary: JsonDict, fields: list[str]) -> JsonDict:
"""Return a new dict with only the fields in 'dictionary' which are present
in 'fields'.
@@ -419,7 +417,7 @@ class SerializeEventConfig:
# the transaction_id in the unsigned section of the event.
requester: Optional[Requester] = None
# List of event fields to include. If empty, all fields will be returned.
only_event_fields: Optional[List[str]] = None
only_event_fields: Optional[list[str]] = None
# Some events can have stripped room state stored in the `unsigned` field.
# This is required for invite and knock functionality. If this option is
# False, that state will be removed from the event before it is returned.
@@ -573,7 +571,7 @@ class EventClientSerializer:
def __init__(self, hs: "HomeServer") -> None:
self._store = hs.get_datastores().main
self._auth = hs.get_auth()
self._add_extra_fields_to_unsigned_client_event_callbacks: List[
self._add_extra_fields_to_unsigned_client_event_callbacks: list[
ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
] = []
@@ -583,7 +581,7 @@ class EventClientSerializer:
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None,
) -> JsonDict:
"""Serializes a single event.
@@ -641,7 +639,7 @@ class EventClientSerializer:
event: EventBase,
time_now: int,
config: SerializeEventConfig,
bundled_aggregations: Dict[str, "BundledAggregations"],
bundled_aggregations: dict[str, "BundledAggregations"],
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@@ -718,8 +716,8 @@ class EventClientSerializer:
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
) -> List[JsonDict]:
bundle_aggregations: Optional[dict[str, "BundledAggregations"]] = None,
) -> list[JsonDict]:
"""Serializes multiple events.
Args:
@@ -763,7 +761,7 @@ PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
def copy_and_fixup_power_levels_contents(
old_power_levels: PowerLevelsContent,
) -> Dict[str, Union[int, Dict[str, int]]]:
) -> dict[str, Union[int, dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing immutabledicts along the way.
We accept as input power level values which are strings, provided they represent an
@@ -779,11 +777,11 @@ def copy_and_fixup_power_levels_contents(
if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
power_levels: dict[str, Union[int, dict[str, int]]] = {}
for k, v in old_power_levels.items():
if isinstance(v, collections.abc.Mapping):
h: Dict[str, int] = {}
h: dict[str, int] = {}
power_levels[k] = h
for k1, v1 in v.items():
_copy_power_level_value_as_integer(v1, h, k1)

View File

@@ -19,7 +19,7 @@
#
#
import collections.abc
from typing import List, Type, Union, cast
from typing import Union, cast
import jsonschema
@@ -283,13 +283,13 @@ POWER_LEVELS_SCHEMA = {
class Mentions(RequestBodyModel):
user_ids: List[StrictStr] = Field(default_factory=list)
user_ids: list[StrictStr] = Field(default_factory=list)
room: StrictBool = False
# This could return something newer than Draft 7, but that's the current "latest"
# validator.
def _create_validator(schema: JsonDict) -> Type[jsonschema.Draft7Validator]:
def _create_validator(schema: JsonDict) -> type[jsonschema.Draft7Validator]:
validator = jsonschema.validators.validator_for(schema)
# by default jsonschema does not consider a immutabledict to be an object so

View File

@@ -20,7 +20,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Sequence
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Sequence
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -305,7 +305,7 @@ def _is_invite_via_3pid(event: EventBase) -> bool:
def parse_events_from_pdu_json(
pdus_json: Sequence[JsonDict], room_version: RoomVersion
) -> List[EventBase]:
) -> list[EventBase]:
return [
event_from_pdu_json(pdu_json, room_version)
for pdu_json in filter_pdus_for_valid_depth(pdus_json)

View File

@@ -32,13 +32,10 @@ from typing import (
Callable,
Collection,
Container,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)
@@ -120,8 +117,8 @@ class SendJoinResult:
event: EventBase
# A string giving the server the event was sent to.
origin: str
state: List[EventBase]
auth_chain: List[EventBase]
state: list[EventBase]
auth_chain: list[EventBase]
# True if 'state' elides non-critical membership events
partial_state: bool
@@ -135,7 +132,7 @@ class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.pdu_destination_tried: Dict[str, Dict[str, int]] = {}
self.pdu_destination_tried: dict[str, dict[str, int]] = {}
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -145,7 +142,7 @@ class FederationClient(FederationBase):
# Cache mapping `event_id` to a tuple of the event itself and the `pull_origin`
# (which server we pulled the event from)
self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache(
self._get_pdu_cache: ExpiringCache[str, tuple[EventBase, str]] = ExpiringCache(
cache_name="get_pdu_cache",
server_name=self.server_name,
hs=self.hs,
@@ -163,8 +160,8 @@ class FederationClient(FederationBase):
# It is a map of (room ID, suggested-only) -> the response of
# get_room_hierarchy.
self._get_room_hierarchy_cache: ExpiringCache[
Tuple[str, bool],
Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]],
tuple[str, bool],
tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]],
] = ExpiringCache(
cache_name="get_room_hierarchy_cache",
server_name=self.server_name,
@@ -265,7 +262,7 @@ class FederationClient(FederationBase):
self,
user: UserID,
destination: str,
query: Dict[str, Dict[str, Dict[str, int]]],
query: dict[str, dict[str, dict[str, int]]],
timeout: Optional[int],
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
@@ -285,8 +282,8 @@ class FederationClient(FederationBase):
# Convert the query with counts into a stable and unstable query and check
# if attempting to claim more than 1 OTK.
content: Dict[str, Dict[str, str]] = {}
unstable_content: Dict[str, Dict[str, List[str]]] = {}
content: dict[str, dict[str, str]] = {}
unstable_content: dict[str, dict[str, list[str]]] = {}
use_unstable = False
for user_id, one_time_keys in query.items():
for device_id, algorithms in one_time_keys.items():
@@ -337,7 +334,7 @@ class FederationClient(FederationBase):
@tag_args
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Collection[str]
) -> Optional[List[EventBase]]:
) -> Optional[list[EventBase]]:
"""Requests some more historic PDUs for the given room from the
given destination server.
@@ -662,7 +659,7 @@ class FederationClient(FederationBase):
@tag_args
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> Tuple[List[str], List[str]]:
) -> tuple[list[str], list[str]]:
"""Calls the /state_ids endpoint to fetch the state at a particular point
in the room, and the auth events for the given event
@@ -711,7 +708,7 @@ class FederationClient(FederationBase):
room_id: str,
event_id: str,
room_version: RoomVersion,
) -> Tuple[List[EventBase], List[EventBase]]:
) -> tuple[list[EventBase], list[EventBase]]:
"""Calls the /state endpoint to fetch the state at a particular point
in the room.
@@ -772,7 +769,7 @@ class FederationClient(FederationBase):
origin: str,
pdus: Collection[EventBase],
room_version: RoomVersion,
) -> List[EventBase]:
) -> list[EventBase]:
"""
Checks the signatures and hashes of a list of pulled events we got from
federation and records any signature failures as failed pull attempts.
@@ -806,7 +803,7 @@ class FederationClient(FederationBase):
# We limit how many PDUs we check at once, as if we try to do hundreds
# of thousands of PDUs at once we see large memory spikes.
valid_pdus: List[EventBase] = []
valid_pdus: list[EventBase] = []
async def _record_failure_callback(event: EventBase, cause: str) -> None:
await self.store.record_event_failed_pull_attempt(
@@ -916,7 +913,7 @@ class FederationClient(FederationBase):
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> List[EventBase]:
) -> list[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id)
@@ -1050,7 +1047,7 @@ class FederationClient(FederationBase):
membership: str,
content: dict,
params: Optional[Mapping[str, Union[str, Iterable[str]]]],
) -> Tuple[str, EventBase, RoomVersion]:
) -> tuple[str, EventBase, RoomVersion]:
"""
Creates an m.room.member event, with context, without participating in the room.
@@ -1092,7 +1089,7 @@ class FederationClient(FederationBase):
% (membership, ",".join(valid_memberships))
)
async def send_request(destination: str) -> Tuple[str, EventBase, RoomVersion]:
async def send_request(destination: str) -> tuple[str, EventBase, RoomVersion]:
ret = await self.transport_layer.make_membership_event(
destination, room_id, user_id, membership, params
)
@@ -1237,7 +1234,7 @@ class FederationClient(FederationBase):
# We now go and check the signatures and hashes for the event. Note
# that we limit how many events we process at a time to keep the
# memory overhead from exploding.
valid_pdus_map: Dict[str, EventBase] = {}
valid_pdus_map: dict[str, EventBase] = {}
async def _execute(pdu: EventBase) -> None:
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
@@ -1507,7 +1504,7 @@ class FederationClient(FederationBase):
# content.
return resp[1]
async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
async def send_knock(self, destinations: list[str], pdu: EventBase) -> JsonDict:
"""Attempts to send a knock event to a given list of servers. Iterates
through the list until one attempt succeeds.
@@ -1568,7 +1565,7 @@ class FederationClient(FederationBase):
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
search_filter: Optional[dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
) -> JsonDict:
@@ -1612,7 +1609,7 @@ class FederationClient(FederationBase):
limit: int,
min_depth: int,
timeout: int,
) -> List[EventBase]:
) -> list[EventBase]:
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
@@ -1718,7 +1715,7 @@ class FederationClient(FederationBase):
destinations: Iterable[str],
room_id: str,
suggested_only: bool,
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
) -> tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
"""
Call other servers to get a hierarchy of the given room.
@@ -1749,7 +1746,7 @@ class FederationClient(FederationBase):
async def send_request(
destination: str,
) -> Tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
) -> tuple[JsonDict, Sequence[JsonDict], Sequence[JsonDict], Sequence[str]]:
try:
res = await self.transport_layer.get_room_hierarchy(
destination=destination,
@@ -1924,8 +1921,8 @@ class FederationClient(FederationBase):
raise InvalidResponseError(str(e))
async def get_account_status(
self, destination: str, user_ids: List[str]
) -> Tuple[JsonDict, List[str]]:
self, destination: str, user_ids: list[str]
) -> tuple[JsonDict, list[str]]:
"""Retrieves account statuses for a given list of users on a given remote
homeserver.
@@ -1991,8 +1988,8 @@ class FederationClient(FederationBase):
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Union[
Tuple[int, Dict[bytes, List[bytes]], bytes],
Tuple[int, Dict[bytes, List[bytes]]],
tuple[int, dict[bytes, list[bytes]], bytes],
tuple[int, dict[bytes, list[bytes]]],
]:
try:
return await self.transport_layer.federation_download_media(
@@ -2036,7 +2033,7 @@ class FederationClient(FederationBase):
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
) -> tuple[int, dict[bytes, list[bytes]]]:
try:
return await self.transport_layer.download_media_v3(
destination,

View File

@@ -27,11 +27,8 @@ from typing import (
Awaitable,
Callable,
Collection,
Dict,
List,
Mapping,
Optional,
Tuple,
Union,
)
@@ -163,10 +160,10 @@ class FederationServer(FederationBase):
# origins that we are currently processing a transaction from.
# a dict from origin to txn id.
self._active_transactions: Dict[str, str] = {}
self._active_transactions: dict[str, str] = {}
# We cache results for transaction with the same ID
self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
self._transaction_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache(
clock=hs.get_clock(),
name="fed_txn_handler",
server_name=self.server_name,
@@ -179,7 +176,7 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache: ResponseCache[Tuple[str, Optional[str]]] = (
self._state_resp_cache: ResponseCache[tuple[str, Optional[str]]] = (
ResponseCache(
clock=hs.get_clock(),
name="state_resp",
@@ -187,7 +184,7 @@ class FederationServer(FederationBase):
timeout_ms=30000,
)
)
self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache(
clock=hs.get_clock(),
name="state_ids_resp",
server_name=self.server_name,
@@ -236,8 +233,8 @@ class FederationServer(FederationBase):
await self._clock.sleep(random.uniform(0, 0.1))
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
) -> Tuple[int, Dict[str, Any]]:
self, origin: str, room_id: str, versions: list[str], limit: int
) -> tuple[int, dict[str, Any]]:
async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -252,7 +249,7 @@ class FederationServer(FederationBase):
async def on_timestamp_to_event_request(
self, origin: str, room_id: str, timestamp: int, direction: Direction
) -> Tuple[int, Dict[str, Any]]:
) -> tuple[int, dict[str, Any]]:
"""When we receive a federated `/timestamp_to_event` request,
handle all of the logic for validating and fetching the event.
@@ -298,7 +295,7 @@ class FederationServer(FederationBase):
transaction_id: str,
destination: str,
transaction_data: JsonDict,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
# If we receive a transaction we should make sure that kick off handling
# any old events in the staging area.
if not self._started_handling_of_staged_events:
@@ -365,7 +362,7 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
) -> tuple[int, dict[str, Any]]:
# CRITICAL SECTION: the first thing we must do (before awaiting) is
# add an entry to _active_transactions.
assert origin not in self._active_transactions
@@ -381,7 +378,7 @@ class FederationServer(FederationBase):
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
) -> tuple[int, dict[str, Any]]:
"""Process an incoming transaction and return the HTTP response
Args:
@@ -429,7 +426,7 @@ class FederationServer(FederationBase):
async def _handle_pdus_in_txn(
self, origin: str, transaction: Transaction, request_time: int
) -> Dict[str, dict]:
) -> dict[str, dict]:
"""Process the PDUs in a received transaction.
Args:
@@ -448,7 +445,7 @@ class FederationServer(FederationBase):
origin_host, _ = parse_server_name(origin)
pdus_by_room: Dict[str, List[EventBase]] = {}
pdus_by_room: dict[str, list[EventBase]] = {}
newest_pdu_ts = 0
@@ -601,7 +598,7 @@ class FederationServer(FederationBase):
async def on_room_state_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
await self._event_auth_handler.assert_host_in_room(room_id, origin)
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -625,7 +622,7 @@ class FederationServer(FederationBase):
@tag_args
async def on_state_ids_request(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
if not event_id:
raise NotImplementedError("Specify an event")
@@ -653,7 +650,7 @@ class FederationServer(FederationBase):
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
) -> Dict[str, list]:
) -> dict[str, list]:
pdus: Collection[EventBase]
event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
pdus = await self.store.get_events_as_list(event_ids)
@@ -669,7 +666,7 @@ class FederationServer(FederationBase):
async def on_pdu_request(
self, origin: str, event_id: str
) -> Tuple[int, Union[JsonDict, str]]:
) -> tuple[int, Union[JsonDict, str]]:
pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
@@ -678,8 +675,8 @@ class FederationServer(FederationBase):
return 404, ""
async def on_query_request(
self, query_type: str, args: Dict[str, str]
) -> Tuple[int, Dict[str, Any]]:
self, query_type: str, args: dict[str, str]
) -> tuple[int, dict[str, Any]]:
received_queries_counter.labels(
type=query_type,
**{SERVER_NAME_LABEL: self.server_name},
@@ -688,8 +685,8 @@ class FederationServer(FederationBase):
return 200, resp
async def on_make_join_request(
self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
) -> Dict[str, Any]:
self, origin: str, room_id: str, user_id: str, supported_versions: list[str]
) -> dict[str, Any]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -714,7 +711,7 @@ class FederationServer(FederationBase):
async def on_invite_request(
self, origin: str, content: JsonDict, room_version_id: str
) -> Dict[str, Any]:
) -> dict[str, Any]:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
raise SynapseError(
@@ -748,7 +745,7 @@ class FederationServer(FederationBase):
content: JsonDict,
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
set_tag(
SynapseTags.SEND_JOIN_RESPONSE_IS_PARTIAL_STATE,
caller_supports_partial_state,
@@ -809,7 +806,7 @@ class FederationServer(FederationBase):
async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> Dict[str, Any]:
) -> dict[str, Any]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
@@ -826,7 +823,7 @@ class FederationServer(FederationBase):
return {}
async def on_make_knock_request(
self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
self, origin: str, room_id: str, user_id: str, supported_versions: list[str]
) -> JsonDict:
"""We've received a /make_knock/ request, so we create a partial knock
event for the room and hand that back, along with the room version, to the knocking
@@ -884,7 +881,7 @@ class FederationServer(FederationBase):
origin: str,
content: JsonDict,
room_id: str,
) -> Dict[str, List[JsonDict]]:
) -> dict[str, list[JsonDict]]:
"""
We have received a knock event for a room. Verify and send the event into the room
on the knocking homeserver's behalf. Then reply with some stripped state from the
@@ -1034,7 +1031,7 @@ class FederationServer(FederationBase):
async def on_event_auth(
self, origin: str, room_id: str, event_id: str
) -> Tuple[int, Dict[str, Any]]:
) -> tuple[int, dict[str, Any]]:
async with self._server_linearizer.queue((origin, room_id)):
await self._event_auth_handler.assert_host_in_room(room_id, origin)
origin_host, _ = parse_server_name(origin)
@@ -1046,20 +1043,20 @@ class FederationServer(FederationBase):
return 200, res
async def on_query_client_keys(
self, origin: str, content: Dict[str, str]
) -> Tuple[int, Dict[str, Any]]:
self, origin: str, content: dict[str, str]
) -> tuple[int, dict[str, Any]]:
return await self.on_query_request("client_keys", content)
async def on_query_user_devices(
self, origin: str, user_id: str
) -> Tuple[int, Dict[str, Any]]:
) -> tuple[int, dict[str, Any]]:
keys = await self.device_handler.on_federation_query_user_devices(user_id)
return 200, keys
@trace
async def on_claim_client_keys(
self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> Dict[str, Any]:
self, query: list[tuple[str, str, str, int]], always_include_fallback_keys: bool
) -> dict[str, Any]:
if any(
not self.hs.is_mine(UserID.from_string(user_id))
for user_id, _, _, _ in query
@@ -1071,7 +1068,7 @@ class FederationServer(FederationBase):
query, always_include_fallback_keys=always_include_fallback_keys
)
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
json_result: dict[str, dict[str, dict[str, JsonDict]]] = {}
for result in results:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
@@ -1098,10 +1095,10 @@ class FederationServer(FederationBase):
self,
origin: str,
room_id: str,
earliest_events: List[str],
latest_events: List[str],
earliest_events: list[str],
latest_events: list[str],
limit: int,
) -> Dict[str, list]:
) -> dict[str, list]:
async with self._server_linearizer.queue((origin, room_id)):
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -1133,7 +1130,7 @@ class FederationServer(FederationBase):
ts_now_ms = self._clock.time_msec()
return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
def _transaction_dict_from_pdus(self, pdu_list: list[EventBase]) -> JsonDict:
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
@@ -1208,7 +1205,7 @@ class FederationServer(FederationBase):
async def _get_next_nonspam_staged_event_for_room(
self, room_id: str, room_version: RoomVersion
) -> Optional[Tuple[str, EventBase]]:
) -> Optional[tuple[str, EventBase]]:
"""Fetch the first non-spam event from staging queue.
Args:
@@ -1363,13 +1360,13 @@ class FederationServer(FederationBase):
lock = new_lock
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
self, sender_user_id: str, target_user_id: str, room_id: str, signed: dict
) -> None:
await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
async def on_exchange_third_party_invite_request(self, event_dict: dict) -> None:
await self.handler.on_exchange_third_party_invite_request(event_dict)
async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
@@ -1407,13 +1404,13 @@ class FederationHandlerRegistry:
# the case.
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {}
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
self.edu_handlers: dict[str, Callable[[str, dict], Awaitable[None]]] = {}
self.query_handlers: dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
# EDU received.
self._edu_type_to_instance: Dict[str, List[str]] = {}
self._edu_type_to_instance: dict[str, list[str]] = {}
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
@@ -1455,7 +1452,7 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
def register_instances_for_edu(
self, edu_type: str, instance_names: List[str]
self, edu_type: str, instance_names: list[str]
) -> None:
"""Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names

View File

@@ -27,7 +27,7 @@ These actions are mostly only used by the :py:mod:`.replication` module.
"""
import logging
from typing import Optional, Tuple
from typing import Optional
from synapse.federation.units import Transaction
from synapse.storage.databases.main import DataStore
@@ -44,7 +44,7 @@ class TransactionActions:
async def have_responded(
self, origin: str, transaction: Transaction
) -> Optional[Tuple[int, JsonDict]]:
) -> Optional[tuple[int, JsonDict]]:
"""Have we already responded to a transaction with the same id and
origin?

View File

@@ -40,14 +40,10 @@ import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
Dict,
Hashable,
Iterable,
List,
Optional,
Sized,
Tuple,
Type,
)
import attr
@@ -77,7 +73,7 @@ class QueueNames(str, Enum):
PRESENCE_DESTINATIONS = "presence_destinations"
queue_name_to_gauge_map: Dict[QueueNames, LaterGauge] = {}
queue_name_to_gauge_map: dict[QueueNames, LaterGauge] = {}
for queue_name in QueueNames:
queue_name_to_gauge_map[queue_name] = LaterGauge(
@@ -100,23 +96,23 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# We may have multiple federation sender instances, so we need to track
# their positions separately.
self._sender_instances = hs.config.worker.federation_shard_config.instances
self._sender_positions: Dict[str, int] = {}
self._sender_positions: dict[str, int] = {}
# Pending presence map user_id -> UserPresenceState
self.presence_map: Dict[str, UserPresenceState] = {}
self.presence_map: dict[str, UserPresenceState] = {}
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
self.presence_destinations: SortedDict[int, Tuple[str, Iterable[str]]] = (
self.presence_destinations: SortedDict[int, tuple[str, Iterable[str]]] = (
SortedDict()
)
# (destination, key) -> EDU
self.keyed_edu: Dict[Tuple[str, tuple], Edu] = {}
self.keyed_edu: dict[tuple[str, tuple], Edu] = {}
# stream position -> (destination, key)
self.keyed_edu_changed: SortedDict[int, Tuple[str, tuple]] = SortedDict()
self.keyed_edu_changed: SortedDict[int, tuple[str, tuple]] = SortedDict()
self.edus: SortedDict[int, Edu] = SortedDict()
@@ -295,7 +291,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
async def get_replication_rows(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
) -> tuple[list[tuple[int, tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
@@ -318,7 +314,7 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
rows: List[Tuple[int, BaseFederationRow]] = []
rows: list[tuple[int, BaseFederationRow]] = []
# Fetch presence to send to destinations
i = self.presence_destinations.bisect_right(from_token)
@@ -413,7 +409,7 @@ class BaseFederationRow:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PresenceDestinationsRow(BaseFederationRow):
state: UserPresenceState
destinations: List[str]
destinations: list[str]
TypeId = "pd"
@@ -436,7 +432,7 @@ class KeyedEduRow(BaseFederationRow):
typing EDUs clobber based on room_id.
"""
key: Tuple[str, ...] # the edu key passed to send_edu
key: tuple[str, ...] # the edu key passed to send_edu
edu: Edu
TypeId = "k"
@@ -471,7 +467,7 @@ class EduRow(BaseFederationRow):
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
_rowtypes: Tuple[Type[BaseFederationRow], ...] = (
_rowtypes: tuple[type[BaseFederationRow], ...] = (
PresenceDestinationsRow,
KeyedEduRow,
EduRow,
@@ -483,16 +479,16 @@ TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ParsedFederationStreamData:
# list of tuples of UserPresenceState and destinations
presence_destinations: List[Tuple[UserPresenceState, List[str]]]
presence_destinations: list[tuple[UserPresenceState, list[str]]]
# dict of destination -> { key -> Edu }
keyed_edus: Dict[str, Dict[Tuple[str, ...], Edu]]
keyed_edus: dict[str, dict[tuple[str, ...], Edu]]
# dict of destination -> [Edu]
edus: Dict[str, List[Edu]]
edus: dict[str, list[Edu]]
async def process_rows_for_federation(
transaction_queue: FederationSender,
rows: List[FederationStream.FederationStreamRow],
rows: list[FederationStream.FederationStreamRow],
) -> None:
"""Parse a list of rows from the federation stream and put them in the
transaction queue ready for sending to the relevant homeservers.

View File

@@ -135,13 +135,10 @@ from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Hashable,
Iterable,
List,
Literal,
Optional,
Tuple,
)
import attr
@@ -312,7 +309,7 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def get_replication_rows(
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
) -> tuple[list[tuple[int, tuple]], int, bool]:
raise NotImplementedError()
@@ -420,7 +417,7 @@ class FederationSender(AbstractFederationSender):
self._federation_shard_config = hs.config.worker.federation_shard_config
# map from destination to PerDestinationQueue
self._per_destination_queues: Dict[str, PerDestinationQueue] = {}
self._per_destination_queues: dict[str, PerDestinationQueue] = {}
transaction_queue_pending_destinations_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
@@ -724,7 +721,7 @@ class FederationSender(AbstractFederationSender):
**{SERVER_NAME_LABEL: self.server_name},
).observe((now - ts) / 1000)
async def handle_room_events(events: List[EventBase]) -> None:
async def handle_room_events(events: list[EventBase]) -> None:
logger.debug(
"Handling %i events in room %s", len(events), events[0].room_id
)
@@ -736,7 +733,7 @@ class FederationSender(AbstractFederationSender):
for event in events:
await handle_event(event)
events_by_room: Dict[str, List[EventBase]] = {}
events_by_room: dict[str, list[EventBase]] = {}
for event_id in event_ids:
# `event_entries` is unsorted, so we have to iterate over `event_ids`
@@ -1124,7 +1121,7 @@ class FederationSender(AbstractFederationSender):
@staticmethod
async def get_replication_rows(
instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
) -> tuple[list[tuple[int, tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return [], 0, False

View File

@@ -23,7 +23,7 @@ import datetime
import logging
from collections import OrderedDict
from types import TracebackType
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Hashable, Iterable, Optional
import attr
from prometheus_client import Counter
@@ -145,16 +145,16 @@ class PerDestinationQueue:
self._last_successful_stream_ordering: Optional[int] = None
# a queue of pending PDUs
self._pending_pdus: List[EventBase] = []
self._pending_pdus: list[EventBase] = []
# XXX this is never actually used: see
# https://github.com/matrix-org/synapse/issues/7549
self._pending_edus: List[Edu] = []
self._pending_edus: list[Edu] = []
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
self._pending_edus_keyed: Dict[Tuple[str, Hashable], Edu] = {}
self._pending_edus_keyed: dict[tuple[str, Hashable], Edu] = {}
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
@@ -164,7 +164,7 @@ class PerDestinationQueue:
#
# Each receipt can only have a single receipt per
# (room ID, receipt type, user ID, thread ID) tuple.
self._pending_receipt_edus: List[Dict[str, Dict[str, Dict[str, dict]]]] = []
self._pending_receipt_edus: list[dict[str, dict[str, dict[str, dict]]]] = []
# stream_id of last successfully sent to-device message.
# NB: may be a long or an int.
@@ -340,7 +340,7 @@ class PerDestinationQueue:
)
async def _transaction_transmission_loop(self) -> None:
pending_pdus: List[EventBase] = []
pending_pdus: list[EventBase] = []
try:
self.transmission_loop_running = True
# This will throw if we wouldn't retry. We do this here so we fail
@@ -665,12 +665,12 @@ class PerDestinationQueue:
if not self._pending_receipt_edus:
self._rrs_pending_flush = False
def _pop_pending_edus(self, limit: int) -> List[Edu]:
def _pop_pending_edus(self, limit: int) -> list[Edu]:
pending_edus = self._pending_edus
pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:]
return pending_edus
async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]:
async def _get_device_update_edus(self, limit: int) -> tuple[list[Edu], int]:
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
@@ -691,7 +691,7 @@ class PerDestinationQueue:
return edus, now_stream_id
async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
async def _get_to_device_message_edus(self, limit: int) -> tuple[list[Edu], int]:
last_device_stream_id = self._last_device_stream_id
to_device_stream_id = self._store.get_to_device_stream_token()
contents, stream_id = await self._store.get_new_device_msgs_for_remote(
@@ -745,9 +745,9 @@ class _TransactionQueueManager:
_device_stream_id: Optional[int] = None
_device_list_id: Optional[int] = None
_last_stream_ordering: Optional[int] = None
_pdus: List[EventBase] = attr.Factory(list)
_pdus: list[EventBase] = attr.Factory(list)
async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
async def __aenter__(self) -> tuple[list[EventBase], list[Edu]]:
# First we calculate the EDUs we want to send, if any.
# There's a maximum number of EDUs that can be sent with a transaction,
@@ -767,7 +767,7 @@ class _TransactionQueueManager:
if self.queue._pending_presence:
# Only send max 50 presence entries in the EDU, to bound the amount
# of data we're sending.
presence_to_add: List[JsonDict] = []
presence_to_add: list[JsonDict] = []
while (
self.queue._pending_presence
and len(presence_to_add) < MAX_PRESENCE_STATES_PER_EDU
@@ -845,7 +845,7 @@ class _TransactionQueueManager:
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:

View File

@@ -18,7 +18,7 @@
#
#
import logging
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING
from prometheus_client import Gauge
@@ -82,8 +82,8 @@ class TransactionManager:
async def send_new_transaction(
self,
destination: str,
pdus: List[EventBase],
edus: List[Edu],
pdus: list[EventBase],
edus: list[Edu],
) -> None:
"""
Args:

View File

@@ -28,13 +28,10 @@ from typing import (
BinaryIO,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
@@ -238,7 +235,7 @@ class TransportLayerClient:
async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, List]:
) -> Union[JsonDict, list]:
"""
Calls a remote federating server at `destination` asking for their
closest event to the given timestamp in the given direction.
@@ -428,7 +425,7 @@ class TransportLayerClient:
omit_members: bool,
) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
query_params: Dict[str, str] = {}
query_params: dict[str, str] = {}
# lazy-load state on join
query_params["omit_members"] = "true" if omit_members else "false"
@@ -442,7 +439,7 @@ class TransportLayerClient:
async def send_leave_v1(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
return await self.client.put_json(
@@ -508,7 +505,7 @@ class TransportLayerClient:
async def send_invite_v1(
self, destination: str, room_id: str, event_id: str, content: JsonDict
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
return await self.client.put_json(
@@ -533,7 +530,7 @@ class TransportLayerClient:
remote_server: str,
limit: Optional[int] = None,
since_token: Optional[str] = None,
search_filter: Optional[Dict] = None,
search_filter: Optional[dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
) -> JsonDict:
@@ -546,7 +543,7 @@ class TransportLayerClient:
if search_filter:
# this uses MSC2197 (Search Filtering over Federation)
data: Dict[str, Any] = {"include_all_networks": include_all_networks}
data: dict[str, Any] = {"include_all_networks": include_all_networks}
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
@@ -570,7 +567,7 @@ class TransportLayerClient:
)
raise
else:
args: Dict[str, Union[str, Iterable[str]]] = {
args: dict[str, Union[str, Iterable[str]]] = {
"include_all_networks": "true" if include_all_networks else "false"
}
if third_party_instance_id:
@@ -854,7 +851,7 @@ class TransportLayerClient:
)
async def get_account_status(
self, destination: str, user_ids: List[str]
self, destination: str, user_ids: list[str]
) -> JsonDict:
"""
Args:
@@ -878,7 +875,7 @@ class TransportLayerClient:
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
) -> tuple[int, dict[bytes, list[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
return await self.client.get_file(
destination,
@@ -905,7 +902,7 @@ class TransportLayerClient:
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]:
) -> tuple[int, dict[bytes, list[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
return await self.client.get_file(
destination,
@@ -936,7 +933,7 @@ class TransportLayerClient:
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
) -> tuple[int, dict[bytes, list[bytes]], bytes]:
path = f"/_matrix/federation/v1/media/download/{media_id}"
return await self.client.federation_get_file(
destination,
@@ -993,9 +990,9 @@ class SendJoinResponse:
"""The parsed response of a `/send_join` request."""
# The list of auth events from the /send_join response.
auth_events: List[EventBase]
auth_events: list[EventBase]
# The list of state from the /send_join response.
state: List[EventBase]
state: list[EventBase]
# The raw join event from the /send_join response.
event_dict: JsonDict
# The parsed join event from the /send_join response. This will be None if
@@ -1006,19 +1003,19 @@ class SendJoinResponse:
members_omitted: bool = False
# List of servers in the room
servers_in_room: Optional[List[str]] = None
servers_in_room: Optional[list[str]] = None
@attr.s(slots=True, auto_attribs=True)
class StateRequestResponse:
"""The parsed response of a `/state` request."""
auth_events: List[EventBase]
state: List[EventBase]
auth_events: list[EventBase]
state: list[EventBase]
@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
def _event_parser(event_dict: JsonDict) -> Generator[None, tuple[str, Any], None]:
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
to add them to a given dictionary.
"""
@@ -1030,7 +1027,7 @@ def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None
@ijson.coroutine
def _event_list_parser(
room_version: RoomVersion, events: List[EventBase]
room_version: RoomVersion, events: list[EventBase]
) -> Generator[None, JsonDict, None]:
"""Helper function for use with `ijson.items_coro` to parse an array of
events and add them to the given list.
@@ -1086,7 +1083,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
def __init__(self, room_version: RoomVersion, v1_api: bool):
self._response = SendJoinResponse([], [], event_dict={})
self._room_version = room_version
self._coros: List[Generator[None, bytes, None]] = []
self._coros: list[Generator[None, bytes, None]] = []
# The V1 API has the shape of `[200, {...}]`, which we handle by
# prefixing with `item.*`.
@@ -1159,7 +1156,7 @@ class _StateParser(ByteParser[StateRequestResponse]):
def __init__(self, room_version: RoomVersion):
self._response = StateRequestResponse([], [])
self._room_version = room_version
self._coros: List[Generator[None, bytes, None]] = [
self._coros: list[Generator[None, bytes, None]] = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
"pdus.item",

View File

@@ -20,7 +20,7 @@
#
#
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional, Tuple, Type
from typing import TYPE_CHECKING, Iterable, Literal, Optional
from synapse.api.errors import FederationDeniedError, SynapseError
from synapse.federation.transport.server._base import (
@@ -52,7 +52,7 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
def __init__(self, hs: "HomeServer", servlet_groups: Optional[List[str]] = None):
def __init__(self, hs: "HomeServer", servlet_groups: Optional[list[str]] = None):
"""Initialize the TransportLayerServer
Will by default register all servlets. For custom behaviour, pass in
@@ -130,8 +130,8 @@ class PublicRoomList(BaseFederationServlet):
self.allow_access = hs.config.server.allow_public_rooms_over_federation
async def on_GET(
self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
self, origin: str, content: Literal[None], query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]:
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -164,8 +164,8 @@ class PublicRoomList(BaseFederationServlet):
return 200, data
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
self, origin: str, content: JsonDict, query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]:
# This implements MSC2197 (Search Filtering over Federation)
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -242,8 +242,8 @@ class OpenIdUserInfo(BaseFederationServlet):
self,
origin: Optional[str],
content: Literal[None],
query: Dict[bytes, List[bytes]],
) -> Tuple[int, JsonDict]:
query: dict[bytes, list[bytes]],
) -> tuple[int, JsonDict]:
token = parse_string_from_args(query, "access_token")
if token is None:
return (
@@ -265,7 +265,7 @@ class OpenIdUserInfo(BaseFederationServlet):
return 200, {"sub": user_id}
SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = {
SERVLET_GROUPS: dict[str, Iterable[type[BaseFederationServlet]]] = {
"federation": FEDERATION_SERVLET_CLASSES,
"room_list": (PublicRoomList,),
"openid": (OpenIdUserInfo,),

View File

@@ -24,7 +24,7 @@ import logging
import re
import time
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Tuple, cast
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, cast
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.urls import FEDERATION_V1_PREFIX
@@ -165,7 +165,7 @@ class Authenticator:
logger.exception("Error resetting retry timings on %s", origin)
def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str]]:
def _parse_auth_header(header_bytes: bytes) -> tuple[str, str, str, Optional[str]]:
"""Parse an X-Matrix auth header
Args:
@@ -185,7 +185,7 @@ def _parse_auth_header(header_bytes: bytes) -> Tuple[str, str, str, Optional[str
rf"{space_or_tab}*,{space_or_tab}*",
re.split(r"^X-Matrix +", header_str, maxsplit=1)[1],
)
param_dict: Dict[str, str] = {
param_dict: dict[str, str] = {
k.lower(): v for k, v in [param.split("=", maxsplit=1) for param in params]
}
@@ -252,7 +252,7 @@ class BaseFederationServlet:
components as specified in the path match regexp.
Returns:
Optional[Tuple[int, object]]: either (response code, response object) to
Optional[tuple[int, object]]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
Raises:
@@ -282,14 +282,14 @@ class BaseFederationServlet:
self.ratelimiter = ratelimiter
self.server_name = server_name
def _wrap(self, func: Callable[..., Awaitable[Tuple[int, Any]]]) -> ServletCallback:
def _wrap(self, func: Callable[..., Awaitable[tuple[int, Any]]]) -> ServletCallback:
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@functools.wraps(func)
async def new_func(
request: SynapseRequest, *args: Any, **kwargs: str
) -> Optional[Tuple[int, Any]]:
) -> Optional[tuple[int, Any]]:
"""A callback which can be passed to HttpServer.RegisterPaths
Args:

View File

@@ -22,14 +22,10 @@ import logging
from collections import Counter
from typing import (
TYPE_CHECKING,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
@@ -93,9 +89,9 @@ class FederationSendServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
transaction_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
"""Called on PUT /send/<transaction_id>/
Args:
@@ -158,9 +154,9 @@ class FederationEventServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
event_id: str,
) -> Tuple[int, Union[JsonDict, str]]:
) -> tuple[int, Union[JsonDict, str]]:
return await self.handler.on_pdu_request(origin, event_id)
@@ -173,9 +169,9 @@ class FederationStateV1Servlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
return await self.handler.on_room_state_request(
origin,
room_id,
@@ -191,9 +187,9 @@ class FederationStateIdsServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
return await self.handler.on_state_ids_request(
origin,
room_id,
@@ -209,9 +205,9 @@ class FederationBackfillServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
@@ -248,9 +244,9 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
timestamp = parse_integer_from_args(query, "ts", required=True)
direction_str = parse_string_from_args(
query, "dir", allowed_values=["f", "b"], required=True
@@ -271,9 +267,9 @@ class FederationQueryServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
query_type: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}
args["origin"] = origin
return await self.handler.on_query_request(query_type, args)
@@ -287,10 +283,10 @@ class FederationMakeJoinServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
"""
Args:
origin: The authenticated server_name of the calling server
@@ -323,10 +319,10 @@ class FederationMakeLeaveServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
result = await self.handler.on_make_leave_request(origin, room_id, user_id)
return 200, result
@@ -339,10 +335,10 @@ class FederationV1SendLeaveServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, Tuple[int, JsonDict]]:
) -> tuple[int, tuple[int, JsonDict]]:
result = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, (200, result)
@@ -357,10 +353,10 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
result = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, result
@@ -373,10 +369,10 @@ class FederationMakeKnockServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
user_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
# Retrieve the room versions the remote homeserver claims to support
supported_versions = parse_strings_from_args(
query, "ver", required=True, encoding="utf-8"
@@ -396,10 +392,10 @@ class FederationV1SendKnockServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
result = await self.handler.on_send_knock_request(origin, content, room_id)
return 200, result
@@ -412,10 +408,10 @@ class FederationEventAuthServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
return await self.handler.on_event_auth(origin, room_id, event_id)
@@ -427,10 +423,10 @@ class FederationV1SendJoinServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, Tuple[int, JsonDict]]:
) -> tuple[int, tuple[int, JsonDict]]:
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
result = await self.handler.on_send_join_request(origin, content, room_id)
@@ -447,10 +443,10 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
# TODO(paul): assert that event_id parsed from path actually
# match those given in content
@@ -470,10 +466,10 @@ class FederationV1InviteServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, Tuple[int, JsonDict]]:
) -> tuple[int, tuple[int, JsonDict]]:
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
@@ -497,10 +493,10 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
event_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
# TODO(paul): assert that room_id/event_id parsed from path actually
# match those given in content
@@ -535,9 +531,9 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
await self.handler.on_exchange_third_party_invite_request(content)
return 200, {}
@@ -547,8 +543,8 @@ class FederationClientKeysQueryServlet(BaseFederationServerServlet):
CATEGORY = "Federation requests"
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
self, origin: str, content: JsonDict, query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]:
return await self.handler.on_query_client_keys(origin, content)
@@ -560,9 +556,9 @@ class FederationUserDevicesQueryServlet(BaseFederationServerServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
user_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
return await self.handler.on_query_user_devices(origin, user_id)
@@ -571,10 +567,10 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet):
CATEGORY = "Federation requests"
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
self, origin: str, content: JsonDict, query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]:
# Generate a count for each algorithm, which is hard-coded to 1.
key_query: List[Tuple[str, str, str, int]] = []
key_query: list[tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
key_query.append((user_id, device_id, algorithm, 1))
@@ -597,10 +593,10 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet):
CATEGORY = "Federation requests"
async def on_POST(
self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
self, origin: str, content: JsonDict, query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]:
# Generate a count for each algorithm.
key_query: List[Tuple[str, str, str, int]] = []
key_query: list[tuple[str, str, str, int]] = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithms in device_keys.items():
counts = Counter(algorithms)
@@ -621,9 +617,9 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet):
self,
origin: str,
content: JsonDict,
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
limit = int(content.get("limit", 10))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
@@ -646,8 +642,8 @@ class On3pidBindServlet(BaseFederationServerServlet):
REQUIRE_AUTH = False
async def on_POST(
self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]]
) -> Tuple[int, JsonDict]:
self, origin: Optional[str], content: JsonDict, query: dict[bytes, list[bytes]]
) -> tuple[int, JsonDict]:
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -682,8 +678,8 @@ class FederationVersionServlet(BaseFederationServlet):
self,
origin: Optional[str],
content: Literal[None],
query: Dict[bytes, List[bytes]],
) -> Tuple[int, JsonDict]:
query: dict[bytes, list[bytes]],
) -> tuple[int, JsonDict]:
return (
200,
{
@@ -715,7 +711,7 @@ class FederationRoomHierarchyServlet(BaseFederationServlet):
content: Literal[None],
query: Mapping[bytes, Sequence[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
suggested_only = parse_boolean_from_args(query, "suggested_only", default=False)
return 200, await self.handler.get_federation_hierarchy(
origin, room_id, suggested_only
@@ -746,9 +742,9 @@ class RoomComplexityServlet(BaseFederationServlet):
self,
origin: str,
content: Literal[None],
query: Dict[bytes, List[bytes]],
query: dict[bytes, list[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
is_public = await self._store.is_room_world_readable_or_publicly_joinable(
room_id
)
@@ -780,7 +776,7 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
content: JsonDict,
query: Mapping[bytes, Sequence[bytes]],
room_id: str,
) -> Tuple[int, JsonDict]:
) -> tuple[int, JsonDict]:
if "user_ids" not in content:
raise SynapseError(
400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
@@ -882,7 +878,7 @@ class FederationMediaThumbnailServlet(BaseFederationServerServlet):
self.media_repo.mark_recently_accessed(None, media_id)
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FEDERATION_SERVLET_CLASSES: tuple[type[BaseFederationServlet], ...] = (
FederationSendServlet,
FederationEventServlet,
FederationStateV1Servlet,

View File

@@ -24,7 +24,7 @@ server protocol.
"""
import logging
from typing import List, Optional, Sequence
from typing import Optional, Sequence
import attr
@@ -70,7 +70,7 @@ class Edu:
getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
def _none_to_list(edus: Optional[List[JsonDict]]) -> List[JsonDict]:
def _none_to_list(edus: Optional[list[JsonDict]]) -> list[JsonDict]:
if edus is None:
return []
return edus
@@ -98,8 +98,8 @@ class Transaction:
origin: str
destination: str
origin_server_ts: int
pdus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
edus: List[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
pdus: list[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
edus: list[JsonDict] = attr.ib(factory=list, converter=_none_to_list)
def get_dict(self) -> JsonDict:
"""A JSON-ready dictionary of valid keys which aren't internal."""
@@ -113,7 +113,7 @@ class Transaction:
return result
def filter_pdus_for_valid_depth(pdus: Sequence[JsonDict]) -> List[JsonDict]:
def filter_pdus_for_valid_depth(pdus: Sequence[JsonDict]) -> list[JsonDict]:
filtered_pdus = []
for pdu in pdus:
# Drop PDUs that have a depth that is outside of the range allowed
@@ -129,5 +129,5 @@ def filter_pdus_for_valid_depth(pdus: Sequence[JsonDict]) -> List[JsonDict]:
def serialize_and_filter_pdus(
pdus: Sequence[EventBase], time_now: Optional[int] = None
) -> List[JsonDict]:
) -> list[JsonDict]:
return filter_pdus_for_valid_depth([pdu.get_pdu_json(time_now) for pdu in pdus])

View File

@@ -19,7 +19,7 @@
#
#
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING
from synapse.api.errors import Codes, SynapseError
from synapse.types import JsonDict, UserID
@@ -40,9 +40,9 @@ class AccountHandler:
async def get_account_statuses(
self,
user_ids: List[str],
user_ids: list[str],
allow_remote: bool,
) -> Tuple[JsonDict, List[str]]:
) -> tuple[JsonDict, list[str]]:
"""Get account statuses for a list of user IDs.
If one or more account(s) belong to remote homeservers, retrieve their status(es)
@@ -63,7 +63,7 @@ class AccountHandler:
"""
statuses = {}
failures = []
remote_users: List[UserID] = []
remote_users: list[UserID] = []
for raw_user_id in user_ids:
try:
@@ -127,8 +127,8 @@ class AccountHandler:
return status
async def _get_remote_account_statuses(
self, remote_users: List[UserID]
) -> Tuple[JsonDict, List[str]]:
self, remote_users: list[UserID]
) -> tuple[JsonDict, list[str]]:
"""Send out federation requests to retrieve the statuses of remote accounts.
Args:
@@ -140,7 +140,7 @@ class AccountHandler:
"""
# Group remote users by destination, so we only send one request per remote
# homeserver.
by_destination: Dict[str, List[str]] = {}
by_destination: dict[str, list[str]] = {}
for user in remote_users:
if user.domain not in by_destination:
by_destination[user.domain] = []
@@ -149,7 +149,7 @@ class AccountHandler:
# Retrieve the statuses and failures for remote accounts.
final_statuses: JsonDict = {}
final_failures: List[str] = []
final_failures: list[str] = []
for destination, users in by_destination.items():
statuses, failures = await self._federation_client.get_account_status(
destination,

View File

@@ -21,7 +21,7 @@
#
import logging
import random
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Awaitable, Callable, Optional
from synapse.api.constants import AccountDataTypes
from synapse.replication.http.account_data import (
@@ -67,7 +67,7 @@ class AccountDataHandler:
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
self._account_data_writers = hs.config.worker.writers.account_data
self._on_account_data_updated_callbacks: List[
self._on_account_data_updated_callbacks: list[
ON_ACCOUNT_DATA_UPDATED_CALLBACK
] = []
@@ -325,7 +325,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
room_ids: StrCollection,
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[JsonDict], int]:
) -> tuple[list[JsonDict], int]:
user_id = user.to_string()
last_stream_id = from_key

View File

@@ -21,7 +21,7 @@
import email.mime.multipart
import email.utils
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -222,7 +222,7 @@ class AccountValidityHandler:
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
async def _get_email_addresses_for_user(self, user_id: str) -> list[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
@@ -263,7 +263,7 @@ class AccountValidityHandler:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
async def renew_account(self, renewal_token: str) -> Tuple[bool, bool, int]:
async def renew_account(self, renewal_token: str) -> tuple[bool, bool, int]:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.

View File

@@ -24,13 +24,9 @@ import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
)
import attr
@@ -218,7 +214,7 @@ class AdminHandler:
to_key = RoomStreamToken(stream=stream_ordering)
# Events that we've processed in this room
written_events: Set[str] = set()
written_events: set[str] = set()
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
@@ -231,7 +227,7 @@ class AdminHandler:
# The reverse mapping to above, i.e. map from unseen event to events
# that have the unseen event in their prev_events, i.e. the unseen
# events "children".
unseen_to_child_events: Dict[str, Set[str]] = {}
unseen_to_child_events: dict[str, set[str]] = {}
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
@@ -412,7 +408,7 @@ class AdminHandler:
async def _redact_all_events(
self, task: ScheduledTask
) -> Tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]:
) -> tuple[TaskStatus, Optional[Mapping[str, Any]], Optional[str]]:
"""
Task to redact all of a users events in the given rooms, tracking which, if any, events
whose redaction failed
@@ -518,7 +514,7 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data."""
@abc.abstractmethod
def write_events(self, room_id: str, events: List[EventBase]) -> None:
def write_events(self, room_id: str, events: list[EventBase]) -> None:
"""Write a batch of events for a room."""
raise NotImplementedError()

View File

@@ -22,12 +22,9 @@ import logging
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)
@@ -143,7 +140,7 @@ class ApplicationServicesHandler:
event_to_received_ts.keys(), get_prev_content=True
)
events_by_room: Dict[str, List[EventBase]] = {}
events_by_room: dict[str, list[EventBase]] = {}
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@@ -341,7 +338,7 @@ class ApplicationServicesHandler:
@wrap_as_background_process("notify_interested_services_ephemeral")
async def _notify_interested_services_ephemeral(
self,
services: List[ApplicationService],
services: list[ApplicationService],
stream_key: StreamKeyType,
new_token: Union[int, MultiWriterStreamToken],
users: Collection[Union[str, UserID]],
@@ -429,7 +426,7 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonMapping]:
) -> list[JsonMapping]:
"""
Return the typing events since the given stream token that the given application
service should receive.
@@ -464,7 +461,7 @@ class ApplicationServicesHandler:
async def _handle_receipts(
self, service: ApplicationService, new_token: MultiWriterStreamToken
) -> List[JsonMapping]:
) -> list[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
@@ -503,7 +500,7 @@ class ApplicationServicesHandler:
service: ApplicationService,
users: Collection[Union[str, UserID]],
new_token: Optional[int],
) -> List[JsonMapping]:
) -> list[JsonMapping]:
"""
Return the latest presence updates that the given application service should receive.
@@ -523,7 +520,7 @@ class ApplicationServicesHandler:
A list of json dictionaries containing data derived from the presence events
that should be sent to the given application service.
"""
events: List[JsonMapping] = []
events: list[JsonMapping] = []
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
@@ -563,7 +560,7 @@ class ApplicationServicesHandler:
service: ApplicationService,
new_token: int,
users: Collection[Union[str, UserID]],
) -> List[JsonDict]:
) -> list[JsonDict]:
"""
Given an application service, determine which events it should receive
from those between the last-recorded to-device message stream token for this
@@ -585,7 +582,7 @@ class ApplicationServicesHandler:
)
# Filter out users that this appservice is not interested in
users_appservice_is_interested_in: List[str] = []
users_appservice_is_interested_in: list[str] = []
for user in users:
# FIXME: We should do this farther up the call stack. We currently repeat
# this operation in _handle_presence.
@@ -612,7 +609,7 @@ class ApplicationServicesHandler:
#
# So we mangle this dict into a flat list of to-device messages with the relevant
# user ID and device ID embedded inside each message dict.
message_payload: List[JsonDict] = []
message_payload: list[JsonDict] = []
for (
user_id,
device_id,
@@ -761,8 +758,8 @@ class ApplicationServicesHandler:
return None
async def query_3pe(
self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
) -> List[JsonDict]:
self, kind: str, protocol: str, fields: dict[bytes, list[bytes]]
) -> list[JsonDict]:
services = self._get_services_for_3pn(protocol)
results = await make_deferred_yieldable(
@@ -786,9 +783,9 @@ class ApplicationServicesHandler:
async def get_3pe_protocols(
self, only_protocol: Optional[str] = None
) -> Dict[str, JsonDict]:
) -> dict[str, JsonDict]:
services = self.store.get_app_services()
protocols: Dict[str, List[JsonDict]] = {}
protocols: dict[str, list[JsonDict]] = {}
# Collect up all the individual protocol responses out of the ASes
for s in services:
@@ -804,7 +801,7 @@ class ApplicationServicesHandler:
if info is not None:
protocols[p].append(info)
def _merge_instances(infos: List[JsonDict]) -> JsonDict:
def _merge_instances(infos: list[JsonDict]) -> JsonDict:
# Merge the 'instances' lists of multiple results, but just take
# the other fields from the first as they ought to be identical
# copy the result so as not to corrupt the cached one
@@ -822,7 +819,7 @@ class ApplicationServicesHandler:
async def _get_services_for_event(
self, event: EventBase
) -> List[ApplicationService]:
) -> list[ApplicationService]:
"""Retrieve a list of application services interested in this event.
Args:
@@ -842,11 +839,11 @@ class ApplicationServicesHandler:
return interested_list
def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
def _get_services_for_user(self, user_id: str) -> list[ApplicationService]:
services = self.store.get_app_services()
return [s for s in services if (s.is_interested_in_user(user_id))]
def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
def _get_services_for_3pn(self, protocol: str) -> list[ApplicationService]:
services = self.store.get_app_services()
return [s for s in services if s.is_interested_in_protocol(protocol)]
@@ -872,9 +869,9 @@ class ApplicationServicesHandler:
return True
async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str, int]]
) -> Tuple[
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
self, query: Iterable[tuple[str, str, str, int]]
) -> tuple[
dict[str, dict[str, dict[str, JsonDict]]], list[tuple[str, str, str, int]]
]:
"""Claim one time keys from application services.
@@ -896,7 +893,7 @@ class ApplicationServicesHandler:
services = self.store.get_app_services()
# Partition the users by appservice.
query_by_appservice: Dict[str, List[Tuple[str, str, str, int]]] = {}
query_by_appservice: dict[str, list[tuple[str, str, str, int]]] = {}
missing = []
for user_id, device, algorithm, count in query:
if not self.store.get_if_app_services_interested_in_user(user_id):
@@ -929,7 +926,7 @@ class ApplicationServicesHandler:
# Patch together the results -- they are all independent (since they
# require exclusive control over the users, which is the outermost key).
claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
claimed_keys: dict[str, dict[str, dict[str, JsonDict]]] = {}
for success, result in results:
if success:
claimed_keys.update(result[0])
@@ -938,8 +935,8 @@ class ApplicationServicesHandler:
return claimed_keys, missing
async def query_keys(
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
self, query: Mapping[str, Optional[list[str]]]
) -> dict[str, dict[str, dict[str, JsonDict]]]:
"""Query application services for device keys.
Users which are exclusively owned by an application service are queried
@@ -954,7 +951,7 @@ class ApplicationServicesHandler:
services = self.store.get_app_services()
# Partition the users by appservice.
query_by_appservice: Dict[str, Dict[str, List[str]]] = {}
query_by_appservice: dict[str, dict[str, list[str]]] = {}
for user_id, device_ids in query.items():
if not self.store.get_if_app_services_interested_in_user(user_id):
continue
@@ -986,7 +983,7 @@ class ApplicationServicesHandler:
# Patch together the results -- they are all independent (since they
# require exclusive control over the users). They get returned as a single
# dictionary.
key_queries: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
key_queries: dict[str, dict[str, dict[str, JsonDict]]] = {}
for success, result in results:
if success:
key_queries.update(result)

View File

@@ -31,13 +31,9 @@ from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
Union,
cast,
)
@@ -102,7 +98,7 @@ invalid_login_token_counter = Counter(
def convert_client_dict_legacy_fields_to_identifier(
submission: JsonDict,
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Convert a legacy-formatted login submission to an identifier dict.
@@ -154,7 +150,7 @@ def convert_client_dict_legacy_fields_to_identifier(
return identifier
def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
def login_id_phone_to_thirdparty(identifier: JsonDict) -> dict[str, str]:
"""
Convert a phone login identifier type to a generic threepid identifier.
@@ -205,7 +201,7 @@ class AuthHandler:
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self.clock = hs.get_clock()
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
self.checkers: dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
@@ -280,7 +276,7 @@ class AuthHandler:
# A mapping of user ID to extra attributes to include in the login
# response.
self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {}
self._extra_attributes: dict[str, SsoLoginExtraAttributes] = {}
self._auth_delegation_enabled = (
hs.config.mas.enabled or hs.config.experimental.msc3861.enabled
@@ -290,10 +286,10 @@ class AuthHandler:
self,
requester: Requester,
request: SynapseRequest,
request_body: Dict[str, Any],
request_body: dict[str, Any],
description: str,
can_skip_ui_auth: bool = False,
) -> Tuple[dict, Optional[str]]:
) -> tuple[dict, Optional[str]]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -440,12 +436,12 @@ class AuthHandler:
async def check_ui_auth(
self,
flows: List[List[str]],
flows: list[list[str]],
request: SynapseRequest,
clientdict: Dict[str, Any],
clientdict: dict[str, Any],
description: str,
get_new_session_data: Optional[Callable[[], JsonDict]] = None,
) -> Tuple[dict, dict, str]:
) -> tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
@@ -579,7 +575,7 @@ class AuthHandler:
)
# check auth type currently being presented
errordict: Dict[str, Any] = {}
errordict: dict[str, Any] = {}
if "type" in authdict:
login_type: str = authdict["type"]
try:
@@ -617,7 +613,7 @@ class AuthHandler:
raise InteractiveAuthIncompleteError(session.session_id, ret)
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
self, stagetype: str, authdict: dict[str, Any], clientip: str
) -> None:
"""
Adds the result of out-of-band authentication into an existing auth
@@ -641,7 +637,7 @@ class AuthHandler:
authdict["session"], stagetype, result
)
def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
def get_session_id(self, clientdict: dict[str, Any]) -> Optional[str]:
"""
Gets the session ID for a client given the client dictionary
@@ -702,8 +698,8 @@ class AuthHandler:
await self.store.delete_old_ui_auth_sessions(expiration_time)
async def _check_auth_dict(
self, authdict: Dict[str, Any], clientip: str
) -> Union[Dict[str, Any], str]:
self, authdict: dict[str, Any], clientip: str
) -> Union[dict[str, Any], str]:
"""Attempt to validate the auth dict provided by a client
Args:
@@ -750,9 +746,9 @@ class AuthHandler:
def _auth_dict_for_flows(
self,
flows: List[List[str]],
flows: list[list[str]],
session_id: str,
) -> Dict[str, Any]:
) -> dict[str, Any]:
public_flows = []
for f in flows:
public_flows.append(f)
@@ -762,7 +758,7 @@ class AuthHandler:
LoginType.TERMS: self._get_params_terms,
}
params: Dict[str, Any] = {}
params: dict[str, Any] = {}
for f in public_flows:
for stage in f:
@@ -780,7 +776,7 @@ class AuthHandler:
refresh_token: str,
access_token_valid_until_ms: Optional[int],
refresh_token_valid_until_ms: Optional[int],
) -> Tuple[str, str, Optional[int]]:
) -> tuple[str, str, Optional[int]]:
"""
Consumes a refresh token and generate both a new access token and a new refresh token from it.
@@ -934,7 +930,7 @@ class AuthHandler:
device_id: str,
expiry_ts: Optional[int],
ultimate_session_expiry_ts: Optional[int],
) -> Tuple[str, int]:
) -> tuple[str, int]:
"""
Creates a new refresh token for the user with the given user ID.
@@ -1067,7 +1063,7 @@ class AuthHandler:
async def _find_user_id_and_pwd_hash(
self, user_id: str
) -> Optional[Tuple[str, str]]:
) -> Optional[tuple[str, str]]:
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
@@ -1142,10 +1138,10 @@ class AuthHandler:
async def validate_login(
self,
login_submission: Dict[str, Any],
login_submission: dict[str, Any],
ratelimit: bool = False,
is_reauth: bool = False,
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -1300,8 +1296,8 @@ class AuthHandler:
async def _validate_userid_login(
self,
username: str,
login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
login_submission: dict[str, Any],
) -> tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1390,7 +1386,7 @@ class AuthHandler:
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
) -> Tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
) -> tuple[Optional[str], Optional[Callable[["LoginResponse"], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1905,7 +1901,7 @@ class AuthHandler:
extra_attributes = self._extra_attributes.get(login_result["user_id"])
if extra_attributes:
login_result_dict = cast(Dict[str, Any], login_result)
login_result_dict = cast(dict[str, Any], login_result)
login_result_dict.update(extra_attributes.extra_attributes)
def _expire_sso_extra_attributes(self) -> None:
@@ -1941,7 +1937,7 @@ def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
def load_single_legacy_password_auth_provider(
module: Type,
module: type,
config: JsonDict,
api: "ModuleApi",
) -> None:
@@ -1966,7 +1962,7 @@ def load_single_legacy_password_auth_provider(
async def wrapped_check_password(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
) -> Optional[tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
@@ -1985,12 +1981,12 @@ def load_single_legacy_password_auth_provider(
return wrapped_check_password
# We need to wrap check_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
# just a str, but now it must return Optional[tuple[str, Optional[Callable]]
if f.__name__ == "check_auth":
async def wrapped_check_auth(
username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable]]]:
) -> Optional[tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
@@ -2006,12 +2002,12 @@ def load_single_legacy_password_auth_provider(
return wrapped_check_auth
# We need to wrap check_3pid_auth as in the old form it could return
# just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
# just a str, but now it must return Optional[tuple[str, Optional[Callable]]
if f.__name__ == "check_3pid_auth":
async def wrapped_check_3pid_auth(
medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable]]]:
) -> Optional[tuple[str, Optional[Callable]]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
@@ -2026,7 +2022,7 @@ def load_single_legacy_password_auth_provider(
return wrapped_check_3pid_auth
def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
def run(*args: tuple, **kwargs: dict) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
assert f is not None
@@ -2079,14 +2075,14 @@ def load_single_legacy_password_auth_provider(
CHECK_3PID_AUTH_CALLBACK = Callable[
[str, str, str],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
CHECK_AUTH_CALLBACK = Callable[
[str, str, JsonDict],
Awaitable[
Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
],
]
GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
@@ -2108,21 +2104,21 @@ class PasswordAuthProvider:
def __init__(self) -> None:
# lists of callbacks
self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
self.get_username_for_registration_callbacks: List[
self.check_3pid_auth_callbacks: list[CHECK_3PID_AUTH_CALLBACK] = []
self.on_logged_out_callbacks: list[ON_LOGGED_OUT_CALLBACK] = []
self.get_username_for_registration_callbacks: list[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
] = []
self.get_displayname_for_registration_callbacks: List[
self.get_displayname_for_registration_callbacks: list[
GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
] = []
self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
self.is_3pid_allowed_callbacks: list[IS_3PID_ALLOWED_CALLBACK] = []
# Mapping from login type to login parameters
self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
self._supported_login_types: dict[str, tuple[str, ...]] = {}
# Mapping from login type to auth checker callbacks
self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
self.auth_checker_callbacks: dict[str, list[CHECK_AUTH_CALLBACK]] = {}
def register_password_auth_provider_callbacks(
self,
@@ -2130,7 +2126,7 @@ class PasswordAuthProvider:
on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
auth_checkers: Optional[
Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
dict[tuple[str, tuple[str, ...]], CHECK_AUTH_CALLBACK]
] = None,
get_username_for_registration: Optional[
GET_USERNAME_FOR_REGISTRATION_CALLBACK
@@ -2207,7 +2203,7 @@ class PasswordAuthProvider:
async def check_auth(
self, username: str, login_type: str, login_dict: JsonDict
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
"""Check if the user has presented valid login credentials
Args:
@@ -2245,7 +2241,7 @@ class PasswordAuthProvider:
if not isinstance(result, tuple) or len(result) != 2:
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
" Optional[tuple[str, Optional[Callable]]]",
callback,
result,
)
@@ -2258,7 +2254,7 @@ class PasswordAuthProvider:
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
" Optional[tuple[str, Optional[Callable]]]",
callback,
result,
)
@@ -2269,7 +2265,7 @@ class PasswordAuthProvider:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
" Optional[tuple[str, Optional[Callable]]]",
callback,
result,
)
@@ -2284,7 +2280,7 @@ class PasswordAuthProvider:
async def check_3pid_auth(
self, medium: str, address: str, password: str
) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
) -> Optional[tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
# This function is able to return a deferred that either
# resolves None, meaning authentication failure, or upon
# success, to a str (which is the user_id) or a tuple of
@@ -2308,7 +2304,7 @@ class PasswordAuthProvider:
if not isinstance(result, tuple) or len(result) != 2:
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
" Optional[tuple[str, Optional[Callable]]]",
callback,
result,
)
@@ -2321,7 +2317,7 @@ class PasswordAuthProvider:
if not isinstance(str_result, str):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
" Optional[tuple[str, Optional[Callable]]]",
callback,
result,
)
@@ -2332,7 +2328,7 @@ class PasswordAuthProvider:
if not callable(callback_result):
logger.warning( # type: ignore[unreachable]
"Wrong type returned by module API callback %s: %s, expected"
" Optional[Tuple[str, Optional[Callable]]]",
" Optional[tuple[str, Optional[Callable]]]",
callback,
result,
)

View File

@@ -20,7 +20,7 @@
#
import logging
import urllib.parse
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Optional
from xml.etree import ElementTree as ET
import attr
@@ -54,7 +54,7 @@ class CasError(Exception):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class CasResponse:
username: str
attributes: Dict[str, List[Optional[str]]]
attributes: dict[str, list[Optional[str]]]
class CasHandler:
@@ -99,7 +99,7 @@ class CasHandler:
self._sso_handler.register_identity_provider(self)
def _build_service_param(self, args: Dict[str, str]) -> str:
def _build_service_param(self, args: dict[str, str]) -> str:
"""
Generates a value to use as the "service" parameter when redirecting or
querying the CAS service.
@@ -116,7 +116,7 @@ class CasHandler:
)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
self, ticket: str, service_args: dict[str, str]
) -> CasResponse:
"""
Validate a CAS ticket with the server, and return the parsed the response.
@@ -186,7 +186,7 @@ class CasHandler:
# Iterate through the nodes and pull out the user and any extra attributes.
user = None
attributes: Dict[str, List[Optional[str]]] = {}
attributes: dict[str, list[Optional[str]]] = {}
for child in root[0]:
if child.tag.endswith("user"):
user = child.text

View File

@@ -13,7 +13,7 @@
#
import logging
from typing import TYPE_CHECKING, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Optional
from twisted.internet.interfaces import IDelayedCall
@@ -226,7 +226,7 @@ class DelayedEventsHandler:
await self._store.update_delayed_events_stream_pos(max_pos)
async def _handle_state_deltas(self, deltas: List[StateDelta]) -> None:
async def _handle_state_deltas(self, deltas: list[StateDelta]) -> None:
"""
Process current state deltas to cancel other users' pending delayed events
that target the same state.
@@ -502,8 +502,8 @@ class DelayedEventsHandler:
await self._send_events(events)
async def _send_events(self, events: List[DelayedEventDetails]) -> None:
sent_state: Set[Tuple[RoomID, EventType, StateKey]] = set()
async def _send_events(self, events: list[DelayedEventDetails]) -> None:
sent_state: set[tuple[RoomID, EventType, StateKey]] = set()
for event in events:
if event.state_key is not None:
state_info = (event.room_id, event.type, event.state_key)
@@ -547,7 +547,7 @@ class DelayedEventsHandler:
else:
self._next_delayed_event_call.reset(delay_sec)
async def get_all_for_user(self, requester: Requester) -> List[JsonDict]:
async def get_all_for_user(self, requester: Requester) -> list[JsonDict]:
"""Return all pending delayed events requested by the given user."""
await self._delayed_event_mgmt_ratelimiter.ratelimit(
requester,

View File

@@ -25,13 +25,9 @@ from threading import Lock
from typing import (
TYPE_CHECKING,
AbstractSet,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
cast,
)
@@ -407,7 +403,7 @@ class DeviceHandler:
raise
@trace
async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
async def get_devices_by_user(self, user_id: str) -> list[JsonDict]:
"""
Retrieve the given user's devices
@@ -431,7 +427,7 @@ class DeviceHandler:
async def get_dehydrated_device(
self, user_id: str
) -> Optional[Tuple[str, JsonDict]]:
) -> Optional[tuple[str, JsonDict]]:
"""Retrieve the information for a dehydrated device.
Args:
@@ -568,7 +564,7 @@ class DeviceHandler:
room_ids: StrCollection,
from_token: StreamToken,
now_token: Optional[StreamToken] = None,
) -> Set[str]:
) -> set[str]:
"""Get the set of users whose devices have changed who share a room with
the given user.
"""
@@ -644,8 +640,8 @@ class DeviceHandler:
# Check for newly joined or left rooms. We need to make sure that we add
# to newly joined in the case membership goes from join -> leave -> join
# again.
newly_joined_rooms: Set[str] = set()
newly_left_rooms: Set[str] = set()
newly_joined_rooms: set[str] = set()
newly_left_rooms: set[str] = set()
for change in membership_changes:
# We check for changes in "joinedness", i.e. if the membership has
# changed to or from JOIN.
@@ -661,10 +657,10 @@ class DeviceHandler:
# the user is currently in.
# List of membership changes per room
room_to_deltas: Dict[str, List[StateDelta]] = {}
room_to_deltas: dict[str, list[StateDelta]] = {}
# The set of event IDs of membership events (so we can fetch their
# associated membership).
memberships_to_fetch: Set[str] = set()
memberships_to_fetch: set[str] = set()
# TODO: Only pull out membership events?
state_changes = await self.store.get_current_state_deltas_for_rooms(
@@ -695,8 +691,8 @@ class DeviceHandler:
# We now want to find any user that have newly joined/invited/knocked,
# or newly left, similarly to above.
newly_joined_or_invited_or_knocked_users: Set[str] = set()
newly_left_users: Set[str] = set()
newly_joined_or_invited_or_knocked_users: set[str] = set()
newly_left_users: set[str] = set()
for _, deltas in room_to_deltas.items():
for delta in deltas:
# Get the prev/new memberships for the delta
@@ -838,7 +834,7 @@ class DeviceHandler:
# Check if the application services have any results.
if self._query_appservices_for_keys:
# Query the appservice for all devices for this user.
query: Dict[str, Optional[List[str]]] = {user_id: None}
query: dict[str, Optional[list[str]]] = {user_id: None}
# Query the appservices for any keys.
appservice_results = await self._appservice_handler.query_keys(query)
@@ -898,7 +894,7 @@ class DeviceHandler:
async def notify_user_signature_update(
self,
from_user_id: str,
user_ids: List[str],
user_ids: list[str],
) -> None:
"""Notify a device writer that a user have made new signatures of other users.
@@ -927,7 +923,7 @@ class DeviceHandler:
async def _delete_device_messages(
self,
task: ScheduledTask,
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
) -> tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
"""Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
assert task.params is not None
user_id = task.params["user_id"]
@@ -1051,7 +1047,7 @@ class DeviceWriterHandler(DeviceHandler):
await self.handle_new_device_update()
async def notify_user_signature_update(
self, from_user_id: str, user_ids: List[str]
self, from_user_id: str, user_ids: list[str]
) -> None:
"""Notify a user that they have made new signatures of other users.
@@ -1112,7 +1108,7 @@ class DeviceWriterHandler(DeviceHandler):
# hosts we've already poked about for this update. This is so that we
# don't poke the same remote server about the same update repeatedly.
current_stream_id = None
hosts_already_sent_to: Set[str] = set()
hosts_already_sent_to: set[str] = set()
try:
stream_id, room_id = await self.store.get_device_change_last_converted_pos()
@@ -1311,7 +1307,7 @@ class DeviceWriterHandler(DeviceHandler):
def _update_device_from_client_ips(
device: JsonDict, client_ips: Mapping[Tuple[str, str], DeviceLastConnectionInfo]
device: JsonDict, client_ips: Mapping[tuple[str, str], DeviceLastConnectionInfo]
) -> None:
ip = client_ips.get((device["user_id"], device["device_id"]))
device.update(
@@ -1338,8 +1334,8 @@ class DeviceListWorkerUpdater:
async def multi_user_device_resync(
self,
user_ids: List[str],
) -> Dict[str, Optional[JsonMapping]]:
user_ids: list[str],
) -> dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
@@ -1365,7 +1361,7 @@ class DeviceListWorkerUpdater:
user_id: str,
master_key: Optional[JsonDict],
self_signing_key: Optional[JsonDict],
) -> List[str]:
) -> list[str]:
"""Process the given new master and self-signing key for the given remote user.
Args:
@@ -1455,14 +1451,14 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
)
# user_id -> list of updates waiting to be handled.
self._pending_updates: Dict[
str, List[Tuple[str, str, Iterable[str], JsonDict]]
self._pending_updates: dict[
str, list[tuple[str, str, Iterable[str], JsonDict]]
] = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache(
self._seen_updates: ExpiringCache[str, set[str]] = ExpiringCache(
cache_name="device_update_edu",
server_name=self.server_name,
hs=self.hs,
@@ -1619,12 +1615,12 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
)
async def _need_to_do_resync(
self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
self, user_id: str, updates: Iterable[tuple[str, str, Iterable[str], JsonDict]]
) -> bool:
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates: Set[str] = self._seen_updates.get(user_id, set())
seen_updates: set[str] = self._seen_updates.get(user_id, set())
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
@@ -1702,8 +1698,8 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
self._resync_retry_lock.release()
async def multi_user_device_resync(
self, user_ids: List[str], mark_failed_as_stale: bool = True
) -> Dict[str, Optional[JsonMapping]]:
self, user_ids: list[str], mark_failed_as_stale: bool = True
) -> dict[str, Optional[JsonMapping]]:
"""
Like `user_device_resync` but operates on multiple users **from the same origin**
at once.
@@ -1739,7 +1735,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
async def _user_device_resync_returning_failed(
self, user_id: str
) -> Tuple[Optional[JsonMapping], bool]:
) -> tuple[Optional[JsonMapping], bool]:
"""Fetches all devices for a user and updates the device cache with them.
Args:

View File

@@ -21,7 +21,7 @@
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Optional
from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
from synapse.api.errors import Codes, SynapseError
@@ -158,7 +158,7 @@ class DeviceMessageHandler:
self,
message_type: str,
sender_user_id: str,
by_device: Dict[str, Dict[str, Any]],
by_device: dict[str, dict[str, Any]],
) -> None:
"""Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale.
@@ -207,7 +207,7 @@ class DeviceMessageHandler:
self,
requester: Requester,
message_type: str,
messages: Dict[str, Dict[str, JsonDict]],
messages: dict[str, dict[str, JsonDict]],
) -> None:
"""
Handle a request from a user to send to-device message(s).
@@ -222,7 +222,7 @@ class DeviceMessageHandler:
set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
local_messages = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
remote_messages: dict[str, dict[str, dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
if not UserID.is_valid(user_id):
logger.warning(

Some files were not shown because too many files have changed in this diff Show More