Compare commits

...

3 Commits

Author SHA1 Message Date
H. Shay
55cfb33644 some type hints 2022-01-05 13:15:50 -08:00
H. Shay
45688c3c0f Merge branch 'develop' into shay/add_types_opentracing.py 2022-01-05 09:16:25 -08:00
H. Shay
b06f415943 merge conflict 2021-12-20 10:20:18 -08:00

View File

@@ -168,9 +168,26 @@ import inspect
import logging
import re
from functools import wraps
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Match,
Optional,
Pattern,
Type,
TypeVar,
Union,
cast,
overload,
)
import attr
from mypy.nodes import JsonDict
from twisted.internet import defer
from twisted.web.http import Request
@@ -180,9 +197,14 @@ from synapse.config import ConfigError
from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
from opentracing.span import Span, SpanContext
from opentracing.tracer import Reference
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
F = TypeVar("F", bound=Callable[..., Any])
# Helper class
@@ -253,10 +275,10 @@ try:
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
def set_process(self, *args, **kwargs):
def set_process(self, *args: Any, **kwargs: Any) -> None:
return self._reporter.set_process(*args, **kwargs)
def report_span(self, span):
def report_span(self, span: "Span") -> None:
try:
return self._reporter.report_span(span)
except Exception:
@@ -303,21 +325,39 @@ _homeserver_whitelist: Optional[Pattern[str]] = None
Sentinel = object()
R = TypeVar("R")
def only_if_tracing(func):
def only_if_tracing(func: Callable[..., R]) -> Callable[..., Optional[R]]:
"""Executes the function only if we're tracing. Otherwise returns None."""
@wraps(func)
def _only_if_tracing_inner(*args, **kwargs):
def _only_if_tracing_inner(*args: Any, **kwargs: Any) -> Optional[R]:
if opentracing:
return func(*args, **kwargs)
else:
return
return None
return _only_if_tracing_inner
def ensure_active_span(message, ret=None):
@overload
def ensure_active_span(
message: str,
) -> Callable[[Callable[..., None]], Callable[..., None]]:
...
@overload
def ensure_active_span(
message: str, ret: R
) -> Callable[[Callable[..., R]], Callable[..., R]]:
...
def ensure_active_span(
message: str, ret: Any = None
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.
@@ -330,9 +370,9 @@ def ensure_active_span(message, ret=None):
was no active span.
"""
def ensure_active_span_inner_1(func):
def ensure_active_span_inner_1(func: Callable[..., Any]) -> Callable[..., Any]:
@wraps(func)
def ensure_active_span_inner_2(*args, **kwargs):
def ensure_active_span_inner_2(*args, **kwargs) -> Any:
if not opentracing:
return ret
@@ -354,7 +394,7 @@ def ensure_active_span(message, ret=None):
@contextlib.contextmanager
def noop_context_manager(*args, **kwargs):
def noop_context_manager(*args: Any, **kwargs: Any):
"""Does exactly what it says on the tin"""
# TODO: replace with contextlib.nullcontext once we drop support for Python 3.6
yield
@@ -363,7 +403,7 @@ def noop_context_manager(*args, **kwargs):
# Setup
def init_tracer(hs: "HomeServer"):
def init_tracer(hs: "HomeServer") -> None:
"""Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
if not hs.config.tracing.opentracer_enabled:
@@ -379,7 +419,7 @@ def init_tracer(hs: "HomeServer"):
# Pull out the jaeger config if it was given. Otherwise set it to something sensible.
# See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
assert set_homeserver_whitelist is not None
set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist)
from jaeger_client.metrics.prometheus import PrometheusMetricsFactory
@@ -405,7 +445,7 @@ def init_tracer(hs: "HomeServer"):
@only_if_tracing
def set_homeserver_whitelist(homeserver_whitelist):
def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None:
"""Sets the homeserver whitelist
Args:
@@ -420,7 +460,7 @@ def set_homeserver_whitelist(homeserver_whitelist):
@only_if_tracing
def whitelisted_homeserver(destination):
def whitelisted_homeserver(destination: str) -> Union[bool, Optional[Match[str]]]:
"""Checks if a destination matches the whitelist
Args:
@@ -436,13 +476,13 @@ def whitelisted_homeserver(destination):
# Could use kwargs but I want these to be explicit
def start_active_span(
operation_name,
child_of=None,
references=None,
tags=None,
start_time=None,
ignore_active_span=False,
finish_on_close=True,
operation_name: str,
child_of: Union["Span", "SpanContext", None] = None,
references: Optional[List["Reference"]] = None,
tags: Optional[dict] = None,
start_time: Optional[float] = None,
ignore_active_span: bool = False,
finish_on_close: bool = True,
):
"""Starts an active opentracing span. Note, the scope doesn't become active
until it has been entered, however, the span starts from the time this
@@ -468,7 +508,7 @@ def start_active_span(
def start_active_span_follows_from(
operation_name: str, contexts: Collection, inherit_force_tracing=False
operation_name: str, contexts: Collection, inherit_force_tracing: bool = False
):
"""Starts an active opentracing span, with additional references to previous spans
@@ -487,19 +527,62 @@ def start_active_span_follows_from(
if inherit_force_tracing and any(
is_context_forced_tracing(ctx) for ctx in contexts
):
assert force_tracing is not None
force_tracing(scope.span)
return scope
def start_active_span_from_request(
request: Request,
operation_name: str,
references: Optional[List["Reference"]] = None,
tags: Optional[dict] = None,
start_time: Optional[float] = None,
ignore_active_span: bool = False,
finish_on_close: bool = True,
):
"""
Extracts a span context from a Twisted Request.
args:
headers (twisted.web.http.Request)
For the other args see opentracing.tracer
returns:
span_context (opentracing.span.SpanContext)
"""
# Twisted encodes the values as lists whereas opentracing doesn't.
# So, we take the first item in the list.
# Also, twisted uses byte arrays while opentracing expects strings.
if opentracing is None:
return noop_context_manager() # type: ignore[unreachable]
header_dict = {
k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
}
context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
return opentracing.tracer.start_active_span(
operation_name,
child_of=context,
references=references,
tags=tags,
start_time=start_time,
ignore_active_span=ignore_active_span,
finish_on_close=finish_on_close,
)
def start_active_span_from_edu(
edu_content,
operation_name,
references: Optional[list] = None,
tags=None,
start_time=None,
ignore_active_span=False,
finish_on_close=True,
edu_content: JsonDict,
operation_name: str,
references: Optional[List["Reference"]] = None,
tags: Optional[dict] = None,
start_time: Optional[float] = None,
ignore_active_span: bool = False,
finish_on_close: bool = True,
):
"""
Extracts a span context from an edu and uses it to start a new active span
@@ -519,6 +602,7 @@ def start_active_span_from_edu(
"opentracing", {}
)
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
assert span_context_from_string is not None
_references = [
opentracing.child_of(span_context_from_string(x))
for x in carrier.get("references", [])
@@ -546,53 +630,53 @@ def start_active_span_from_edu(
# Opentracing setters for tags, logs, etc
@only_if_tracing
def active_span():
def active_span() -> Optional[Span]:
"""Get the currently active span, if any"""
return opentracing.tracer.active_span
@ensure_active_span("set a tag")
def set_tag(key, value):
def set_tag(key: str, value: Union[str, bool, int, float]) -> None:
"""Sets a tag on the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_tag(key, value)
@ensure_active_span("log")
def log_kv(key_values, timestamp=None):
def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None:
"""Log to the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.log_kv(key_values, timestamp)
@ensure_active_span("set the traces operation name")
def set_operation_name(operation_name):
def set_operation_name(operation_name: str) -> None:
"""Sets the operation name of the active span"""
assert opentracing.tracer.active_span is not None
opentracing.tracer.active_span.set_operation_name(operation_name)
@only_if_tracing
def force_tracing(span=Sentinel) -> None:
def force_tracing(span: Union[object, "Span", None] = Sentinel) -> None:
"""Force sampling for the active/given span and its children.
Args:
span: span to force tracing for. By default, the active span.
"""
if span is Sentinel:
span = opentracing.tracer.active_span
if span is None:
logger.error("No active span in force_tracing")
return
if span is Sentinel:
span = opentracing.tracer.active_span
span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1)
span.set_tag(opentracing.tags.SAMPLING_PRIORITY, 1) # type: ignore[attr-defined]
# also set a bit of baggage, so that we have a way of figuring out if
# it is enabled later
span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1")
span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") # type: ignore[attr-defined]
def is_context_forced_tracing(span_context) -> bool:
def is_context_forced_tracing(span_context: Optional["SpanContext"]) -> bool:
"""Check if sampling has been force for the given span context."""
if span_context is None:
return False
@@ -631,6 +715,7 @@ def inject_header_dict(
raise ValueError(
"destination must be given unless check_destination is False"
)
assert whitelisted_homeserver is not None
if not whitelisted_homeserver(destination):
return
@@ -663,8 +748,10 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
@ensure_active_span("get the active span context as a dict", ret={})
def get_active_span_text_map(destination=None):
@ensure_active_span(
"get the active span context as a dict", ret=cast(Dict[str, str], {})
)
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
injecting a span into an empty carrier.
@@ -676,6 +763,7 @@ def get_active_span_text_map(destination=None):
dict: the active span's context if opentracing is enabled, otherwise empty.
"""
assert whitelisted_homeserver is not None
if destination and not whitelisted_homeserver(destination):
return {}
@@ -688,8 +776,8 @@ def get_active_span_text_map(destination=None):
return carrier
@ensure_active_span("get the span context as a string.", ret={})
def active_span_context_as_string():
@ensure_active_span("get the span context as a string.", ret="{}")
def active_span_context_as_string() -> str:
"""
Returns:
The active span context encoded as a string.
@@ -703,7 +791,7 @@ def active_span_context_as_string():
return json_encoder.encode(carrier)
def span_context_from_request(request: Request) -> "Optional[opentracing.SpanContext]":
def span_context_from_request(request: Request) -> "Optional['SpanContext']":
"""Extract an opentracing context from the headers on an HTTP request
This is useful when we have received an HTTP request from another part of our
@@ -718,17 +806,17 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon
@only_if_tracing
def span_context_from_string(carrier):
def span_context_from_string(carrier: str) -> "SpanContext":
"""
Returns:
The active span context decoded from a string.
"""
carrier = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
dict_carrier = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, dict_carrier)
@only_if_tracing
def extract_text_map(carrier):
def extract_text_map(carrier: dict) -> "SpanContext":
"""
Wrapper method for opentracing's tracer.extract for TEXT_MAP.
Args: