Compare commits

...

4 Commits

Author SHA1 Message Date
Erik Johnston
e5c6aafe61 Fix redactions 2021-05-10 15:48:27 +01:00
Erik Johnston
c0a7348580 Handle deduplicating multiple fetch event requests 2021-05-10 14:49:39 +01:00
Erik Johnston
875a8fec34 Newsfile 2021-05-06 14:45:20 +01:00
Erik Johnston
9d1118dde8 Ensure we only have one copy of an event in memory at a time
This ensures that if the get event cache overflows we don't end up with
multiple copies of the event in RAM at the same time (which could lead
to memory bloat)
2021-05-06 14:42:42 +01:00
3 changed files with 42 additions and 6 deletions

1
changelog.d/9939.misc Normal file
View File

@@ -0,0 +1 @@
Ensure we only have one copy of an event in memory at a time.

View File

@@ -181,7 +181,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# changed its content in the database. We can't call # changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type. # right type.
txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) txn.call_after(self._invalidate_get_event_cache, event.event_id)
# Send that invalidation to replication so that other workers also invalidate # Send that invalidation to replication so that other workers also invalidate
# the event cache. # the event cache.
self._send_invalidation_to_replication( self._send_invalidation_to_replication(

View File

@@ -14,7 +14,6 @@
import logging import logging
import threading import threading
from collections import namedtuple
from typing import ( from typing import (
Collection, Collection,
Container, Container,
@@ -25,7 +24,9 @@ from typing import (
Tuple, Tuple,
overload, overload,
) )
from weakref import WeakValueDictionary
import attr
from constantly import NamedConstant, Names from constantly import NamedConstant, Names
from typing_extensions import Literal from typing_extensions import Literal
@@ -73,7 +74,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]
class EventRedactBehaviour(Names): class EventRedactBehaviour(Names):
@@ -157,9 +161,14 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache = LruCache( self._get_event_cache = LruCache(
cache_name="*getEvent*", cache_name="*getEvent*",
keylen=3,
max_size=hs.config.caches.event_cache_size, max_size=hs.config.caches.event_cache_size,
) )
# We seperately track which events we have in memory. This is mainly to
# guard against loading the same event into memory multiple times when
# `_get_event_cache` overflows.
self._in_memory_events = (
WeakValueDictionary()
) # type: WeakValueDictionary[str, _EventCacheEntry]
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
@@ -519,6 +528,7 @@ class EventsWorkerStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,)) self._get_event_cache.invalidate((event_id,))
self._in_memory_events.pop(event_id, None)
def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches """Fetch events from the caches
@@ -539,6 +549,9 @@ class EventsWorkerStore(SQLBaseStore):
ret = self._get_event_cache.get( ret = self._get_event_cache.get(
(event_id,), None, update_metrics=update_metrics (event_id,), None, update_metrics=update_metrics
) )
if not ret:
ret = self._in_memory_events.get(event_id)
if not ret: if not ret:
continue continue
@@ -708,6 +721,9 @@ class EventsWorkerStore(SQLBaseStore):
if events_to_fetch: if events_to_fetch:
logger.debug("Also fetching redaction events %s", events_to_fetch) logger.debug("Also fetching redaction events %s", events_to_fetch)
# The events to return
result_map = {} # type: Dict[str, _EventCacheEntry]
# build a map from event_id to EventBase # build a map from event_id to EventBase
event_map = {} event_map = {}
for event_id, row in fetched_events.items(): for event_id, row in fetched_events.items():
@@ -720,6 +736,18 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason: if not allow_rejected and rejected_reason:
continue continue
# Check whether we already have this event in memory. This can
# happen multiple requests for the same event happen at the same
# time. (Ideally we'd have make it so that this doesn't happen, but
# that would require a larger refactor).
cached_entry = self._in_memory_events.get(event_id)
if cached_entry is not None:
# We need to add to the event_map as we read from it to fetch redactions.
event_map[event_id] = cached_entry.event
result_map[event_id] = cached_entry
self._get_event_cache.set((event_id,), cached_entry)
continue
# If the event or metadata cannot be parsed, log the error and act # If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown. # as if the event is unknown.
try: try:
@@ -813,8 +841,10 @@ class EventsWorkerStore(SQLBaseStore):
# finally, we can decide whether each one needs redacting, and build # finally, we can decide whether each one needs redacting, and build
# the cache entries. # the cache entries.
result_map = {}
for event_id, original_ev in event_map.items(): for event_id, original_ev in event_map.items():
if event_id in result_map:
continue
redactions = fetched_events[event_id]["redactions"] redactions = fetched_events[event_id]["redactions"]
redacted_event = self._maybe_redact_event_row( redacted_event = self._maybe_redact_event_row(
original_ev, redactions, event_map original_ev, redactions, event_map
@@ -825,6 +855,7 @@ class EventsWorkerStore(SQLBaseStore):
) )
self._get_event_cache.set((event_id,), cache_entry) self._get_event_cache.set((event_id,), cache_entry)
self._in_memory_events[event_id] = cache_entry
result_map[event_id] = cache_entry result_map[event_id] = cache_entry
return result_map return result_map
@@ -1056,7 +1087,11 @@ class EventsWorkerStore(SQLBaseStore):
set[str]: The events we have already seen. set[str]: The events we have already seen.
""" """
# if the event cache contains the event, obviously we've seen it. # if the event cache contains the event, obviously we've seen it.
results = {x for x in event_ids if self._get_event_cache.contains(x)} results = {
x
for x in event_ids
if self._get_event_cache.contains((x,)) or x in self._in_memory_events
}
def have_seen_events_txn(txn, chunk): def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE " sql = "SELECT event_id FROM events as e WHERE "