Compare commits

...

3 Commits

Author SHA1 Message Date
Erik Johnston
930f5b0a25 Use *args 2024-10-04 16:21:40 +01:00
Erik Johnston
d56af0d83c Fixup 2024-10-04 15:47:27 +01:00
Erik Johnston
5039693b0a Run extensions in parallel 2024-10-04 15:43:46 +01:00
3 changed files with 150 additions and 12 deletions

View File

@@ -49,7 +49,10 @@ from synapse.types.handlers.sliding_sync import (
SlidingSyncConfig,
SlidingSyncResult,
)
from synapse.util.async_helpers import concurrently_execute
from synapse.util.async_helpers import (
concurrently_execute,
gather_optional_coroutines,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -97,26 +100,26 @@ class SlidingSyncExtensionHandler:
if sync_config.extensions is None:
return SlidingSyncResult.Extensions()
to_device_response = None
to_device_coro = None
if sync_config.extensions.to_device is not None:
to_device_response = await self.get_to_device_extension_response(
to_device_coro = self.get_to_device_extension_response(
sync_config=sync_config,
to_device_request=sync_config.extensions.to_device,
to_token=to_token,
)
e2ee_response = None
e2ee_coro = None
if sync_config.extensions.e2ee is not None:
e2ee_response = await self.get_e2ee_extension_response(
e2ee_coro = self.get_e2ee_extension_response(
sync_config=sync_config,
e2ee_request=sync_config.extensions.e2ee,
to_token=to_token,
from_token=from_token,
)
account_data_response = None
account_data_coro = None
if sync_config.extensions.account_data is not None:
account_data_response = await self.get_account_data_extension_response(
account_data_coro = self.get_account_data_extension_response(
sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
@@ -127,9 +130,9 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
receipts_response = None
receipts_coro = None
if sync_config.extensions.receipts is not None:
receipts_response = await self.get_receipts_extension_response(
receipts_coro = self.get_receipts_extension_response(
sync_config=sync_config,
previous_connection_state=previous_connection_state,
new_connection_state=new_connection_state,
@@ -141,9 +144,9 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
typing_response = None
typing_coro = None
if sync_config.extensions.typing is not None:
typing_response = await self.get_typing_extension_response(
typing_coro = self.get_typing_extension_response(
sync_config=sync_config,
actual_lists=actual_lists,
actual_room_ids=actual_room_ids,
@@ -153,6 +156,20 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
(
to_device_response,
e2ee_response,
account_data_response,
receipts_response,
typing_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
account_data_coro,
receipts_coro,
typing_coro,
)
return SlidingSyncResult.Extensions(
to_device=to_device_response,
e2ee=e2ee_response,

View File

@@ -37,6 +37,7 @@ import warnings
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Optional,
@@ -850,6 +851,32 @@ def run_in_background(
return d
def run_coroutine_in_background(
coroutine: typing.Coroutine[Any, Any, R],
) -> "defer.Deferred[R]":
current = current_context()
d = defer.ensureDeferred(coroutine)
# The function may have reset the context before returning, so
# we need to restore it now.
ctx = set_current_context(current)
# The original context will be restored when the deferred
# completes, but there is nothing waiting for it, so it will
# get leaked into the reactor or some other function which
# wasn't expecting it. We therefore need to reset the context
# here.
#
# (If this feels asymmetric, consider it this way: we are
# effectively forking a new thread of execution. We are
# probably currently within a ``with LoggingContext()`` block,
# which is supposed to have a single entry and exit point. But
# by spawning off another deferred, we are effectively
# adding a new exit point.)
d.addBoth(_set_context_cb, ctx)
return d
T = TypeVar("T")

View File

@@ -51,7 +51,7 @@ from typing import (
)
import attr
from typing_extensions import Concatenate, Literal, ParamSpec
from typing_extensions import Concatenate, Literal, ParamSpec, Unpack
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -61,6 +61,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_coroutine_in_background,
run_in_background,
)
from synapse.util import Clock
@@ -344,6 +345,7 @@ T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
@overload
@@ -402,6 +404,98 @@ def gather_results( # type: ignore[misc]
return deferred.addCallback(tuple)
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]]]],
) -> Tuple[Optional[T1]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
]
],
) -> Tuple[Optional[T1], Optional[T2]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
]
],
) -> Tuple[Optional[T1], Optional[T2], Optional[T3]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
]
],
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
Optional[Coroutine[Any, Any, T5]],
]
],
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
) -> Tuple[Optional[T1], ...]:
try:
results = await make_deferred_yieldable(
defer.gatherResults(
[
run_coroutine_in_background(coroutine)
for coroutine in coroutines
if coroutine
],
consumeErrors=True,
)
)
results_iter = iter(results)
return tuple(
next(results_iter) if coroutine else None for coroutine in coroutines
)
except defer.FirstError as dfe:
# unwrap the error from defer.gatherResults.
# The raised exception's traceback only includes func() etc if
# the 'await' happens before the exception is thrown - ie if the failure
# happens *asynchronously* - otherwise Twisted throws away the traceback as it
# could be large.
#
# We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
# we could throw Twisted into the fires of Mordor.
# suppress exception chaining, because the FirstError doesn't tell us anything
# very interesting.
assert isinstance(dfe.subFailure.value, BaseException)
raise dfe.subFailure.value from None
@attr.s(slots=True, auto_attribs=True)
class _LinearizerEntry:
# The number of things executing.