Compare commits

...

3 Commits

Author SHA1 Message Date
David Robertson
36f47b37a9 Easier fn annotations for synapse.logging.context
Not yet passing `no-untyped-defs`. `make_deferred_yieldable` is tricky and needs more thought.
2021-10-04 14:07:56 +01:00
David Robertson
63c33aa4ce no-untyped-defs for synapse.logging.handlers 2021-10-04 14:06:08 +01:00
David Robertson
f984191b04 no-untyped-defs for synapse.logging.formatter 2021-10-04 14:05:43 +01:00
4 changed files with 75 additions and 30 deletions

View File

@@ -96,6 +96,12 @@ files =
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
[mypy-synapse.logging.formatter]
disallow_untyped_defs = True
[mypy-synapse.logging.handlers]
disallow_untyped_defs = True
[mypy-synapse.rest.*]
disallow_untyped_defs = True

View File

@@ -22,21 +22,37 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
import functools
import inspect
import logging
import threading
import typing
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import attr
from typing_extensions import Literal
from twisted.internet import defer, threads
from twisted.internet.interfaces import IReactorThreads, ThreadPool
if TYPE_CHECKING:
from synapse.logging.scopecontextmanager import _LogContextScope
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Any])
logger = logging.getLogger(__name__)
try:
@@ -66,7 +82,7 @@ except Exception:
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
def logcontext_error(msg: str):
def logcontext_error(msg: str) -> None:
logger.warning(msg)
@@ -220,28 +236,28 @@ class _Sentinel:
self.scope = None
self.tag = None
def __str__(self):
def __str__(self) -> str:
return "sentinel"
def copy_to(self, record):
def copy_to(self, record: "LoggingContext") -> None:
pass
def start(self, rusage: "Optional[resource._RUsage]"):
def start(self, rusage: "Optional[resource._RUsage]") -> None:
pass
def stop(self, rusage: "Optional[resource._RUsage]"):
def stop(self, rusage: "Optional[resource._RUsage]") -> None:
pass
def add_database_transaction(self, duration_sec):
def add_database_transaction(self, duration_sec: float) -> None:
pass
def add_database_scheduled(self, sched_sec):
def add_database_scheduled(self, sched_sec: float) -> None:
pass
def record_event_fetch(self, event_count):
def record_event_fetch(self, event_count: int) -> None:
pass
def __bool__(self):
def __bool__(self) -> bool:
return False
@@ -379,7 +395,12 @@ class LoggingContext:
)
return self
def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
@@ -399,10 +420,8 @@ class LoggingContext:
# recorded against the correct metrics.
self.finished = True
def copy_to(self, record) -> None:
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
def copy_to(self, record: "LoggingContext") -> None:
"""Copy logging fields from this context to another LoggingContext"""
# we track the current request
record.request = self.request
@@ -575,7 +594,7 @@ class LoggingContextFilter(logging.Filter):
record.
"""
def __init__(self, request: str = ""):
def __init__(self, request: str = "") -> None:
self._default_request = request
def filter(self, record: logging.LogRecord) -> Literal[True]:
@@ -626,7 +645,12 @@ class PreserveLoggingContext:
def __enter__(self) -> None:
self._old_context = set_current_context(self._new_context)
def __exit__(self, type, value, traceback) -> None:
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
context = set_current_context(self._old_context)
if context != self._new_context:
@@ -711,16 +735,19 @@ def nested_logging_context(suffix: str) -> LoggingContext:
)
def preserve_fn(f):
def preserve_fn(f: F) -> F:
"""Function decorator which wraps the function with run_in_background"""
def g(*args, **kwargs):
@functools.wraps(f)
def g(*args: Any, **kwargs: Any) -> Any:
return run_in_background(f, *args, **kwargs)
return g
return cast(F, g)
def run_in_background(f, *args, **kwargs) -> defer.Deferred:
def run_in_background(
f: Callable[..., T], *args: Any, **kwargs: Any
) -> "defer.Deferred[T]":
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -823,7 +850,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
return result
def defer_to_thread(reactor, f, *args, **kwargs):
def defer_to_thread(
reactor: IReactorThreads, f: Callable[..., T], *args: Any, **kwargs: Any
) -> "defer.Deferred[T]":
"""
Calls the function `f` using a thread from the reactor's default threadpool and
returns the result as a Deferred.
@@ -855,7 +884,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
def defer_to_threadpool(
reactor: IReactorThreads,
threadpool: ThreadPool,
f: Callable[..., T],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[T]":
"""
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
logcontexts correctly.
@@ -897,7 +932,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
assert isinstance(curr_context, LoggingContext)
parent_context = curr_context
def g():
def g() -> T:
with LoggingContext(str(curr_context), parent_context=parent_context):
return f(*args, **kwargs)

View File

@@ -16,6 +16,13 @@
import logging
import traceback
from io import StringIO
from types import TracebackType
from typing import Optional, Tuple, Type, Union
ExceptionInfo = Union[
Tuple[Type[BaseException], BaseException, Optional[TracebackType]],
Tuple[None, None, None],
]
class LogFormatter(logging.Formatter):
@@ -28,10 +35,7 @@ class LogFormatter(logging.Formatter):
where it was caught are logged).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def formatException(self, ei):
def formatException(self, ei: ExceptionInfo) -> str:
sio = StringIO()
(typ, val, tb) = ei

View File

@@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
)
self._flushing_thread.start()
def on_reactor_running():
def on_reactor_running() -> None:
self._reactor_started = True
reactor_to_use: IReactorCore
@@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
else:
return True
def _flush_periodically(self):
def _flush_periodically(self) -> None:
"""
Whilst this handler is active, flush the handler periodically.
"""