Compare commits

...

3 Commits

Author SHA1 Message Date
Erik Johnston
1e05b033af Persited up to command 2020-09-29 14:45:42 +01:00
Erik Johnston
4499d81adf Wire up token 2020-09-29 14:43:28 +01:00
Erik Johnston
a4dde1f23c Reduce usages of RoomStreamToken constructor 2020-09-29 14:43:28 +01:00
15 changed files with 389 additions and 60 deletions

View File

@@ -29,7 +29,6 @@ from synapse.api.errors import (
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
RoomStreamToken,
StreamToken,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -113,8 +112,7 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id)
set_tag("from_token", from_token)
now_room_id = self.store.get_room_max_stream_ordering()
now_room_key = RoomStreamToken(None, now_room_id)
now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)

View File

@@ -1141,7 +1141,7 @@ class RoomEventSource:
return (events, end_key)
def get_current_key(self) -> RoomStreamToken:
return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
return self.store.get_room_max_token()
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)

View File

@@ -163,7 +163,7 @@ class _NotifierUserStream:
"""
# Immediately wake up stream if something has already since happened
# since their last token.
if self.last_notified_token.is_after(token):
if self.last_notified_token != token:
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
@@ -470,7 +470,7 @@ class Notifier:
async def check_for_updates(
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if not after_token.is_after(before_token):
if after_token == before_token:
return EventStreamResult([], (from_token, from_token))
events = [] # type: List[EventBase]

View File

@@ -77,6 +77,7 @@ REQUIREMENTS = [
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
"cbor2",
]
CONDITIONAL_REQUIREMENTS = {

View File

@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
from synapse.types import PersistedEventPosition, UserID
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -152,9 +152,7 @@ class ReplicationDataHandler:
if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),)
max_token = RoomStreamToken(
None, self.store.get_room_max_stream_ordering()
)
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
self.notifier.on_new_room_event(
event, event_pos, max_token, extra_users

View File

@@ -171,6 +171,37 @@ class PositionCommand(Command):
return " ".join((self.stream_name, self.instance_name, str(self.token)))
class PersistedToCommand(Command):
"""Sent by writers to inform others that it has persisted up to the included
token.
The included `token` will *not* have been persisted by the instance.
Format::
PERSISTED_TO <stream_name> <instance_name> <token>
On receipt the client should mark that the given instances has persisted
everything up to the given token. Note: this does *not* mean that other
instances have also persisted all their rows up to that point.
"""
NAME = "PERSISTED_TO"
def __init__(self, stream_name, instance_name, token):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
stream_name, instance_name, token = line.split(" ", 2)
return cls(stream_name, instance_name, int(token))
def to_line(self):
return " ".join((self.stream_name, self.instance_name, str(self.token)))
class ErrorCommand(_SimpleCommand):
"""Sent by either side if there was an ERROR. The data is a string describing
the error.
@@ -405,6 +436,7 @@ _COMMANDS = (
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
PersistedToCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.

View File

@@ -47,6 +47,7 @@ from synapse.replication.tcp.commands import (
ReplicateCommand,
UserIpCommand,
UserSyncCommand,
PersistedToCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import (
@@ -387,6 +388,9 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_PERSISTED_TO(self, conn: AbstractConnection, cmd: PersistedToCommand):
pass
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes

View File

@@ -24,6 +24,7 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
@@ -84,6 +85,9 @@ class ReplicationStreamer:
# Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate()
if self.streams:
self.clock.looping_call(self.on_notifier_poke, 1000.0)
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
connections if there are.
@@ -126,9 +130,7 @@ class ReplicationStreamer:
random.shuffle(all_streams)
for stream in all_streams:
if stream.last_token == stream.current_token(
self._instance_name
):
if not stream.has_updates():
continue
if self._replication_torture_level:
@@ -174,6 +176,11 @@ class ReplicationStreamer:
except Exception:
logger.exception("Failed to replicate")
# for command in stream.extra_commands(
# sent_updates=bool(updates)
# ):
# self.command_handler.send_command(command)
logger.debug("No more pending updates, breaking poke loop")
finally:
self.pending_updates = False

View File

@@ -31,6 +31,7 @@ from typing import (
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import Command
if TYPE_CHECKING:
import synapse.server
@@ -187,6 +188,12 @@ class Stream:
)
return updates, upto_token, limited
def has_updates(self) -> bool:
return self.current_token(self.local_instance_name) != self.last_token
def extra_commands(self, sent_updates: bool) -> List[Command]:
return []
def current_token_without_instance(
current_token: Callable[[], int]

View File

@@ -19,7 +19,8 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token
from synapse.replication.tcp.streams._base import Stream, StreamUpdateResult, Token
from synapse.replication.tcp.commands import Command, PersistedToCommand
"""Handling of the 'events' replication stream
@@ -222,3 +223,18 @@ class EventsStream(Stream):
(typ, data) = row
data = TypeToRow[typ].from_data(data)
return EventsStreamRow(typ, data)
def has_updates(self) -> bool:
return True
def extra_commands(self, sent_updates: bool) -> List[Command]:
if sent_updates:
return []
return [
PersistedToCommand(
self.NAME,
self.local_instance_name,
self._store._stream_id_gen.get_max_persisted_position_for_self(),
)
]

View File

@@ -178,6 +178,8 @@ class PersistEventsStore:
)
persist_event_counter.inc(len(events_and_contexts))
logger.debug("Finished persisting 1")
if not backfilled:
# backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that.
@@ -185,6 +187,8 @@ class PersistEventsStore:
events_and_contexts[-1][0].internal_metadata.stream_ordering
)
logger.debug("Finished persisting 2")
for event, context in events_and_contexts:
if context.app_service:
origin_type = "local"
@@ -198,6 +202,8 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
logger.debug("Finished persisting 3")
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -206,6 +212,9 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids)
)
logger.debug("Finished persisting 4")
logger.debug("Finished persisting 5")
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events.

View File

@@ -35,11 +35,10 @@ what sort order was used:
- topological tokems: "t%d-%d", where the integers map to the topological
and stream ordering columns respectively.
"""
import abc
import logging
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
@@ -54,6 +53,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -76,6 +76,18 @@ _EventDictReturn = namedtuple(
)
def _filter_result(
instance_name: str,
stream_id: int,
from_token: RoomStreamToken,
to_token: RoomStreamToken,
) -> bool:
from_id = from_token.instance_map.get(instance_name, from_token.stream)
to_id = to_token.instance_map.get(instance_name, to_token.stream)
return from_id < stream_id <= to_id
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
@@ -209,6 +221,71 @@ def _make_generic_sql_bound(
)
def _make_instance_filter_clause(
direction: str,
from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
) -> Tuple[str, List[Any]]:
if from_token and from_token.topological:
from_token = None
if to_token and to_token.topological:
to_token = None
if not from_token and not to_token:
return "", []
from_bound = ">=" if direction == "b" else "<"
to_bound = "<" if direction == "b" else ">="
filter_clauses = []
filter_args = [] # type: List[Any]
from_map = from_token.instance_map if from_token else {}
to_map = to_token.instance_map if to_token else {}
default_from = from_token.stream if from_token else None
default_to = to_token.stream if to_token else None
if default_from and default_to:
filter_clauses.append(
"(? %s stream_ordering AND ? %s stream_ordering)" % (from_bound, to_bound)
)
filter_args.extend((default_from, default_to,))
elif default_from:
filter_clauses.append("(? %s stream_ordering)" % (from_bound,))
filter_args.extend((default_from,))
elif default_to:
filter_clauses.append("(? %s stream_ordering)" % (to_bound,))
filter_args.extend((default_to,))
for instance in set(from_map).union(to_map):
from_id = from_map.get(instance, default_from)
to_id = to_map.get(instance, default_to)
if from_id and to_id:
filter_clauses.append(
"(instance_name = ? AND ? %s stream_ordering AND ? %s stream_ordering)"
% (from_bound, to_bound)
)
filter_args.extend((instance, from_id, to_id,))
elif from_id:
filter_clauses.append(
"(instance_name = ? AND ? %s stream_ordering)" % (from_bound,)
)
filter_args.extend((instance, from_id,))
elif to_id:
filter_clauses.append(
"(instance_name = ? AND ? %s stream_ordering)" % (to_bound,)
)
filter_args.extend((instance, to_id,))
filter_clause = ""
if filter_clauses:
filter_clause = "(%s)" % (" OR ".join(filter_clauses),)
return filter_clause, filter_args
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
@@ -305,6 +382,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
min_pos = self._stream_id_gen.get_current_token()
positions = {}
if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
positions = {
i: p
for i, p in self._stream_id_gen.get_positions().items()
if p >= min_pos
}
if set(positions.values()) == {min_pos}:
positions = {}
return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
@@ -402,25 +495,50 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key:
return [], from_key
from_id = from_key.stream
to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
has_changed = self._events_stream_cache.has_entity_changed(
room_id, from_key.stream
)
if not has_changed:
return [], from_key
def f(txn):
sql = (
"SELECT event_id, stream_ordering FROM events WHERE"
" room_id = ?"
" AND not outlier"
" AND stream_ordering > ? AND stream_ordering <= ?"
" ORDER BY stream_ordering %s LIMIT ?"
) % (order,)
txn.execute(sql, (room_id, from_id, to_id, limit))
filter_clause, filter_args = _make_instance_filter_clause(
"f", from_key, to_key
)
if filter_clause:
filter_clause = " AND " + filter_clause
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
sql = """
SELECT event_id, instance_name, stream_ordering
FROM events
WHERE
room_id = ?
AND not outlier
AND stream_ordering > ? AND stream_ordering <= ?
%s
ORDER BY stream_ordering %s LIMIT ?
""" % (
filter_clause,
order,
)
args = [room_id, min_from_id, max_to_id]
args.extend(filter_args)
args.append(limit)
txn.execute(sql, args)
# rows = [
# _EventDictReturn(event_id, None, stream_ordering)
# for event_id, instance_name, stream_ordering in txn
# if _filter_result(instance_name, stream_ordering, from_key, to_key)
# ]
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn
]
return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -429,7 +547,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
self._set_before_and_after(ret, rows, topo_order=from_key.stream is None)
if order.lower() == "desc":
ret.reverse()
@@ -446,29 +564,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
from_id = from_key.stream
to_id = to_key.stream
if from_key == to_key:
return []
if from_id:
if from_key:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
user_id, int(from_key.stream)
)
if not has_changed:
return []
def f(txn):
sql = (
"SELECT m.event_id, stream_ordering FROM events AS e,"
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
filter_clause, filter_args = _make_instance_filter_clause(
"f", from_key, to_key
)
txn.execute(sql, (user_id, from_id, to_id))
if filter_clause:
filter_clause = " AND " + filter_clause
min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
sql = """
SELECT m.event_id, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
%s
ORDER BY e.stream_ordering ASC
""" % (
filter_clause,
)
args = [user_id, min_from_id, max_to_id]
args.extend(filter_args)
txn.execute(sql, args)
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
@@ -975,11 +1104,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else:
order = "ASC"
if from_token.topological is not None:
from_bound = from_token.as_tuple()
elif direction == "b":
from_bound = (
None,
max(from_token.instance_map.values(), default=from_token.stream),
)
else:
from_bound = (
None,
min(from_token.instance_map.values(), default=from_token.stream),
)
to_bound = None
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_tuple()
elif direction == "b":
to_bound = (
None,
min(to_token.instance_map.values(), default=to_token.stream),
)
else:
to_bound = (
None,
max(to_token.instance_map.values(), default=to_token.stream),
)
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token.as_tuple(),
to_token=to_token.as_tuple() if to_token else None,
from_token=from_bound,
to_token=to_bound,
engine=self.database_engine,
)
@@ -989,6 +1146,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause
args.extend(filter_args)
stream_filter_clause, stream_filter_args = _make_instance_filter_clause(
direction, from_token, to_token
)
if stream_filter_clause:
bounds += " AND " + stream_filter_clause
args.extend(stream_filter_args)
args.append(int(limit))
select_keywords = "SELECT"

View File

@@ -229,7 +229,7 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True)
)
return RoomStreamToken(None, self.main_store.get_current_events_token())
return self.main_store.get_room_max_token()
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
@@ -247,11 +247,10 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred)
max_persisted_id = self.main_store.get_current_events_token()
event_stream_id = event.internal_metadata.stream_ordering
pos = PersistedEventPosition(self._instance_name, event_stream_id)
return pos, RoomStreamToken(None, max_persisted_id)
return pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item):

View File

@@ -217,6 +217,7 @@ class MultiWriterIdGenerator:
self._instance_name = instance_name
self._positive = positive
self._writers = writers
self._sequence_name = sequence_name
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
@@ -227,6 +228,8 @@ class MultiWriterIdGenerator:
# return them.
self._current_positions = {} # type: Dict[str, int]
self._max_persisted_positions = dict(self._current_positions)
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
@@ -404,6 +407,12 @@ class MultiWriterIdGenerator:
current position if possible.
"""
logger.debug(
"Mark as finished 1 _current_positions %s: %s",
self._sequence_name,
self._current_positions,
)
with self._lock:
self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id)
@@ -439,6 +448,16 @@ class MultiWriterIdGenerator:
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
self._max_persisted_positions[self._instance_name] = max(
self._current_positions[self._instance_name],
self._max_persisted_positions.get(self._instance_name, 0),
)
logger.debug(
"Mark as finished _current_positions %s: %s",
self._sequence_name,
self._current_positions,
)
self._add_persisted_position(next_id)
@@ -454,6 +473,11 @@ class MultiWriterIdGenerator:
"""
with self._lock:
logger.debug(
"get_current_token_for_writer %s: %s",
self._sequence_name,
self._current_positions,
)
return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
@@ -478,6 +502,12 @@ class MultiWriterIdGenerator:
new_id, self._current_positions.get(instance_name, 0)
)
self._max_persisted_positions[instance_name] = max(
new_id,
self._current_positions.get(instance_name, 0),
self._max_persisted_positions.get(instance_name, 0),
)
self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int:
@@ -492,10 +522,29 @@ class MultiWriterIdGenerator:
with self._lock:
return self._return_factor * self._persisted_upto_position
def get_max_persisted_position_for_self(self) -> int:
with self._lock:
if self._unfinished_ids:
return self.get_current_token_for_writer(self._instance_name)
return self._return_factor * max(
self._current_positions.values(), default=1
)
def advance_persisted_to(self, instance_name: str, new_id: int):
new_id *= self._return_factor
with self._lock:
self._max_persisted_positions[instance_name] = max(
new_id,
self._current_positions.get(instance_name, 0),
self._max_persisted_positions.get(instance_name, 0),
)
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
This is used to keep the `_current_positions` up to date.
This is used to keep the `_persisted_upto_position` up to date.
"""
# We require that the lock is locked by caller
@@ -506,7 +555,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
min_curr = min(self._current_positions.values(), default=0)
min_curr = min(self._max_persisted_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are

View File

@@ -21,8 +21,9 @@ from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
import attr
import cbor2
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
@@ -362,7 +363,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
@attr.s(frozen=True, slots=True)
@attr.s(frozen=True, slots=True, cmp=False)
class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1.
@@ -392,6 +393,8 @@ class RoomStreamToken:
)
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
instance_map = attr.ib(type=Dict[str, int], factory=dict)
@classmethod
def parse(cls, string: str) -> "RoomStreamToken":
try:
@@ -400,6 +403,11 @@ class RoomStreamToken:
if string[0] == "t":
parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
if string[0] == "m":
payload = cbor2.loads(decode_base64(string[1:]))
return cls(
topological=None, stream=payload["s"], instance_map=payload["p"],
)
except Exception:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -413,15 +421,49 @@ class RoomStreamToken:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
if self.topological or other.topological:
raise Exception("Can't advance topological tokens")
max_stream = max(self.stream, other.stream)
instance_map = {
instance: max(
self.instance_map.get(instance, self.stream),
other.instance_map.get(instance, other.stream),
)
for instance in set(self.instance_map).union(other.instance_map)
}
return RoomStreamToken(None, max_stream, instance_map)
def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)
def __str__(self) -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
return "m" + encode_base64(
cbor2.dumps({"s": self.stream, "p": self.instance_map}),
)
else:
return "s%d" % (self.stream,)
def __lt__(self, other: "RoomStreamToken"):
if self.stream != other.stream:
return self.stream < other.stream
for instance in set(self.instance_map).union(other.instance_map):
if self.instance_map.get(instance, self.stream) != other.instance_map.get(
instance, other.stream
):
return self.instance_map.get(
instance, self.stream
) < other.instance_map.get(instance, other.stream)
return False
@attr.s(slots=True, frozen=True)
class StreamToken:
@@ -461,7 +503,7 @@ class StreamToken:
def is_after(self, other):
"""Does this token contain events that the other doesn't?"""
return (
(other.room_stream_id < self.room_stream_id)
(other.room_key < self.room_key)
or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key))
@@ -476,13 +518,16 @@ class StreamToken:
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.
"""
new_token = self.copy_and_replace(key, new_value)
if key == "room_key":
new_id = new_token.room_stream_id
old_id = self.room_stream_id
else:
new_id = int(getattr(new_token, key))
old_id = int(getattr(self, key))
new_token = self.copy_and_replace(
"room_key", self.room_key.copy_and_advance(new_value)
)
return new_token
new_token = self.copy_and_replace(key, new_value)
new_id = int(getattr(new_token, key))
old_id = int(getattr(self, key))
if old_id < new_id:
return new_token
else:
@@ -507,7 +552,7 @@ class PersistedEventPosition:
stream = attr.ib(type=int)
def persisted_after(self, token: RoomStreamToken) -> bool:
return token.stream < self.stream
return token.instance_map.get(self.instance_name, token.stream) < self.stream
class ThirdPartyInstanceID(