mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-05 01:10:13 +00:00
Compare commits
3 Commits
quenting/l
...
erikj/pers
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1e05b033af | ||
|
|
4499d81adf | ||
|
|
a4dde1f23c |
@@ -29,7 +29,6 @@ from synapse.api.errors import (
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import (
|
||||
RoomStreamToken,
|
||||
StreamToken,
|
||||
get_domain_from_id,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
@@ -113,8 +112,7 @@ class DeviceWorkerHandler(BaseHandler):
|
||||
|
||||
set_tag("user_id", user_id)
|
||||
set_tag("from_token", from_token)
|
||||
now_room_id = self.store.get_room_max_stream_ordering()
|
||||
now_room_key = RoomStreamToken(None, now_room_id)
|
||||
now_room_key = self.store.get_room_max_token()
|
||||
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
|
||||
|
||||
@@ -1141,7 +1141,7 @@ class RoomEventSource:
|
||||
return (events, end_key)
|
||||
|
||||
def get_current_key(self) -> RoomStreamToken:
|
||||
return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
|
||||
return self.store.get_room_max_token()
|
||||
|
||||
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
|
||||
return self.store.get_room_events_max_id(room_id)
|
||||
|
||||
@@ -163,7 +163,7 @@ class _NotifierUserStream:
|
||||
"""
|
||||
# Immediately wake up stream if something has already since happened
|
||||
# since their last token.
|
||||
if self.last_notified_token.is_after(token):
|
||||
if self.last_notified_token != token:
|
||||
return _NotificationListener(defer.succeed(self.current_token))
|
||||
else:
|
||||
return _NotificationListener(self.notify_deferred.observe())
|
||||
@@ -470,7 +470,7 @@ class Notifier:
|
||||
async def check_for_updates(
|
||||
before_token: StreamToken, after_token: StreamToken
|
||||
) -> EventStreamResult:
|
||||
if not after_token.is_after(before_token):
|
||||
if after_token == before_token:
|
||||
return EventStreamResult([], (from_token, from_token))
|
||||
|
||||
events = [] # type: List[EventBase]
|
||||
|
||||
@@ -77,6 +77,7 @@ REQUIREMENTS = [
|
||||
"Jinja2>=2.9",
|
||||
"bleach>=1.4.3",
|
||||
"typing-extensions>=3.7.4",
|
||||
"cbor2",
|
||||
]
|
||||
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
|
||||
@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
|
||||
EventsStreamEventRow,
|
||||
EventsStreamRow,
|
||||
)
|
||||
from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
|
||||
from synapse.types import PersistedEventPosition, UserID
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@@ -152,9 +152,7 @@ class ReplicationDataHandler:
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (UserID.from_string(event.state_key),)
|
||||
|
||||
max_token = RoomStreamToken(
|
||||
None, self.store.get_room_max_stream_ordering()
|
||||
)
|
||||
max_token = self.store.get_room_max_token()
|
||||
event_pos = PersistedEventPosition(instance_name, token)
|
||||
self.notifier.on_new_room_event(
|
||||
event, event_pos, max_token, extra_users
|
||||
|
||||
@@ -171,6 +171,37 @@ class PositionCommand(Command):
|
||||
return " ".join((self.stream_name, self.instance_name, str(self.token)))
|
||||
|
||||
|
||||
class PersistedToCommand(Command):
|
||||
"""Sent by writers to inform others that it has persisted up to the included
|
||||
token.
|
||||
|
||||
The included `token` will *not* have been persisted by the instance.
|
||||
|
||||
Format::
|
||||
|
||||
PERSISTED_TO <stream_name> <instance_name> <token>
|
||||
|
||||
On receipt the client should mark that the given instances has persisted
|
||||
everything up to the given token. Note: this does *not* mean that other
|
||||
instances have also persisted all their rows up to that point.
|
||||
"""
|
||||
|
||||
NAME = "PERSISTED_TO"
|
||||
|
||||
def __init__(self, stream_name, instance_name, token):
|
||||
self.stream_name = stream_name
|
||||
self.instance_name = instance_name
|
||||
self.token = token
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, instance_name, token = line.split(" ", 2)
|
||||
return cls(stream_name, instance_name, int(token))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.stream_name, self.instance_name, str(self.token)))
|
||||
|
||||
|
||||
class ErrorCommand(_SimpleCommand):
|
||||
"""Sent by either side if there was an ERROR. The data is a string describing
|
||||
the error.
|
||||
@@ -405,6 +436,7 @@ _COMMANDS = (
|
||||
UserIpCommand,
|
||||
RemoteServerUpCommand,
|
||||
ClearUserSyncsCommand,
|
||||
PersistedToCommand,
|
||||
) # type: Tuple[Type[Command], ...]
|
||||
|
||||
# Map of command name to command type.
|
||||
|
||||
@@ -47,6 +47,7 @@ from synapse.replication.tcp.commands import (
|
||||
ReplicateCommand,
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
PersistedToCommand,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import AbstractConnection
|
||||
from synapse.replication.tcp.streams import (
|
||||
@@ -387,6 +388,9 @@ class ReplicationCommandHandler:
|
||||
assert self._server_notices_sender is not None
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
def on_PERSISTED_TO(self, conn: AbstractConnection, cmd: PersistedToCommand):
|
||||
pass
|
||||
|
||||
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore RDATA that are just our own echoes
|
||||
|
||||
@@ -24,6 +24,7 @@ from twisted.internet.protocol import Factory
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
||||
from synapse.replication.tcp.streams import EventsStream
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
stream_updates_counter = Counter(
|
||||
@@ -84,6 +85,9 @@ class ReplicationStreamer:
|
||||
# Set of streams to replicate.
|
||||
self.streams = self.command_handler.get_streams_to_replicate()
|
||||
|
||||
if self.streams:
|
||||
self.clock.looping_call(self.on_notifier_poke, 1000.0)
|
||||
|
||||
def on_notifier_poke(self):
|
||||
"""Checks if there is actually any new data and sends it to the
|
||||
connections if there are.
|
||||
@@ -126,9 +130,7 @@ class ReplicationStreamer:
|
||||
random.shuffle(all_streams)
|
||||
|
||||
for stream in all_streams:
|
||||
if stream.last_token == stream.current_token(
|
||||
self._instance_name
|
||||
):
|
||||
if not stream.has_updates():
|
||||
continue
|
||||
|
||||
if self._replication_torture_level:
|
||||
@@ -174,6 +176,11 @@ class ReplicationStreamer:
|
||||
except Exception:
|
||||
logger.exception("Failed to replicate")
|
||||
|
||||
# for command in stream.extra_commands(
|
||||
# sent_updates=bool(updates)
|
||||
# ):
|
||||
# self.command_handler.send_command(command)
|
||||
|
||||
logger.debug("No more pending updates, breaking poke loop")
|
||||
finally:
|
||||
self.pending_updates = False
|
||||
|
||||
@@ -31,6 +31,7 @@ from typing import (
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
from synapse.replication.tcp.commands import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
@@ -187,6 +188,12 @@ class Stream:
|
||||
)
|
||||
return updates, upto_token, limited
|
||||
|
||||
def has_updates(self) -> bool:
|
||||
return self.current_token(self.local_instance_name) != self.last_token
|
||||
|
||||
def extra_commands(self, sent_updates: bool) -> List[Command]:
|
||||
return []
|
||||
|
||||
|
||||
def current_token_without_instance(
|
||||
current_token: Callable[[], int]
|
||||
|
||||
@@ -19,7 +19,8 @@ from typing import List, Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
from synapse.replication.tcp.streams._base import Stream, StreamUpdateResult, Token
|
||||
from synapse.replication.tcp.commands import Command, PersistedToCommand
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
||||
@@ -222,3 +223,18 @@ class EventsStream(Stream):
|
||||
(typ, data) = row
|
||||
data = TypeToRow[typ].from_data(data)
|
||||
return EventsStreamRow(typ, data)
|
||||
|
||||
def has_updates(self) -> bool:
|
||||
return True
|
||||
|
||||
def extra_commands(self, sent_updates: bool) -> List[Command]:
|
||||
if sent_updates:
|
||||
return []
|
||||
|
||||
return [
|
||||
PersistedToCommand(
|
||||
self.NAME,
|
||||
self.local_instance_name,
|
||||
self._store._stream_id_gen.get_max_persisted_position_for_self(),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -178,6 +178,8 @@ class PersistEventsStore:
|
||||
)
|
||||
persist_event_counter.inc(len(events_and_contexts))
|
||||
|
||||
logger.debug("Finished persisting 1")
|
||||
|
||||
if not backfilled:
|
||||
# backfilled events have negative stream orderings, so we don't
|
||||
# want to set the event_persisted_position to that.
|
||||
@@ -185,6 +187,8 @@ class PersistEventsStore:
|
||||
events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||
)
|
||||
|
||||
logger.debug("Finished persisting 2")
|
||||
|
||||
for event, context in events_and_contexts:
|
||||
if context.app_service:
|
||||
origin_type = "local"
|
||||
@@ -198,6 +202,8 @@ class PersistEventsStore:
|
||||
|
||||
event_counter.labels(event.type, origin_type, origin_entity).inc()
|
||||
|
||||
logger.debug("Finished persisting 3")
|
||||
|
||||
for room_id, new_state in current_state_for_room.items():
|
||||
self.store.get_current_state_ids.prefill((room_id,), new_state)
|
||||
|
||||
@@ -206,6 +212,9 @@ class PersistEventsStore:
|
||||
(room_id,), list(latest_event_ids)
|
||||
)
|
||||
|
||||
logger.debug("Finished persisting 4")
|
||||
logger.debug("Finished persisting 5")
|
||||
|
||||
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
|
||||
"""Filter the supplied list of event_ids to get those which are prev_events of
|
||||
existing (non-outlier/rejected) events.
|
||||
|
||||
@@ -35,11 +35,10 @@ what sort order was used:
|
||||
- topological tokems: "t%d-%d", where the integers map to the topological
|
||||
and stream ordering columns respectively.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -54,6 +53,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import Collection, RoomStreamToken
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
@@ -76,6 +76,18 @@ _EventDictReturn = namedtuple(
|
||||
)
|
||||
|
||||
|
||||
def _filter_result(
|
||||
instance_name: str,
|
||||
stream_id: int,
|
||||
from_token: RoomStreamToken,
|
||||
to_token: RoomStreamToken,
|
||||
) -> bool:
|
||||
from_id = from_token.instance_map.get(instance_name, from_token.stream)
|
||||
to_id = to_token.instance_map.get(instance_name, to_token.stream)
|
||||
|
||||
return from_id < stream_id <= to_id
|
||||
|
||||
|
||||
def generate_pagination_where_clause(
|
||||
direction: str,
|
||||
column_names: Tuple[str, str],
|
||||
@@ -209,6 +221,71 @@ def _make_generic_sql_bound(
|
||||
)
|
||||
|
||||
|
||||
def _make_instance_filter_clause(
|
||||
direction: str,
|
||||
from_token: Optional[RoomStreamToken],
|
||||
to_token: Optional[RoomStreamToken],
|
||||
) -> Tuple[str, List[Any]]:
|
||||
if from_token and from_token.topological:
|
||||
from_token = None
|
||||
if to_token and to_token.topological:
|
||||
to_token = None
|
||||
|
||||
if not from_token and not to_token:
|
||||
return "", []
|
||||
|
||||
from_bound = ">=" if direction == "b" else "<"
|
||||
to_bound = "<" if direction == "b" else ">="
|
||||
|
||||
filter_clauses = []
|
||||
filter_args = [] # type: List[Any]
|
||||
|
||||
from_map = from_token.instance_map if from_token else {}
|
||||
to_map = to_token.instance_map if to_token else {}
|
||||
|
||||
default_from = from_token.stream if from_token else None
|
||||
default_to = to_token.stream if to_token else None
|
||||
|
||||
if default_from and default_to:
|
||||
filter_clauses.append(
|
||||
"(? %s stream_ordering AND ? %s stream_ordering)" % (from_bound, to_bound)
|
||||
)
|
||||
filter_args.extend((default_from, default_to,))
|
||||
elif default_from:
|
||||
filter_clauses.append("(? %s stream_ordering)" % (from_bound,))
|
||||
filter_args.extend((default_from,))
|
||||
elif default_to:
|
||||
filter_clauses.append("(? %s stream_ordering)" % (to_bound,))
|
||||
filter_args.extend((default_to,))
|
||||
|
||||
for instance in set(from_map).union(to_map):
|
||||
from_id = from_map.get(instance, default_from)
|
||||
to_id = to_map.get(instance, default_to)
|
||||
|
||||
if from_id and to_id:
|
||||
filter_clauses.append(
|
||||
"(instance_name = ? AND ? %s stream_ordering AND ? %s stream_ordering)"
|
||||
% (from_bound, to_bound)
|
||||
)
|
||||
filter_args.extend((instance, from_id, to_id,))
|
||||
elif from_id:
|
||||
filter_clauses.append(
|
||||
"(instance_name = ? AND ? %s stream_ordering)" % (from_bound,)
|
||||
)
|
||||
filter_args.extend((instance, from_id,))
|
||||
elif to_id:
|
||||
filter_clauses.append(
|
||||
"(instance_name = ? AND ? %s stream_ordering)" % (to_bound,)
|
||||
)
|
||||
filter_args.extend((instance, to_id,))
|
||||
|
||||
filter_clause = ""
|
||||
if filter_clauses:
|
||||
filter_clause = "(%s)" % (" OR ".join(filter_clauses),)
|
||||
|
||||
return filter_clause, filter_args
|
||||
|
||||
|
||||
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
|
||||
# NB: This may create SQL clauses that don't optimise well (and we don't
|
||||
# have indices on all possible clauses). E.g. it may create
|
||||
@@ -305,6 +382,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||
def get_room_min_stream_ordering(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_room_max_token(self) -> RoomStreamToken:
|
||||
min_pos = self._stream_id_gen.get_current_token()
|
||||
|
||||
positions = {}
|
||||
if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
|
||||
positions = {
|
||||
i: p
|
||||
for i, p in self._stream_id_gen.get_positions().items()
|
||||
if p >= min_pos
|
||||
}
|
||||
|
||||
if set(positions.values()) == {min_pos}:
|
||||
positions = {}
|
||||
|
||||
return RoomStreamToken(None, min_pos, positions)
|
||||
|
||||
async def get_room_events_stream_for_rooms(
|
||||
self,
|
||||
room_ids: Collection[str],
|
||||
@@ -402,25 +495,50 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||
if from_key == to_key:
|
||||
return [], from_key
|
||||
|
||||
from_id = from_key.stream
|
||||
to_id = to_key.stream
|
||||
|
||||
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
|
||||
has_changed = self._events_stream_cache.has_entity_changed(
|
||||
room_id, from_key.stream
|
||||
)
|
||||
|
||||
if not has_changed:
|
||||
return [], from_key
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT event_id, stream_ordering FROM events WHERE"
|
||||
" room_id = ?"
|
||||
" AND not outlier"
|
||||
" AND stream_ordering > ? AND stream_ordering <= ?"
|
||||
" ORDER BY stream_ordering %s LIMIT ?"
|
||||
) % (order,)
|
||||
txn.execute(sql, (room_id, from_id, to_id, limit))
|
||||
filter_clause, filter_args = _make_instance_filter_clause(
|
||||
"f", from_key, to_key
|
||||
)
|
||||
if filter_clause:
|
||||
filter_clause = " AND " + filter_clause
|
||||
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
|
||||
max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
|
||||
|
||||
sql = """
|
||||
SELECT event_id, instance_name, stream_ordering
|
||||
FROM events
|
||||
WHERE
|
||||
room_id = ?
|
||||
AND not outlier
|
||||
AND stream_ordering > ? AND stream_ordering <= ?
|
||||
%s
|
||||
ORDER BY stream_ordering %s LIMIT ?
|
||||
""" % (
|
||||
filter_clause,
|
||||
order,
|
||||
)
|
||||
args = [room_id, min_from_id, max_to_id]
|
||||
args.extend(filter_args)
|
||||
args.append(limit)
|
||||
txn.execute(sql, args)
|
||||
|
||||
# rows = [
|
||||
# _EventDictReturn(event_id, None, stream_ordering)
|
||||
# for event_id, instance_name, stream_ordering in txn
|
||||
# if _filter_result(instance_name, stream_ordering, from_key, to_key)
|
||||
# ]
|
||||
rows = [
|
||||
_EventDictReturn(event_id, None, stream_ordering)
|
||||
for event_id, instance_name, stream_ordering in txn
|
||||
]
|
||||
return rows
|
||||
|
||||
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
|
||||
@@ -429,7 +547,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||
[r.event_id for r in rows], get_prev_content=True
|
||||
)
|
||||
|
||||
self._set_before_and_after(ret, rows, topo_order=from_id is None)
|
||||
self._set_before_and_after(ret, rows, topo_order=from_key.stream is None)
|
||||
|
||||
if order.lower() == "desc":
|
||||
ret.reverse()
|
||||
@@ -446,29 +564,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||
async def get_membership_changes_for_user(
|
||||
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
|
||||
) -> List[EventBase]:
|
||||
from_id = from_key.stream
|
||||
to_id = to_key.stream
|
||||
|
||||
if from_key == to_key:
|
||||
return []
|
||||
|
||||
if from_id:
|
||||
if from_key:
|
||||
has_changed = self._membership_stream_cache.has_entity_changed(
|
||||
user_id, int(from_id)
|
||||
user_id, int(from_key.stream)
|
||||
)
|
||||
if not has_changed:
|
||||
return []
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT m.event_id, stream_ordering FROM events AS e,"
|
||||
" room_memberships AS m"
|
||||
" WHERE e.event_id = m.event_id"
|
||||
" AND m.user_id = ?"
|
||||
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
|
||||
" ORDER BY e.stream_ordering ASC"
|
||||
filter_clause, filter_args = _make_instance_filter_clause(
|
||||
"f", from_key, to_key
|
||||
)
|
||||
txn.execute(sql, (user_id, from_id, to_id))
|
||||
if filter_clause:
|
||||
filter_clause = " AND " + filter_clause
|
||||
|
||||
min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
|
||||
max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
|
||||
|
||||
sql = """
|
||||
SELECT m.event_id, stream_ordering
|
||||
FROM events AS e, room_memberships AS m
|
||||
WHERE e.event_id = m.event_id
|
||||
AND m.user_id = ?
|
||||
AND e.stream_ordering > ? AND e.stream_ordering <= ?
|
||||
%s
|
||||
ORDER BY e.stream_ordering ASC
|
||||
""" % (
|
||||
filter_clause,
|
||||
)
|
||||
args = [user_id, min_from_id, max_to_id]
|
||||
args.extend(filter_args)
|
||||
txn.execute(sql, args)
|
||||
|
||||
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
|
||||
|
||||
@@ -975,11 +1104,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||
else:
|
||||
order = "ASC"
|
||||
|
||||
if from_token.topological is not None:
|
||||
from_bound = from_token.as_tuple()
|
||||
elif direction == "b":
|
||||
from_bound = (
|
||||
None,
|
||||
max(from_token.instance_map.values(), default=from_token.stream),
|
||||
)
|
||||
else:
|
||||
from_bound = (
|
||||
None,
|
||||
min(from_token.instance_map.values(), default=from_token.stream),
|
||||
)
|
||||
|
||||
to_bound = None
|
||||
if to_token:
|
||||
if to_token.topological is not None:
|
||||
to_bound = to_token.as_tuple()
|
||||
elif direction == "b":
|
||||
to_bound = (
|
||||
None,
|
||||
min(to_token.instance_map.values(), default=to_token.stream),
|
||||
)
|
||||
else:
|
||||
to_bound = (
|
||||
None,
|
||||
max(to_token.instance_map.values(), default=to_token.stream),
|
||||
)
|
||||
|
||||
bounds = generate_pagination_where_clause(
|
||||
direction=direction,
|
||||
column_names=("topological_ordering", "stream_ordering"),
|
||||
from_token=from_token.as_tuple(),
|
||||
to_token=to_token.as_tuple() if to_token else None,
|
||||
from_token=from_bound,
|
||||
to_token=to_bound,
|
||||
engine=self.database_engine,
|
||||
)
|
||||
|
||||
@@ -989,6 +1146,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
||||
bounds += " AND " + filter_clause
|
||||
args.extend(filter_args)
|
||||
|
||||
stream_filter_clause, stream_filter_args = _make_instance_filter_clause(
|
||||
direction, from_token, to_token
|
||||
)
|
||||
if stream_filter_clause:
|
||||
bounds += " AND " + stream_filter_clause
|
||||
args.extend(stream_filter_args)
|
||||
|
||||
args.append(int(limit))
|
||||
|
||||
select_keywords = "SELECT"
|
||||
|
||||
@@ -229,7 +229,7 @@ class EventsPersistenceStorage:
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
return RoomStreamToken(None, self.main_store.get_current_events_token())
|
||||
return self.main_store.get_room_max_token()
|
||||
|
||||
async def persist_event(
|
||||
self, event: EventBase, context: EventContext, backfilled: bool = False
|
||||
@@ -247,11 +247,10 @@ class EventsPersistenceStorage:
|
||||
|
||||
await make_deferred_yieldable(deferred)
|
||||
|
||||
max_persisted_id = self.main_store.get_current_events_token()
|
||||
event_stream_id = event.internal_metadata.stream_ordering
|
||||
|
||||
pos = PersistedEventPosition(self._instance_name, event_stream_id)
|
||||
return pos, RoomStreamToken(None, max_persisted_id)
|
||||
return pos, self.main_store.get_room_max_token()
|
||||
|
||||
def _maybe_start_persisting(self, room_id: str):
|
||||
async def persisting_queue(item):
|
||||
|
||||
@@ -217,6 +217,7 @@ class MultiWriterIdGenerator:
|
||||
self._instance_name = instance_name
|
||||
self._positive = positive
|
||||
self._writers = writers
|
||||
self._sequence_name = sequence_name
|
||||
self._return_factor = 1 if positive else -1
|
||||
|
||||
# We lock as some functions may be called from DB threads.
|
||||
@@ -227,6 +228,8 @@ class MultiWriterIdGenerator:
|
||||
# return them.
|
||||
self._current_positions = {} # type: Dict[str, int]
|
||||
|
||||
self._max_persisted_positions = dict(self._current_positions)
|
||||
|
||||
# Set of local IDs that we're still processing. The current position
|
||||
# should be less than the minimum of this set (if not empty).
|
||||
self._unfinished_ids = set() # type: Set[int]
|
||||
@@ -404,6 +407,12 @@ class MultiWriterIdGenerator:
|
||||
current position if possible.
|
||||
"""
|
||||
|
||||
logger.debug(
|
||||
"Mark as finished 1 _current_positions %s: %s",
|
||||
self._sequence_name,
|
||||
self._current_positions,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.discard(next_id)
|
||||
self._finished_ids.add(next_id)
|
||||
@@ -439,6 +448,16 @@ class MultiWriterIdGenerator:
|
||||
if new_cur:
|
||||
curr = self._current_positions.get(self._instance_name, 0)
|
||||
self._current_positions[self._instance_name] = max(curr, new_cur)
|
||||
self._max_persisted_positions[self._instance_name] = max(
|
||||
self._current_positions[self._instance_name],
|
||||
self._max_persisted_positions.get(self._instance_name, 0),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Mark as finished _current_positions %s: %s",
|
||||
self._sequence_name,
|
||||
self._current_positions,
|
||||
)
|
||||
|
||||
self._add_persisted_position(next_id)
|
||||
|
||||
@@ -454,6 +473,11 @@ class MultiWriterIdGenerator:
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
logger.debug(
|
||||
"get_current_token_for_writer %s: %s",
|
||||
self._sequence_name,
|
||||
self._current_positions,
|
||||
)
|
||||
return self._return_factor * self._current_positions.get(instance_name, 0)
|
||||
|
||||
def get_positions(self) -> Dict[str, int]:
|
||||
@@ -478,6 +502,12 @@ class MultiWriterIdGenerator:
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
)
|
||||
|
||||
self._max_persisted_positions[instance_name] = max(
|
||||
new_id,
|
||||
self._current_positions.get(instance_name, 0),
|
||||
self._max_persisted_positions.get(instance_name, 0),
|
||||
)
|
||||
|
||||
self._add_persisted_position(new_id)
|
||||
|
||||
def get_persisted_upto_position(self) -> int:
|
||||
@@ -492,10 +522,29 @@ class MultiWriterIdGenerator:
|
||||
with self._lock:
|
||||
return self._return_factor * self._persisted_upto_position
|
||||
|
||||
def get_max_persisted_position_for_self(self) -> int:
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
return self.get_current_token_for_writer(self._instance_name)
|
||||
|
||||
return self._return_factor * max(
|
||||
self._current_positions.values(), default=1
|
||||
)
|
||||
|
||||
def advance_persisted_to(self, instance_name: str, new_id: int):
|
||||
new_id *= self._return_factor
|
||||
|
||||
with self._lock:
|
||||
self._max_persisted_positions[instance_name] = max(
|
||||
new_id,
|
||||
self._current_positions.get(instance_name, 0),
|
||||
self._max_persisted_positions.get(instance_name, 0),
|
||||
)
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
||||
This is used to keep the `_current_positions` up to date.
|
||||
This is used to keep the `_persisted_upto_position` up to date.
|
||||
"""
|
||||
|
||||
# We require that the lock is locked by caller
|
||||
@@ -506,7 +555,7 @@ class MultiWriterIdGenerator:
|
||||
# We move the current min position up if the minimum current positions
|
||||
# of all instances is higher (since by definition all positions less
|
||||
# that that have been persisted).
|
||||
min_curr = min(self._current_positions.values(), default=0)
|
||||
min_curr = min(self._max_persisted_positions.values(), default=0)
|
||||
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
||||
|
||||
# We now iterate through the seen positions, discarding those that are
|
||||
|
||||
@@ -21,8 +21,9 @@ from collections import namedtuple
|
||||
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import attr
|
||||
import cbor2
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from unpaddedbase64 import decode_base64
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
@@ -362,7 +363,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
|
||||
return username.decode("ascii")
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True)
|
||||
@attr.s(frozen=True, slots=True, cmp=False)
|
||||
class RoomStreamToken:
|
||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||
|
||||
@@ -392,6 +393,8 @@ class RoomStreamToken:
|
||||
)
|
||||
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
|
||||
|
||||
instance_map = attr.ib(type=Dict[str, int], factory=dict)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string: str) -> "RoomStreamToken":
|
||||
try:
|
||||
@@ -400,6 +403,11 @@ class RoomStreamToken:
|
||||
if string[0] == "t":
|
||||
parts = string[1:].split("-", 1)
|
||||
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
||||
if string[0] == "m":
|
||||
payload = cbor2.loads(decode_base64(string[1:]))
|
||||
return cls(
|
||||
topological=None, stream=payload["s"], instance_map=payload["p"],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||
@@ -413,15 +421,49 @@ class RoomStreamToken:
|
||||
pass
|
||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||
|
||||
def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
|
||||
if self.topological or other.topological:
|
||||
raise Exception("Can't advance topological tokens")
|
||||
|
||||
max_stream = max(self.stream, other.stream)
|
||||
|
||||
instance_map = {
|
||||
instance: max(
|
||||
self.instance_map.get(instance, self.stream),
|
||||
other.instance_map.get(instance, other.stream),
|
||||
)
|
||||
for instance in set(self.instance_map).union(other.instance_map)
|
||||
}
|
||||
|
||||
return RoomStreamToken(None, max_stream, instance_map)
|
||||
|
||||
def as_tuple(self) -> Tuple[Optional[int], int]:
|
||||
return (self.topological, self.stream)
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.topological is not None:
|
||||
return "t%d-%d" % (self.topological, self.stream)
|
||||
elif self.instance_map:
|
||||
return "m" + encode_base64(
|
||||
cbor2.dumps({"s": self.stream, "p": self.instance_map}),
|
||||
)
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
def __lt__(self, other: "RoomStreamToken"):
|
||||
if self.stream != other.stream:
|
||||
return self.stream < other.stream
|
||||
|
||||
for instance in set(self.instance_map).union(other.instance_map):
|
||||
if self.instance_map.get(instance, self.stream) != other.instance_map.get(
|
||||
instance, other.stream
|
||||
):
|
||||
return self.instance_map.get(
|
||||
instance, self.stream
|
||||
) < other.instance_map.get(instance, other.stream)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class StreamToken:
|
||||
@@ -461,7 +503,7 @@ class StreamToken:
|
||||
def is_after(self, other):
|
||||
"""Does this token contain events that the other doesn't?"""
|
||||
return (
|
||||
(other.room_stream_id < self.room_stream_id)
|
||||
(other.room_key < self.room_key)
|
||||
or (int(other.presence_key) < int(self.presence_key))
|
||||
or (int(other.typing_key) < int(self.typing_key))
|
||||
or (int(other.receipt_key) < int(self.receipt_key))
|
||||
@@ -476,13 +518,16 @@ class StreamToken:
|
||||
"""Advance the given key in the token to a new value if and only if the
|
||||
new value is after the old value.
|
||||
"""
|
||||
new_token = self.copy_and_replace(key, new_value)
|
||||
if key == "room_key":
|
||||
new_id = new_token.room_stream_id
|
||||
old_id = self.room_stream_id
|
||||
else:
|
||||
new_id = int(getattr(new_token, key))
|
||||
old_id = int(getattr(self, key))
|
||||
new_token = self.copy_and_replace(
|
||||
"room_key", self.room_key.copy_and_advance(new_value)
|
||||
)
|
||||
return new_token
|
||||
|
||||
new_token = self.copy_and_replace(key, new_value)
|
||||
new_id = int(getattr(new_token, key))
|
||||
old_id = int(getattr(self, key))
|
||||
|
||||
if old_id < new_id:
|
||||
return new_token
|
||||
else:
|
||||
@@ -507,7 +552,7 @@ class PersistedEventPosition:
|
||||
stream = attr.ib(type=int)
|
||||
|
||||
def persisted_after(self, token: RoomStreamToken) -> bool:
|
||||
return token.stream < self.stream
|
||||
return token.instance_map.get(self.instance_name, token.stream) < self.stream
|
||||
|
||||
|
||||
class ThirdPartyInstanceID(
|
||||
|
||||
Reference in New Issue
Block a user