mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-07 01:20:16 +00:00
Compare commits
5 Commits
madlittlem
...
quenting/r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a3962ace72 | ||
|
|
db0f7f8a70 | ||
|
|
140b9d9b0f | ||
|
|
7d888eabc1 | ||
|
|
6474bd8951 |
1
changelog.d/18564.misc
Normal file
1
changelog.d/18564.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove unnecessary HTTP replication calls and make retries round-robin accross workers when possible.
|
||||
@@ -85,7 +85,6 @@ from synapse.logging.opentracing import (
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.replication.http.federation import (
|
||||
ReplicationFederationSendEduRestServlet,
|
||||
ReplicationGetQueryRestServlet,
|
||||
)
|
||||
from synapse.storage.databases.main.lock import Lock
|
||||
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
|
||||
@@ -1380,7 +1379,6 @@ class FederationHandlerRegistry:
|
||||
# and use them. However we have guards before we use them to ensure that
|
||||
# we don't route to ourselves, and in monolith mode that will always be
|
||||
# the case.
|
||||
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
|
||||
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
|
||||
|
||||
self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {}
|
||||
@@ -1450,11 +1448,8 @@ class FederationHandlerRegistry:
|
||||
# Check if we can route it somewhere else that isn't us
|
||||
instances = self._edu_type_to_instance.get(edu_type, ["master"])
|
||||
if self._instance_name not in instances:
|
||||
# Pick an instance randomly so that we don't overload one.
|
||||
route_to = random.choice(instances)
|
||||
|
||||
await self._send_edu(
|
||||
instance_name=route_to,
|
||||
instances=instances,
|
||||
edu_type=edu_type,
|
||||
origin=origin,
|
||||
content=content,
|
||||
@@ -1469,10 +1464,6 @@ class FederationHandlerRegistry:
|
||||
if handler:
|
||||
return await handler(args)
|
||||
|
||||
# Check if we can route it somewhere else that isn't us
|
||||
if self._instance_name == "master":
|
||||
return await self._get_query_client(query_type=query_type, args=args)
|
||||
|
||||
# Uh oh, no handler! Let's raise an exception so the request returns an
|
||||
# error.
|
||||
logger.warning("No handler registered for query type %s", query_type)
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
#
|
||||
#
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import AccountDataTypes
|
||||
@@ -133,7 +132,7 @@ class AccountDataHandler:
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._add_room_data_client(
|
||||
instance_name=random.choice(self._account_data_writers),
|
||||
instances=self._account_data_writers,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
account_data_type=account_data_type,
|
||||
@@ -174,7 +173,7 @@ class AccountDataHandler:
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._remove_room_data_client(
|
||||
instance_name=random.choice(self._account_data_writers),
|
||||
instances=self._account_data_writers,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
account_data_type=account_data_type,
|
||||
@@ -210,7 +209,7 @@ class AccountDataHandler:
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._add_user_data_client(
|
||||
instance_name=random.choice(self._account_data_writers),
|
||||
instances=self._account_data_writers,
|
||||
user_id=user_id,
|
||||
account_data_type=account_data_type,
|
||||
content=content,
|
||||
@@ -246,7 +245,7 @@ class AccountDataHandler:
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._remove_user_data_client(
|
||||
instance_name=random.choice(self._account_data_writers),
|
||||
instances=self._account_data_writers,
|
||||
user_id=user_id,
|
||||
account_data_type=account_data_type,
|
||||
)
|
||||
@@ -277,7 +276,7 @@ class AccountDataHandler:
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._add_tag_client(
|
||||
instance_name=random.choice(self._account_data_writers),
|
||||
instances=self._account_data_writers,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
tag=tag,
|
||||
@@ -302,7 +301,7 @@ class AccountDataHandler:
|
||||
return max_stream_id
|
||||
else:
|
||||
response = await self._remove_tag_client(
|
||||
instance_name=random.choice(self._account_data_writers),
|
||||
instances=self._account_data_writers,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
tag=tag,
|
||||
|
||||
@@ -20,7 +20,6 @@ from twisted.internet.interfaces import IDelayedCall
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import ShadowBanError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
|
||||
from synapse.logging.opentracing import set_tag
|
||||
from synapse.metrics import event_processing_positions
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
@@ -290,10 +289,7 @@ class DelayedEventsHandler:
|
||||
if self._repl_client is not None:
|
||||
# NOTE: If this throws, the delayed event will remain in the DB and
|
||||
# will be picked up once the main worker gets another delayed event.
|
||||
await self._repl_client(
|
||||
instance_name=MAIN_PROCESS_INSTANCE_NAME,
|
||||
next_send_ts=next_send_ts,
|
||||
)
|
||||
await self._repl_client(next_send_ts=next_send_ts)
|
||||
elif self._next_send_ts_changed(next_send_ts):
|
||||
self._schedule_next_at(next_send_ts)
|
||||
|
||||
|
||||
@@ -73,10 +73,6 @@ from synapse.logging.context import nested_logging_context
|
||||
from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import NOT_SPAM
|
||||
from synapse.replication.http.federation import (
|
||||
ReplicationCleanRoomRestServlet,
|
||||
ReplicationStoreRoomOnOutlierMembershipRestServlet,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.invite_rule import InviteRule
|
||||
from synapse.types import JsonDict, StrCollection, get_domain_from_id
|
||||
@@ -163,19 +159,6 @@ class FederationHandler:
|
||||
self._notifier = hs.get_notifier()
|
||||
self._worker_locks = hs.get_worker_locks_handler()
|
||||
|
||||
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
|
||||
hs
|
||||
)
|
||||
|
||||
if hs.config.worker.worker_app:
|
||||
self._maybe_store_room_on_outlier_membership = (
|
||||
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
|
||||
)
|
||||
else:
|
||||
self._maybe_store_room_on_outlier_membership = (
|
||||
self.store.maybe_store_room_on_outlier_membership
|
||||
)
|
||||
|
||||
self._room_backfill = Linearizer("room_backfill")
|
||||
|
||||
self._third_party_event_rules = (
|
||||
@@ -857,7 +840,7 @@ class FederationHandler:
|
||||
event.internal_metadata.out_of_band_membership = True
|
||||
|
||||
# Record the room ID and its version so that we have a record of the room
|
||||
await self._maybe_store_room_on_outlier_membership(
|
||||
await self.store.maybe_store_room_on_outlier_membership(
|
||||
room_id=event.room_id, room_version=event_format_version
|
||||
)
|
||||
|
||||
@@ -1115,7 +1098,7 @@ class FederationHandler:
|
||||
# keep a record of the room version, if we don't yet know it.
|
||||
# (this may get overwritten if we later get a different room version in a
|
||||
# join dance).
|
||||
await self._maybe_store_room_on_outlier_membership(
|
||||
await self.store.maybe_store_room_on_outlier_membership(
|
||||
room_id=event.room_id, room_version=room_version
|
||||
)
|
||||
|
||||
@@ -1768,10 +1751,7 @@ class FederationHandler:
|
||||
Args:
|
||||
room_id
|
||||
"""
|
||||
if self.config.worker.worker_app:
|
||||
await self._clean_room_for_join_client(room_id)
|
||||
else:
|
||||
await self.store.clean_room_for_join(room_id)
|
||||
await self.store.clean_room_for_join(room_id)
|
||||
|
||||
async def get_room_complexity(
|
||||
self, remote_room_hosts: List[str], room_id: str
|
||||
|
||||
@@ -2259,7 +2259,7 @@ class FederationEventHandler:
|
||||
try:
|
||||
for batch in batch_iter(event_and_contexts, 200):
|
||||
result = await self._send_events(
|
||||
instance_name=instance,
|
||||
instances=[instance],
|
||||
store=self._store,
|
||||
room_id=room_id,
|
||||
event_and_contexts=batch,
|
||||
|
||||
@@ -1578,7 +1578,7 @@ class EventCreationHandler:
|
||||
|
||||
try:
|
||||
result = await self.send_events(
|
||||
instance_name=writer_instance,
|
||||
instances=[writer_instance],
|
||||
events_and_context=events_and_context,
|
||||
store=self.store,
|
||||
requester=requester,
|
||||
|
||||
@@ -484,7 +484,7 @@ class _NullContextManager(ContextManager[None]):
|
||||
class WorkerPresenceHandler(BasePresenceHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._presence_writer_instance = hs.config.worker.writers.presence[0]
|
||||
self._presence_writer_instances = hs.config.worker.writers.presence
|
||||
|
||||
# Route presence EDUs to the right worker
|
||||
hs.get_federation_registry().register_instances_for_edu(
|
||||
@@ -717,7 +717,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||
|
||||
# Proxy request to instance that writes presence
|
||||
await self._set_state_client(
|
||||
instance_name=self._presence_writer_instance,
|
||||
instances=self._presence_writer_instances,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
state=state,
|
||||
@@ -738,7 +738,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
|
||||
# Proxy request to instance that writes presence
|
||||
user_id = user.to_string()
|
||||
await self._bump_active_client(
|
||||
instance_name=self._presence_writer_instance,
|
||||
instances=self._presence_writer_instances,
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
)
|
||||
@@ -2476,7 +2476,7 @@ class PresenceFederationQueue:
|
||||
# If not local we query over http replication from the presence
|
||||
# writer
|
||||
result = await self._repl_client(
|
||||
instance_name=instance_name,
|
||||
instances=[instance_name],
|
||||
stream_name=PresenceFederationStream.NAME,
|
||||
from_token=from_token,
|
||||
upto_token=upto_token,
|
||||
|
||||
@@ -196,7 +196,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
self._is_push_writer = (
|
||||
hs.get_instance_name() in hs.config.worker.writers.push_rules
|
||||
)
|
||||
self._push_writer = hs.config.worker.writers.push_rules[0]
|
||||
self._push_writers = hs.config.worker.writers.push_rules
|
||||
self._copy_push_client = ReplicationCopyPusherRestServlet.make_client(hs)
|
||||
|
||||
def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
|
||||
@@ -1414,7 +1414,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
)
|
||||
else:
|
||||
await self._copy_push_client(
|
||||
instance_name=self._push_writer,
|
||||
instances=self._push_writers,
|
||||
user_id=user_id,
|
||||
old_room_id=old_room_id,
|
||||
new_room_id=new_room_id,
|
||||
|
||||
@@ -23,7 +23,17 @@ import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from inspect import signature
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from prometheus_client import Counter, Gauge
|
||||
|
||||
@@ -85,7 +95,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
|
||||
Requests can be sent by calling the client returned by `make_client`.
|
||||
Requests are sent to master process by default, but can be sent to other
|
||||
named processes by specifying an `instance_name` keyword argument.
|
||||
named processes by specifying an `instances` keyword argument.
|
||||
|
||||
Attributes:
|
||||
NAME (str): A name for the endpoint, added to the path as well as used
|
||||
@@ -126,15 +136,14 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
|
||||
)
|
||||
|
||||
# We reserve `instance_name` as a parameter to sending requests, so we
|
||||
# We reserve `instances` as a parameter to sending requests, so we
|
||||
# assert here that sub classes don't try and use the name.
|
||||
assert "instance_name" not in self.PATH_ARGS, (
|
||||
"`instance_name` is a reserved parameter name"
|
||||
assert "instances" not in self.PATH_ARGS, (
|
||||
"`instances` is a reserved parameter name"
|
||||
)
|
||||
assert (
|
||||
"instance_name"
|
||||
not in signature(self.__class__._serialize_payload).parameters
|
||||
), "`instance_name` is a reserved parameter name"
|
||||
"instances" not in signature(self.__class__._serialize_payload).parameters
|
||||
), "`instances` is a reserved parameter name"
|
||||
|
||||
assert self.METHOD in ("PUT", "POST", "GET")
|
||||
|
||||
@@ -163,8 +172,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
|
||||
raise RuntimeError("Invalid Authorization header.")
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
async def _serialize_payload(**kwargs) -> JsonDict:
|
||||
async def _serialize_payload(**kwargs: Any) -> JsonDict:
|
||||
"""Static method that is called when creating a request.
|
||||
|
||||
Concrete implementations should have explicit parameters (rather than
|
||||
@@ -196,14 +206,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
"""Create a client that makes requests.
|
||||
|
||||
Returns a callable that accepts the same parameters as
|
||||
`_serialize_payload`, and also accepts an optional `instance_name`
|
||||
parameter to specify which instance to hit (the instance must be in
|
||||
the `instance_map` config).
|
||||
`_serialize_payload`, and also accepts an optional `instances` parameter
|
||||
to specify which instances to hit (the instances must be in the
|
||||
`instance_map` config).
|
||||
"""
|
||||
clock = hs.get_clock()
|
||||
client = hs.get_replication_client()
|
||||
local_instance_name = hs.get_instance_name()
|
||||
|
||||
# This is the current index on the instance pool, so that we round-robin between instances
|
||||
instance_pool_index = 0
|
||||
|
||||
instance_map = hs.config.worker.instance_map
|
||||
|
||||
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
|
||||
@@ -216,19 +229,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
|
||||
@trace_with_opname("outgoing_replication_request")
|
||||
async def send_request(
|
||||
*, instance_name: str = MAIN_PROCESS_INSTANCE_NAME, **kwargs: Any
|
||||
*, instances: Optional[list[str]] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
# We have to pull these out here to avoid circular dependencies...
|
||||
streams = hs.get_replication_command_handler().get_streams_to_replicate()
|
||||
replication = hs.get_replication_data_handler()
|
||||
|
||||
# If no instances were given, route to the main process
|
||||
instances = instances or [MAIN_PROCESS_INSTANCE_NAME]
|
||||
|
||||
with outgoing_gauge.track_inprogress():
|
||||
if instance_name == local_instance_name:
|
||||
raise Exception("Trying to send HTTP request to self")
|
||||
if instance_name not in instance_map:
|
||||
raise Exception(
|
||||
"Instance %r not in 'instance_map' config" % (instance_name,)
|
||||
)
|
||||
for instance_name in instances:
|
||||
if instance_name == local_instance_name:
|
||||
raise Exception("Trying to send HTTP request to self")
|
||||
if instance_name not in instance_map:
|
||||
raise Exception(
|
||||
"Instance %r not in 'instance_map' config"
|
||||
% (instance_name,)
|
||||
)
|
||||
|
||||
data = await cls._serialize_payload(**kwargs)
|
||||
|
||||
@@ -273,15 +291,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
"Unknown METHOD on %s replication endpoint" % (cls.NAME,)
|
||||
)
|
||||
|
||||
# Hard code a special scheme to show this only used for replication. The
|
||||
# instance_name will be passed into the ReplicationEndpointFactory to
|
||||
# determine connection details from the instance_map.
|
||||
uri = "synapse-replication://%s/_synapse/replication/%s/%s" % (
|
||||
instance_name,
|
||||
cls.NAME,
|
||||
"/".join(url_args),
|
||||
)
|
||||
|
||||
headers: Dict[bytes, List[bytes]] = {}
|
||||
# Add an authorization header, if configured.
|
||||
if replication_secret:
|
||||
@@ -292,10 +301,30 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
# Keep track of attempts made so we can bail if we don't manage to
|
||||
# connect to the target after N tries.
|
||||
attempts = 0
|
||||
|
||||
# We keep retrying the same request for timeouts. This is so that we
|
||||
# have a good idea that the request has either succeeded or failed
|
||||
# on the master, and so whether we should clean up or not.
|
||||
while True:
|
||||
# We're modifying the variable on the upper scope. Note
|
||||
# that this isn't thread-safe, but we likely don't
|
||||
# really care if the round-robin isn't perfect.
|
||||
nonlocal instance_pool_index
|
||||
instance_pool_index += 1
|
||||
chosen_instance_name = instances[
|
||||
instance_pool_index % len(instances)
|
||||
]
|
||||
|
||||
# Hard code a special scheme to show this only used for
|
||||
# replication. The instance_name will be passed into the
|
||||
# ReplicationEndpointFactory to determine connection
|
||||
# details from the instance_map.
|
||||
uri = "synapse-replication://%s/_synapse/replication/%s/%s" % (
|
||||
chosen_instance_name,
|
||||
cls.NAME,
|
||||
"/".join(url_args),
|
||||
)
|
||||
|
||||
try:
|
||||
result = await request_func(uri, data, headers=headers)
|
||||
break
|
||||
@@ -324,6 +353,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
|
||||
await clock.sleep(delay)
|
||||
attempts += 1
|
||||
|
||||
except HttpResponseException as e:
|
||||
# We convert to SynapseError as we know that it was a SynapseError
|
||||
# on the main process that we should send to the client. (And
|
||||
@@ -333,7 +363,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
except Exception as e:
|
||||
_outgoing_request_counter.labels(cls.NAME, "ERR").inc()
|
||||
raise SynapseError(
|
||||
502, f"Failed to talk to {instance_name} process"
|
||||
502, f"Failed to talk to {instances} process"
|
||||
) from e
|
||||
|
||||
_outgoing_request_counter.labels(cls.NAME, 200).inc()
|
||||
@@ -343,7 +373,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
_STREAM_POSITION_KEY, {}
|
||||
).items():
|
||||
await replication.wait_for_stream_position(
|
||||
instance_name=instance_name,
|
||||
instance_name=chosen_instance_name,
|
||||
stream_name=stream_name,
|
||||
position=position,
|
||||
)
|
||||
|
||||
@@ -578,7 +578,7 @@ class StateHandler:
|
||||
writer_instance = self._events_shard_config.get_instance(room_id)
|
||||
if writer_instance != self._instance_name:
|
||||
await self._update_current_state_client(
|
||||
instance_name=writer_instance,
|
||||
instances=[writer_instance],
|
||||
room_id=room_id,
|
||||
)
|
||||
return
|
||||
|
||||
@@ -41,7 +41,6 @@ from synapse.storage.database import (
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.events import SLIDING_SYNC_RELEVANT_STATE_SET
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.util.caches.descriptors import CachedFunction
|
||||
@@ -284,6 +283,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
|
||||
# This is needed to avoid a circular import.
|
||||
from synapse.storage.databases.main.events import (
|
||||
SLIDING_SYNC_RELEVANT_STATE_SET,
|
||||
)
|
||||
|
||||
data = row.data
|
||||
|
||||
if row.type == EventsStreamEventRow.TypeId:
|
||||
@@ -347,6 +351,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
relates_to: Optional[str],
|
||||
backfilled: bool,
|
||||
) -> None:
|
||||
# This is needed to avoid a circular import.
|
||||
from synapse.storage.databases.main.events import (
|
||||
SLIDING_SYNC_RELEVANT_STATE_SET,
|
||||
)
|
||||
|
||||
# XXX: If you add something to this function make sure you add it to
|
||||
# `_invalidate_caches_for_room_events` as well.
|
||||
|
||||
|
||||
@@ -46,13 +46,14 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.logging.opentracing import tag_args, trace
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.background_updates import ForeignKeyConstraint
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
@@ -123,7 +124,9 @@ class _NoChainCoverIndex(Exception):
|
||||
super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
|
||||
|
||||
|
||||
class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore):
|
||||
class EventFederationWorkerStore(
|
||||
SignatureWorkerStore, EventsWorkerStore, CacheInvalidationWorkerStore
|
||||
):
|
||||
# TODO: this attribute comes from EventPushActionWorkerStore. Should we inherit from
|
||||
# that store so that mypy can deduce this for itself?
|
||||
stream_ordering_month_ago: Optional[int]
|
||||
@@ -2053,6 +2056,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
|
||||
number_pdus_in_federation_queue.set(count)
|
||||
oldest_pdu_in_federation_staging.set(age)
|
||||
|
||||
async def clean_room_for_join(self, room_id: str) -> None:
|
||||
await self.db_pool.runInteraction(
|
||||
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
||||
)
|
||||
|
||||
def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
|
||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_latest_event_ids_in_room, (room_id,)
|
||||
)
|
||||
|
||||
|
||||
class EventFederationStore(EventFederationWorkerStore):
|
||||
"""Responsible for storing and serving up the various graphs associated
|
||||
@@ -2078,17 +2094,6 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
|
||||
)
|
||||
|
||||
async def clean_room_for_join(self, room_id: str) -> None:
|
||||
await self.db_pool.runInteraction(
|
||||
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
||||
)
|
||||
|
||||
def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None:
|
||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||
|
||||
async def _background_delete_non_state_event_auth(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
|
||||
@@ -1935,6 +1935,65 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
desc="set_room_is_public_appservice_false",
|
||||
)
|
||||
|
||||
async def has_auth_chain_index(self, room_id: str) -> bool:
|
||||
"""Check if the room has (or can have) a chain cover index.
|
||||
|
||||
Defaults to True if we don't have an entry in `rooms` table nor any
|
||||
events for the room.
|
||||
"""
|
||||
|
||||
has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="has_auth_chain_index",
|
||||
desc="has_auth_chain_index",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if has_auth_chain_index:
|
||||
return True
|
||||
|
||||
# It's possible that we already have events for the room in our DB
|
||||
# without a corresponding room entry. If we do then we don't want to
|
||||
# mark the room as having an auth chain cover index.
|
||||
max_ordering = await self.db_pool.simple_select_one_onecol(
|
||||
table="events",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="MAX(stream_ordering)",
|
||||
allow_none=True,
|
||||
desc="has_auth_chain_index_fallback",
|
||||
)
|
||||
|
||||
return max_ordering is None
|
||||
|
||||
async def maybe_store_room_on_outlier_membership(
|
||||
self, room_id: str, room_version: RoomVersion
|
||||
) -> None:
|
||||
"""
|
||||
When we receive an invite or any other event over federation that may relate to a room
|
||||
we are not in, store the version of the room if we don't already know the room version.
|
||||
"""
|
||||
# It's possible that we already have events for the room in our DB
|
||||
# without a corresponding room entry. If we do then we don't want to
|
||||
# mark the room as having an auth chain cover index.
|
||||
has_auth_chain_index = await self.has_auth_chain_index(room_id)
|
||||
|
||||
await self.db_pool.simple_upsert(
|
||||
desc="maybe_store_room_on_outlier_membership",
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
values={},
|
||||
insertion_values={
|
||||
"room_version": room_version.identifier,
|
||||
"is_public": False,
|
||||
# We don't worry about setting the `creator` here because
|
||||
# we don't process any messages in a room while a user is
|
||||
# invited (only after the join).
|
||||
"creator": "",
|
||||
"has_auth_chain_index": has_auth_chain_index,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class _BackgroundUpdates:
|
||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||
@@ -2186,37 +2245,6 @@ class RoomBackgroundUpdateStore(RoomWorkerStore):
|
||||
|
||||
return len(rooms)
|
||||
|
||||
async def has_auth_chain_index(self, room_id: str) -> bool:
|
||||
"""Check if the room has (or can have) a chain cover index.
|
||||
|
||||
Defaults to True if we don't have an entry in `rooms` table nor any
|
||||
events for the room.
|
||||
"""
|
||||
|
||||
has_auth_chain_index = await self.db_pool.simple_select_one_onecol(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="has_auth_chain_index",
|
||||
desc="has_auth_chain_index",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if has_auth_chain_index:
|
||||
return True
|
||||
|
||||
# It's possible that we already have events for the room in our DB
|
||||
# without a corresponding room entry. If we do then we don't want to
|
||||
# mark the room as having an auth chain cover index.
|
||||
max_ordering = await self.db_pool.simple_select_one_onecol(
|
||||
table="events",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="MAX(stream_ordering)",
|
||||
allow_none=True,
|
||||
desc="has_auth_chain_index_fallback",
|
||||
)
|
||||
|
||||
return max_ordering is None
|
||||
|
||||
async def _background_populate_room_depth_min_depth2(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
@@ -2566,34 +2594,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||
updatevalues={"join_event_id": join_event_id},
|
||||
)
|
||||
|
||||
async def maybe_store_room_on_outlier_membership(
|
||||
self, room_id: str, room_version: RoomVersion
|
||||
) -> None:
|
||||
"""
|
||||
When we receive an invite or any other event over federation that may relate to a room
|
||||
we are not in, store the version of the room if we don't already know the room version.
|
||||
"""
|
||||
# It's possible that we already have events for the room in our DB
|
||||
# without a corresponding room entry. If we do then we don't want to
|
||||
# mark the room as having an auth chain cover index.
|
||||
has_auth_chain_index = await self.has_auth_chain_index(room_id)
|
||||
|
||||
await self.db_pool.simple_upsert(
|
||||
desc="maybe_store_room_on_outlier_membership",
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
values={},
|
||||
insertion_values={
|
||||
"room_version": room_version.identifier,
|
||||
"is_public": False,
|
||||
# We don't worry about setting the `creator` here because
|
||||
# we don't process any messages in a room while a user is
|
||||
# invited (only after the join).
|
||||
"creator": "",
|
||||
"has_auth_chain_index": has_auth_chain_index,
|
||||
},
|
||||
)
|
||||
|
||||
async def add_event_report(
|
||||
self,
|
||||
room_id: str,
|
||||
|
||||
Reference in New Issue
Block a user