Compare commits

...

4 Commits

Author SHA1 Message Date
Erik Johnston
d9a29506b6 Add EventStreamPosition type 2020-09-10 16:10:01 +01:00
Erik Johnston
e92ad193be Newsfile 2020-09-10 13:54:21 +01:00
Erik Johnston
2352b0522e Make StreamToken.room_key use proper type. 2020-09-10 13:54:20 +01:00
Erik Johnston
86dfefdf8c Add type hints for persist events 2020-09-10 13:42:13 +01:00
18 changed files with 184 additions and 132 deletions

1
changelog.d/8281.misc Normal file
View File

@@ -0,0 +1 @@
Change `StreamToken.room_key` to be a `RoomStreamToken` instance.

View File

@@ -46,10 +46,12 @@ files =
synapse/server_notices, synapse/server_notices,
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,
synapse/storage/engines, synapse/storage/engines,
synapse/storage/persist_events.py,
synapse/storage/state.py, synapse/storage/state.py,
synapse/storage/util, synapse/storage/util,
synapse/streams, synapse/streams,

View File

@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
else: else:
stream_ordering = room.stream_ordering stream_ordering = room.stream_ordering
from_key = str(RoomStreamToken(0, 0)) from_key = RoomStreamToken(0, 0)
to_key = str(RoomStreamToken(None, stream_ordering)) to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room written_events = set() # Events that we've processed in this room
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
if not events: if not events:
break break
from_key = events[-1].internal_metadata.after from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
events = await filter_events_for_client(self.storage, user_id, events) events = await filter_events_for_client(self.storage, user_id, events)

View File

@@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ( from synapse.types import (
RoomStreamToken, RoomStreamToken,
StreamToken,
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
) )
@@ -104,7 +105,7 @@ class DeviceWorkerHandler(BaseHandler):
@trace @trace
@measure_func("device.get_user_ids_changed") @measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id, from_token): async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
"""Get list of users that have had the devices updated, or have newly """Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in. joined a room, that `user_id` may be interested in.
@@ -115,7 +116,8 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id) set_tag("user_id", user_id)
set_tag("from_token", from_token) set_tag("from_token", from_token)
now_room_key = await self.store.get_room_events_max_id() now_room_id = self.store.get_room_max_stream_ordering()
now_room_key = RoomStreamToken(None, now_room_id)
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
@@ -142,7 +144,7 @@ class DeviceWorkerHandler(BaseHandler):
) )
rooms_changed.update(event.room_id for event in member_events) rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream stream_ordering = from_token.room_key.stream
possibly_changed = set(changed) possibly_changed = set(changed)
possibly_left = set() possibly_left = set()

View File

@@ -74,6 +74,8 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
MutableStateMap, MutableStateMap,
PersistedEventPosition,
RoomStreamToken,
StateMap, StateMap,
UserID, UserID,
get_domain_from_id, get_domain_from_id,
@@ -2891,7 +2893,7 @@ class FederationHandler(BaseHandler):
) )
return result["max_stream_id"] return result["max_stream_id"]
else: else:
max_stream_id = await self.storage.persistence.persist_events( max_stream_token = await self.storage.persistence.persist_events(
event_and_contexts, backfilled=backfilled event_and_contexts, backfilled=backfilled
) )
@@ -2902,12 +2904,12 @@ class FederationHandler(BaseHandler):
if not backfilled: # Never notify for backfilled events if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts: for event, _ in event_and_contexts:
await self._notify_persisted_event(event, max_stream_id) await self._notify_persisted_event(event, max_stream_token)
return max_stream_id return max_stream_token.stream
async def _notify_persisted_event( async def _notify_persisted_event(
self, event: EventBase, max_stream_id: int self, event: EventBase, max_stream_token: RoomStreamToken
) -> None: ) -> None:
"""Checks to see if notifier/pushers should be notified about the """Checks to see if notifier/pushers should be notified about the
event or not. event or not.
@@ -2933,9 +2935,11 @@ class FederationHandler(BaseHandler):
elif event.internal_metadata.is_outlier(): elif event.internal_metadata.is_outlier():
return return
event_stream_id = event.internal_metadata.stream_ordering event_pos = PersistedEventPosition(
self._instance_name, event.internal_metadata.stream_ordering
)
self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_pos, max_stream_token, extra_users=extra_users
) )
async def _clean_room_for_join(self, room_id: str) -> None: async def _clean_room_for_join(self, room_id: str) -> None:

View File

@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@@ -167,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id self.state_handler.get_current_state, event.room_id
) )
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,) room_end_token = RoomStreamToken(None, event.stream_ordering,)
deferred_room_state = run_in_background( deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id] self.state_store.get_state_for_events, [event.event_id]
) )

View File

@@ -973,6 +973,7 @@ class EventCreationHandler:
This should only be run on the instance in charge of persisting events. This should only be run on the instance in charge of persisting events.
""" """
assert self._is_event_writer assert self._is_event_writer
assert self.storage.persistence is not None
if ratelimit: if ratelimit:
# We check if this is a room admin redacting an event so that we # We check if this is a room admin redacting an event so that we
@@ -1135,7 +1136,7 @@ class EventCreationHandler:
if prev_state_ids: if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden") raise AuthError(403, "Changing the room create event is forbidden")
event_stream_id, max_stream_id = await self.storage.persistence.persist_event( event_pos, max_stream_token = await self.storage.persistence.persist_event(
event, context=context event, context=context
) )
@@ -1146,7 +1147,7 @@ class EventCreationHandler:
def _notify(): def _notify():
try: try:
self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users event, event_pos, max_stream_token, extra_users=extra_users
) )
except Exception: except Exception:
logger.exception("Error notifying about new room event") logger.exception("Error notifying about new room event")
@@ -1158,7 +1159,7 @@ class EventCreationHandler:
# matters as sometimes presence code can take a while. # matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user) run_in_background(self._bump_active_time, requester.user)
return event_stream_id return event_pos.stream
async def _bump_active_time(self, user: UserID) -> None: async def _bump_active_time(self, user: UserID) -> None:
try: try:

View File

@@ -344,7 +344,7 @@ class PaginationHandler:
# gets called. # gets called.
raise Exception("limit not set") raise Exception("limit not set")
room_token = RoomStreamToken.parse(from_token.room_key) room_token = from_token.room_key
with await self.pagination_lock.read(room_id): with await self.pagination_lock.read(room_id):
( (
@@ -381,7 +381,7 @@ class PaginationHandler:
if leave_token.topological < max_topo: if leave_token.topological < max_topo:
from_token = from_token.copy_and_replace( from_token = from_token.copy_and_replace(
"room_key", leave_token_str "room_key", leave_token
) )
await self.hs.get_handlers().federation_handler.maybe_backfill( await self.hs.get_handlers().federation_handler.maybe_backfill(

View File

@@ -1091,20 +1091,19 @@ class RoomEventSource:
async def get_new_events( async def get_new_events(
self, self,
user: UserID, user: UserID,
from_key: str, from_key: RoomStreamToken,
limit: int, limit: int,
room_ids: List[str], room_ids: List[str],
is_guest: bool, is_guest: bool,
explicit_room_id: Optional[str] = None, explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
# We just ignore the key for now. # We just ignore the key for now.
to_key = self.get_current_key() to_key = self.get_current_key()
from_token = RoomStreamToken.parse(from_key) if from_key.topological:
if from_token.topological:
logger.warning("Stream has topological part!!!! %r", from_key) logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,) from_key = RoomStreamToken(None, from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string()) app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service: if app_service:
@@ -1133,14 +1132,14 @@ class RoomEventSource:
events[:] = events[:limit] events[:] = events[:limit]
if events: if events:
end_key = events[-1].internal_metadata.after end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
else: else:
end_key = to_key end_key = to_key
return (events, end_key) return (events, end_key)
def get_current_key(self) -> str: def get_current_key(self) -> RoomStreamToken:
return "s%d" % (self.store.get_room_max_stream_ordering(),) return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id) return self.store.get_room_events_max_id(room_id)

View File

@@ -378,7 +378,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"): with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else 0
room_ids = sync_result_builder.joined_room_ids room_ids = sync_result_builder.joined_room_ids
@@ -402,7 +402,7 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"} event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = await receipt_source.get_new_events( receipts, receipt_key = await receipt_source.get_new_events(
@@ -533,7 +533,7 @@ class SyncHandler:
if len(recents) > timeline_limit: if len(recents) > timeline_limit:
limited = True limited = True
recents = recents[-timeline_limit:] recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
prev_batch_token = now_token.copy_and_replace("room_key", room_key) prev_batch_token = now_token.copy_and_replace("room_key", room_key)
@@ -1322,6 +1322,7 @@ class SyncHandler:
is_guest=sync_config.is_guest, is_guest=sync_config.is_guest,
include_offline=include_offline, include_offline=include_offline,
) )
assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace( sync_result_builder.now_token = now_token.copy_and_replace(
"presence_key", presence_key "presence_key", presence_key
) )
@@ -1484,7 +1485,7 @@ class SyncHandler:
if rooms_changed: if rooms_changed:
return True return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream stream_id = since_token.room_key.stream
for room_id in sync_result_builder.joined_room_ids: for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id): if self.store.has_room_changed_since(room_id, stream_id):
return True return True
@@ -1750,7 +1751,7 @@ class SyncHandler:
continue continue
leave_token = now_token.copy_and_replace( leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,) "room_key", RoomStreamToken(None, event.stream_ordering)
) )
room_entries.append( room_entries.append(
RoomSyncResultBuilder( RoomSyncResultBuilder(

View File

@@ -25,6 +25,7 @@ from typing import (
Set, Set,
Tuple, Tuple,
TypeVar, TypeVar,
Union,
) )
from prometheus_client import Counter from prometheus_client import Counter
@@ -41,7 +42,13 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID from synapse.types import (
Collection,
PersistedEventPosition,
RoomStreamToken,
StreamToken,
UserID,
)
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@@ -111,7 +118,9 @@ class _NotifierUserStream:
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred()) self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key: str, stream_id: int, time_now_ms: int): def notify(
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
Args: Args:
@@ -186,7 +195,7 @@ class Notifier:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.pending_new_room_events = ( self.pending_new_room_events = (
[] []
) # type: List[Tuple[int, EventBase, Collection[UserID]]] ) # type: List[Tuple[PersistedEventPosition, EventBase, Collection[UserID]]]
# Called when there are new things to stream over replication # Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]] self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -245,8 +254,8 @@ class Notifier:
def on_new_room_event( def on_new_room_event(
self, self,
event: EventBase, event: EventBase,
room_stream_id: int, event_pos: PersistedEventPosition,
max_room_stream_id: int, max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [], extra_users: Collection[UserID] = [],
): ):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@@ -260,16 +269,16 @@ class Notifier:
until all previous events have been persisted before notifying until all previous events have been persisted before notifying
the client streams. the client streams.
""" """
self.pending_new_room_events.append((room_stream_id, event, extra_users)) self.pending_new_room_events.append((event_pos, event, extra_users))
self._notify_pending_new_room_events(max_room_stream_id) self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication() self.notify_replication()
def _notify_pending_new_room_events(self, max_room_stream_id: int): def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
Args: Args:
max_room_stream_id: The highest stream_id below which all max_room_stream_token: The highest stream_id below which all
events have been persisted. events have been persisted.
""" """
pending = self.pending_new_room_events pending = self.pending_new_room_events
@@ -278,11 +287,9 @@ class Notifier:
users = set() # type: Set[UserID] users = set() # type: Set[UserID]
rooms = set() # type: Set[str] rooms = set() # type: Set[str]
for room_stream_id, event, extra_users in pending: for event_pos, event, extra_users in pending:
if room_stream_id > max_room_stream_id: if event_pos.persisted_after(max_room_stream_token):
self.pending_new_room_events.append( self.pending_new_room_events.append((event_pos, event, extra_users))
(room_stream_id, event, extra_users)
)
else: else:
if ( if (
event.type == EventTypes.Member event.type == EventTypes.Member
@@ -294,42 +301,46 @@ class Notifier:
rooms.add(event.room_id) rooms.add(event.room_id)
if users or rooms: if users or rooms:
self.on_new_event("room_key", max_room_stream_id, users=users, rooms=rooms) self.on_new_event(
self._on_updated_room_token(max_room_stream_id) "room_key", max_room_stream_token, users=users, rooms=rooms,
)
self._on_updated_room_token(max_room_stream_token)
def _on_updated_room_token(self, max_room_stream_id: int): def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
"""Poke services that might care that the room position has been """Poke services that might care that the room position has been
updated. updated.
""" """
# poke any interested application service. # poke any interested application service.
run_as_background_process( run_as_background_process(
"_notify_app_services", self._notify_app_services, max_room_stream_id "_notify_app_services", self._notify_app_services, max_room_stream_token
) )
run_as_background_process( run_as_background_process(
"_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id "_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_token
) )
if self.federation_sender: if self.federation_sender:
self.federation_sender.notify_new_events(max_room_stream_id) self.federation_sender.notify_new_events(max_room_stream_token.stream)
async def _notify_app_services(self, max_room_stream_id: int): async def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
try: try:
await self.appservice_handler.notify_interested_services(max_room_stream_id) await self.appservice_handler.notify_interested_services(
max_room_stream_token.stream
)
except Exception: except Exception:
logger.exception("Error notifying application services of event") logger.exception("Error notifying application services of event")
async def _notify_pusher_pool(self, max_room_stream_id: int): async def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
try: try:
await self._pusher_pool.on_new_notifications(max_room_stream_id) await self._pusher_pool.on_new_notifications(max_room_stream_token.stream)
except Exception: except Exception:
logger.exception("Error pusher pool of event") logger.exception("Error pusher pool of event")
def on_new_event( def on_new_event(
self, self,
stream_key: str, stream_key: str,
new_token: int, new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [], users: Collection[UserID] = [],
rooms: Collection[str] = [], rooms: Collection[str] = [],
): ):

View File

@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow, EventsStreamRow,
) )
from synapse.types import UserID from synapse.types import PersistedEventPosition, RoomStreamToken, UserID
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@@ -151,8 +151,14 @@ class ReplicationDataHandler:
extra_users = () # type: Tuple[UserID, ...] extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),) extra_users = (UserID.from_string(event.state_key),)
max_token = self.store.get_room_max_stream_ordering()
self.notifier.on_new_room_event(event, token, max_token, extra_users) max_token = RoomStreamToken(
None, self.store.get_room_max_stream_ordering()
)
event_pos = PersistedEventPosition(instance_name, token)
self.notifier.on_new_room_event(
event, event_pos, max_token, extra_users
)
# Notify any waiting deferreds. The list is ordered by position so we # Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is # just iterate through the list until we reach a position that is

View File

@@ -47,6 +47,9 @@ class Storage:
# interfaces. # interfaces.
self.main = stores.main self.main = stores.main
self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores) self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores) self.state = StateGroupStorage(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorage(hs, stores)

View File

@@ -213,7 +213,7 @@ class PersistEventsStore:
Returns: Returns:
Filtered event ids Filtered event ids
""" """
results = [] results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch): def _get_events_which_are_prevs_txn(txn, batch):
sql = """ sql = """
@@ -631,7 +631,9 @@ class PersistEventsStore:
) )
@classmethod @classmethod
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): def _filter_events_and_contexts_for_duplicates(
cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice. """Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one. Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +643,9 @@ class PersistEventsStore:
Returns: Returns:
list[(EventBase, EventContext)]: filtered list list[(EventBase, EventContext)]: filtered list
""" """
new_events_and_contexts = OrderedDict() new_events_and_contexts = (
OrderedDict()
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts: for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id) prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context: if prev_event_context:
@@ -655,7 +659,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context) new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values()) return list(new_events_and_contexts.values())
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): def _update_room_depths_txn(
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
):
"""Update min_depth for each room """Update min_depth for each room
Args: Args:
@@ -664,7 +673,7 @@ class PersistEventsStore:
we are persisting we are persisting
backfilled (bool): True if the events were backfilled backfilled (bool): True if the events were backfilled
""" """
depth_updates = {} depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts: for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids # Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id) txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -1436,7 +1445,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events. Forward extremities are handled when we first start persisting the events.
""" """
events_by_room = {} events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events: for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev) events_by_room.setdefault(ev.room_id, []).append(ev)

View File

@@ -310,11 +310,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms( async def get_room_events_stream_for_rooms(
self, self,
room_ids: Collection[str], room_ids: Collection[str],
from_key: str, from_key: RoomStreamToken,
to_key: str, to_key: RoomStreamToken,
limit: int = 0, limit: int = 0,
order: str = "DESC", order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]: ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:
@@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room - list of recent events in the room
- stream ordering key for the start of the chunk of events returned. - stream ordering key for the start of the chunk of events returned.
""" """
from_id = RoomStreamToken.parse_stream_token(from_key).stream room_ids = self._events_stream_cache.get_entities_changed(
room_ids, from_key.stream
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id) )
if not room_ids: if not room_ids:
return {} return {}
@@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results return results
def get_rooms_that_changed( def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]: ) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have """Given a list of rooms and a token, return rooms where there may have
been changes. been changes.
Args:
room_ids
from_key: The room_key portion of a StreamToken
""" """
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = from_key.stream
return { return {
room_id room_id
for room_id in room_ids for room_id in room_ids
@@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room( async def get_room_events_stream_for_room(
self, self,
room_id: str, room_id: str,
from_key: str, from_key: RoomStreamToken,
to_key: str, to_key: RoomStreamToken,
limit: int = 0, limit: int = 0,
order: str = "DESC", order: str = "DESC",
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:
@@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key: if from_key == to_key:
return [], from_key return [], from_key
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = from_key.stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id) has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse() ret.reverse()
if rows: if rows:
key = "s%d" % min(r.stream_ordering for r in rows) key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.
@@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key return ret, key
async def get_membership_changes_for_user( async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]: ) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream from_id = from_key.stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream to_id = to_key.stream
if from_key == to_key: if from_key == to_key:
return [] return []
@@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret return ret
async def get_recent_events_for_room( async def get_recent_events_for_room(
self, room_id: str, limit: int, end_token: str self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering. """Get the most recent events in the room in topological ordering.
Args: Args:
@@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token) return (events, token)
async def get_recent_event_ids_for_room( async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: str self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[_EventDictReturn], str]: ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering. """Get the most recent events in the room in topological ordering.
Args: Args:
@@ -535,13 +531,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0: if limit == 0:
return [], end_token return [], end_token
parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room", "get_recent_event_ids_for_room",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
from_token=parsed_end_token, from_token=end_token,
limit=limit, limit=limit,
) )
@@ -619,7 +613,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none, allow_none=allow_none,
) )
async def get_stream_token_for_event(self, event_id: str) -> str: async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event """The stream token for an event
Args: Args:
event_id: The id of the event to look up a stream token for. event_id: The id of the event to look up a stream token for.
@@ -629,7 +623,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A "s%d" stream token. A "s%d" stream token.
""" """
stream_id = await self.get_stream_id_for_event(event_id) stream_id = await self.get_stream_id_for_event(event_id)
return "s%d" % (stream_id,) return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str: async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event """The stream token for an event
@@ -954,7 +948,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b", direction: str = "b",
limit: int = -1, limit: int = -1,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], str]: ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
@@ -1054,17 +1048,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead. # TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token next_token = to_token if to_token else from_token
return rows, str(next_token) return rows, next_token
async def paginate_room_events( async def paginate_room_events(
self, self,
room_id: str, room_id: str,
from_key: str, from_key: RoomStreamToken,
to_key: Optional[str] = None, to_key: Optional[RoomStreamToken] = None,
direction: str = "b", direction: str = "b",
limit: int = -1, limit: int = -1,
event_filter: Optional[Filter] = None, event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token. """Returns list of events before or after a given token.
Args: Args:
@@ -1083,17 +1077,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`). and `to_key`).
""" """
parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key:
parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction( rows, token = await self.db_pool.runInteraction(
"paginate_room_events", "paginate_room_events",
self._paginate_room_events_txn, self._paginate_room_events_txn,
room_id, room_id,
parsed_from_key, from_key,
parsed_to_key, to_key,
direction, direction,
limit, limit,
event_filter, event_filter,

View File

@@ -18,7 +18,7 @@
import itertools import itertools
import logging import logging
from collections import deque, namedtuple from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events import DeltaState
from synapse.types import StateMap from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@@ -185,9 +185,12 @@ class EventsPersistenceStorage:
# store for now. # store for now.
self.main_store = stores.main self.main_store = stores.main
self.state_store = stores.state self.state_store = stores.state
assert stores.persist_events
self.persist_events_store = stores.persist_events self.persist_events_store = stores.persist_events
self._clock = hs.get_clock() self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@@ -196,7 +199,7 @@ class EventsPersistenceStorage:
self, self,
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
) -> int: ) -> RoomStreamToken:
""" """
Write events to the database Write events to the database
Args: Args:
@@ -208,7 +211,7 @@ class EventsPersistenceStorage:
Returns: Returns:
the stream ordering of the latest persisted event the stream ordering of the latest persisted event
""" """
partitioned = {} partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
@@ -226,11 +229,11 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
return self.main_store.get_current_events_token() return RoomStreamToken(None, self.main_store.get_current_events_token())
async def persist_event( async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[int, int]: ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
""" """
Returns: Returns:
The stream ordering of `event`, and the stream ordering of the The stream ordering of `event`, and the stream ordering of the
@@ -245,7 +248,10 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred) await make_deferred_yieldable(deferred)
max_persisted_id = self.main_store.get_current_events_token() max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id) event_stream_id = event.internal_metadata.stream_ordering
pos = PersistedEventPosition(self._instance_name, event_stream_id)
return pos, RoomStreamToken(None, max_persisted_id)
def _maybe_start_persisting(self, room_id: str): def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item): async def persisting_queue(item):
@@ -305,7 +311,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room. # Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then # We do this by working out what the new extremities are and then
# calculating the state from that. # calculating the state from that.
events_by_room = {} events_by_room = (
{}
) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk: for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append( events_by_room.setdefault(event.room_id, []).append(
(event, context) (event, context)
@@ -436,7 +444,7 @@ class EventsPersistenceStorage:
self, self,
room_id: str, room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]], event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str], latest_event_ids: Collection[str],
): ):
"""Calculates the new forward extremities for a room given events to """Calculates the new forward extremities for a room given events to
persist. persist.
@@ -470,7 +478,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events. # Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs( existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result result
) ) # type: Collection[str]
result.difference_update(existing_prevs) result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev # Finally handle the case where the new events have soft-failed prev

View File

@@ -425,7 +425,9 @@ class RoomStreamToken:
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class StreamToken: class StreamToken:
room_key = attr.ib(type=str) room_key = attr.ib(
type=RoomStreamToken, validator=attr.validators.instance_of(RoomStreamToken)
)
presence_key = attr.ib(type=int) presence_key = attr.ib(type=int)
typing_key = attr.ib(type=int) typing_key = attr.ib(type=int)
receipt_key = attr.ib(type=int) receipt_key = attr.ib(type=int)
@@ -445,21 +447,17 @@ class StreamToken:
while len(keys) < len(attr.fields(cls)): while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key # i.e. old token from before receipt_key
keys.append("0") keys.append("0")
return cls(keys[0], *(int(k) for k in keys[1:])) return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
except Exception: except Exception:
raise SynapseError(400, "Invalid Token") # raise SynapseError(400, "Invalid Token")
raise
def to_string(self): def to_string(self):
return self._SEPARATOR.join([str(k) for k in attr.astuple(self)]) return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
@property @property
def room_stream_id(self): def room_stream_id(self):
# TODO(markjh): Awful hack to work around hacks in the presence tests return self.room_key.stream
# which assume that the keys are integers.
if type(self.room_key) is int:
return self.room_key
else:
return int(self.room_key[1:].split("-")[-1])
def is_after(self, other): def is_after(self, other):
"""Does this token contain events that the other doesn't?""" """Does this token contain events that the other doesn't?"""
@@ -475,7 +473,7 @@ class StreamToken:
or (int(other.groups_key) < int(self.groups_key)) or (int(other.groups_key) < int(self.groups_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the """Advance the given key in the token to a new value if and only if the
new value is after the old value. new value is after the old value.
""" """
@@ -491,13 +489,28 @@ class StreamToken:
else: else:
return self return self
def copy_and_replace(self, key, new_value): def copy_and_replace(self, key, new_value) -> "StreamToken":
return attr.evolve(self, **{key: new_value}) return attr.evolve(self, **{key: new_value})
StreamToken.START = StreamToken.from_string("s0_0") StreamToken.START = StreamToken.from_string("s0_0")
@attr.s(slots=True, frozen=True)
class PersistedEventPosition:
"""Position of a newly persisted event with instance that persisted it.
This can be used to test whether the event is persisted before or after a
RoomStreamToken.
"""
instance_name = attr.ib(type=str)
stream = attr.ib(type=int)
def persisted_after(self, token: RoomStreamToken) -> bool:
return token.stream < self.stream
class ThirdPartyInstanceID( class ThirdPartyInstanceID(
namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id")) namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
): ):

View File

@@ -71,7 +71,10 @@ async def inject_event(
""" """
event, context = await create_event(hs, room_version, prev_event_ids, **kwargs) event, context = await create_event(hs, room_version, prev_event_ids, **kwargs)
await hs.get_storage().persistence.persist_event(event, context) persistence = hs.get_storage().persistence
assert persistence is not None
await persistence.persist_event(event, context)
return event return event