Move unique snowflake homeserver background tasks to start_background_tasks (#19037)

(the standard pattern for this kind of thing)
This commit is contained in:
Eric Eastwood
2025-10-13 10:19:09 -05:00
committed by GitHub
parent 2d07bd7fd2
commit d2c582ef3c
14 changed files with 331 additions and 182 deletions

1
changelog.d/19037.misc Normal file
View File

@@ -0,0 +1 @@
Move unique snowflake homeserver background tasks to `start_background_tasks` (the standard pattern for this kind of thing).

View File

@@ -64,7 +64,6 @@ from twisted.web.resource import Resource
import synapse.util.caches
from synapse.api.constants import MAX_PDU_SIZE
from synapse.app import check_bind_error
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config import ConfigError
from synapse.config._base import format_config_error
from synapse.config.homeserver import HomeServerConfig
@@ -683,15 +682,6 @@ async def start(hs: "HomeServer", freeze: bool = True) -> None:
if hs.config.worker.run_background_tasks:
hs.start_background_tasks()
# TODO: This should be moved to same pattern we use for other background tasks:
# Add to `REQUIRED_ON_BACKGROUND_TASK_STARTUP` and rely on
# `start_background_tasks` to start it.
await hs.get_common_usage_metrics_manager().setup()
# TODO: This feels like another pattern that should refactored as one of the
# `REQUIRED_ON_BACKGROUND_TASK_STARTUP`
start_phone_stats_home(hs)
if freeze:
# We now freeze all allocated objects in the hopes that (almost)
# everything currently allocated are things that will be used for the

View File

@@ -62,7 +62,7 @@ class CommonUsageMetricsManager:
"""
return await self._collect()
async def setup(self) -> None:
def setup(self) -> None:
"""Keep the gauges for common usage metrics up to date."""
self._hs.run_as_background_process(
desc="common_usage_metrics_update_gauges",

View File

@@ -62,6 +62,7 @@ from synapse.api.auth_blocking import AuthBlocking
from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
from synapse.app._base import unregister_sighups
from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.config.homeserver import HomeServerConfig
@@ -643,6 +644,8 @@ class HomeServer(metaclass=abc.ABCMeta):
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
self.get_task_scheduler()
self.get_common_usage_metrics_manager().setup()
start_phone_stats_home(self)
def get_reactor(self) -> ISynapseReactor:
"""

View File

@@ -214,7 +214,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_to_server_transport.loseConnection()
# there should have been exactly one request
self.assertEqual(len(requests), 1)
self.assertEqual(
len(requests),
1,
"Expected to handle exactly one HTTP replication request but saw %d - requests=%s"
% (len(requests), requests),
)
return requests[0]

View File

@@ -46,28 +46,39 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
received_account_data_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == AccountDataStream.NAME
]
self.assertEqual([], received_account_data_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
# We should have received all the expected rows in the right order
#
# Filter the updates to only include account data changes
received_account_data_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == AccountDataStream.NAME
]
for t in updates:
(stream_name, token, row) = received_rows.pop(0)
(stream_name, token, row) = received_account_data_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertEqual(row.room_id, "test_room")
(stream_name, token, row) = received_rows.pop(0)
(stream_name, token, row) = received_account_data_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.global")
self.assertIsNone(row.room_id)
self.assertEqual([], received_rows)
self.assertEqual([], received_account_data_rows)
def test_update_function_global_account_data_limit(self) -> None:
"""Test replication with many global account data updates"""
@@ -85,32 +96,38 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
)
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
received_account_data_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == AccountDataStream.NAME
]
self.assertEqual([], received_account_data_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
#
# Filter the updates to only include typing changes
received_account_data_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == AccountDataStream.NAME
]
for t in updates:
(stream_name, token, row) = received_rows.pop(0)
(stream_name, token, row) = received_account_data_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertIsNone(row.room_id)
(stream_name, token, row) = received_rows.pop(0)
(stream_name, token, row) = received_account_data_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.per_room")
self.assertEqual(row.room_id, "test_room")
self.assertEqual([], received_rows)
self.assertEqual([], received_account_data_rows)

View File

@@ -30,6 +30,7 @@ from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
_MAX_STATE_UPDATES_PER_ROOM,
EventsStream,
EventsStreamAllStateRow,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
@@ -82,7 +83,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertEqual([], received_event_rows)
# now reconnect to pull the updates
self.reconnect()
@@ -90,31 +96,34 @@ class EventsStreamTestCase(BaseStreamTestCase):
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
#
# Filter the updates to only include event changes
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
stream_name, token, row = received_event_rows.pop(0)
self.assertEqual(EventsStream.NAME, stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, event.event_id)
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state_event.event_id)
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state_event.event_id)
self.assertEqual([], received_rows)
self.assertEqual([], received_event_rows)
@parameterized.expand(
[(_STREAM_UPDATE_TARGET_ROW_COUNT, False), (_MAX_STATE_UPDATES_PER_ROOM, True)]
@@ -170,9 +179,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertGreaterEqual(len(received_event_rows), 2 * len(events))
# disconnect, so that we can stack up the changes
self.disconnect()
@@ -202,7 +214,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertEqual([], received_event_rows)
# now reconnect to pull the updates
self.reconnect()
@@ -218,33 +235,34 @@ class EventsStreamTestCase(BaseStreamTestCase):
# of the states that got reverted.
# - two rows for state2
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
# first check the first two rows, which should be the state1 event.
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state1.event_id)
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
# now the last two rows, which should be the state2 event.
stream_name, token, row = received_rows.pop(-2)
stream_name, token, row = received_event_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state2.event_id)
stream_name, token, row = received_rows.pop(-1)
stream_name, token, row = received_event_rows.pop(-1)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
@@ -254,16 +272,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
if collapse_state_changes:
# that should leave us with the rows for the PL event, the state changes
# get collapsed into a single row.
self.assertEqual(len(received_rows), 2)
self.assertEqual(len(received_event_rows), 2)
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state-all")
self.assertIsInstance(row.data, EventsStreamAllStateRow)
@@ -271,9 +289,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
else:
# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)
self.assertEqual(len(received_event_rows), len(events) + 2)
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
@@ -282,7 +300,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# the state rows are unsorted
state_rows: List[EventsStreamCurrentStateRow] = []
for stream_name, _, row in received_rows:
for stream_name, _, row in received_event_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
@@ -346,9 +364,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertGreaterEqual(len(received_event_rows), 2 * len(events))
# disconnect, so that we can stack up the changes
self.disconnect()
@@ -375,7 +396,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertEqual([], received_event_rows)
# now reconnect to pull the updates
self.reconnect()
@@ -383,14 +409,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
# we should have received all the expected rows in the right order (as
# well as various cache invalidation updates which we ignore)
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertGreaterEqual(len(received_rows), len(events))
self.assertGreaterEqual(len(received_event_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
# the PL event and each of the states that got reverted.
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
@@ -400,7 +428,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# the state rows are unsorted
state_rows: List[EventsStreamCurrentStateRow] = []
for _ in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
stream_name, token, row = received_event_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
@@ -417,7 +445,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
self.assertEqual([], received_rows)
self.assertEqual([], received_event_rows)
def test_backwards_stream_id(self) -> None:
"""
@@ -432,7 +460,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertEqual([], received_event_rows)
# now reconnect to pull the updates
self.reconnect()
@@ -440,14 +473,16 @@ class EventsStreamTestCase(BaseStreamTestCase):
# We should have received the expected single row (as well as various
# cache invalidation updates which we ignore).
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
# There should be a single received row.
self.assertEqual(len(received_rows), 1)
self.assertEqual(len(received_event_rows), 1)
stream_name, token, row = received_rows[0]
stream_name, token, row = received_event_rows[0]
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
@@ -468,10 +503,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# No updates have been received (because it was discard as old).
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
received_event_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == EventsStream.NAME
]
self.assertEqual(len(received_rows), 0)
self.assertEqual(len(received_event_rows), 0)
# Ensure the stream has not gone backwards.
current_token = worker_events_stream.current_token("master")

View File

@@ -38,24 +38,45 @@ class FederationStreamTestCase(BaseStreamTestCase):
Makes sure that updates sent while we are offline are received later.
"""
fed_sender = self.hs.get_federation_sender()
received_rows = self.test_handler.received_rdata_rows
# Send an update before we connect
fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
# Now reconnect and pull the updates
self.reconnect()
# FIXME: This seems odd, why aren't we calling `self.replicate()` here? but also
# doing so, causes other assumptions to fail (multiple HTTP replication attempts
# are made).
self.reactor.advance(0)
# check we're testing what we think we are: no rows should yet have been
# Check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual(received_rows, [])
#
# Filter the updates to only include typing changes
received_federation_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == FederationStream.NAME
]
self.assertEqual(received_federation_rows, [])
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "federation")
self.assert_request_is_get_repl_stream_updates(request, FederationStream.NAME)
# we should have received an update row
stream_name, token, row = received_rows.pop()
self.assertEqual(stream_name, "federation")
received_federation_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == FederationStream.NAME
]
self.assertEqual(
len(received_federation_rows),
1,
"Expected exactly one row for the federation stream",
)
(stream_name, token, row) = received_federation_rows[0]
self.assertEqual(stream_name, FederationStream.NAME)
self.assertIsInstance(row, FederationStream.FederationStreamRow)
self.assertEqual(row.type, EduRow.TypeId)
edurow = EduRow.from_data(row.data)
@@ -63,19 +84,30 @@ class FederationStreamTestCase(BaseStreamTestCase):
self.assertEqual(edurow.edu.origin, self.hs.hostname)
self.assertEqual(edurow.edu.destination, "testdest")
self.assertEqual(edurow.edu.content, {"a": "b"})
self.assertEqual(received_rows, [])
# Clear out the received rows that we've checked so we can check for new ones later
self.test_handler.received_rdata_rows.clear()
# additional updates should be transferred without an HTTP hit
fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
self.reactor.advance(0)
# Pull in the updates
self.replicate()
# there should be no http hit
self.assertEqual(len(self.reactor.tcpClients), 0)
# ... but we should have a row
self.assertEqual(len(received_rows), 1)
stream_name, token, row = received_rows.pop()
self.assertEqual(stream_name, "federation")
# ... but we should have a row
received_federation_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == FederationStream.NAME
]
self.assertEqual(
len(received_federation_rows),
1,
"Expected exactly one row for the federation stream",
)
(stream_name, token, row) = received_federation_rows[0]
self.assertEqual(stream_name, FederationStream.NAME)
self.assertIsInstance(row, FederationStream.FederationStreamRow)
self.assertEqual(row.type, EduRow.TypeId)
edurow = EduRow.from_data(row.data)

View File

@@ -20,7 +20,6 @@
# type: ignore
from unittest.mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
@@ -30,9 +29,6 @@ USER_ID = "@feeling:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self):
self.reconnect()
@@ -50,23 +46,30 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.replicate()
# there should be one RDATA command
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
received_receipt_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == ReceiptsStream.NAME
]
self.assertEqual(
len(received_receipt_rows),
1,
"Expected exactly one row for the receipts stream",
)
(stream_name, token, row) = received_receipt_rows[0]
self.assertEqual(stream_name, ReceiptsStream.NAME)
self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual("$event:blue", row.event_id)
self.assertIsNone(row.thread_id)
self.assertEqual({"a": 1}, row.data)
# Clear out the received rows that we've checked so we can check for new ones later
self.test_handler.received_rdata_rows.clear()
# Now let's disconnect and insert some data.
self.disconnect()
self.test_handler.on_rdata.reset_mock()
self.get_success(
self.hs.get_datastores().main.insert_receipt(
"!room2:blue",
@@ -79,20 +82,27 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
)
self.replicate()
# Nothing should have happened as we are disconnected
self.test_handler.on_rdata.assert_not_called()
# Not yet connected: no rows should yet have been received
self.assertEqual([], self.test_handler.received_rdata_rows)
# Now reconnect and pull the updates
self.reconnect()
self.pump(0.1)
self.replicate()
# We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
received_receipt_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == ReceiptsStream.NAME
]
self.assertEqual(
len(received_receipt_rows),
1,
"Expected exactly one row for the receipts stream",
)
(stream_name, token, row) = received_receipt_rows[0]
self.assertEqual(stream_name, ReceiptsStream.NAME)
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))
row: ReceiptsStream.ReceiptsStreamRow = rdata_rows[0]
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)

View File

@@ -88,15 +88,15 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
# We should have received all the expected rows in the right order
# Filter the updates to only include thread subscription changes
received_rows = [
upd
for upd in self.test_handler.received_rdata_rows
if upd[0] == ThreadSubscriptionsStream.NAME
received_thread_subscription_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == ThreadSubscriptionsStream.NAME
]
# Verify all the thread subscription updates
for thread_id in updates:
(stream_name, token, row) = received_rows.pop(0)
(stream_name, token, row) = received_thread_subscription_rows.pop(0)
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
self.assertEqual(row.user_id, "@test_user:example.org")
@@ -104,14 +104,14 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
self.assertEqual(row.event_id, thread_id)
# Verify the last update in the different room
(stream_name, token, row) = received_rows.pop(0)
(stream_name, token, row) = received_thread_subscription_rows.pop(0)
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
self.assertEqual(row.user_id, "@test_user:example.org")
self.assertEqual(row.room_id, other_room_id)
self.assertEqual(row.event_id, other_thread_root_id)
self.assertEqual([], received_rows)
self.assertEqual([], received_thread_subscription_rows)
def test_multiple_users_thread_subscription_updates(self) -> None:
"""Test replication with thread subscription updates for multiple users"""
@@ -138,18 +138,18 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
# We should have received all the expected rows
# Filter the updates to only include thread subscription changes
received_rows = [
upd
for upd in self.test_handler.received_rdata_rows
if upd[0] == ThreadSubscriptionsStream.NAME
received_thread_subscription_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == ThreadSubscriptionsStream.NAME
]
# Should have one update per user
self.assertEqual(len(received_rows), len(users))
self.assertEqual(len(received_thread_subscription_rows), len(users))
# Verify all updates
for i, user_id in enumerate(users):
(stream_name, token, row) = received_rows[i]
(stream_name, token, row) = received_thread_subscription_rows[i]
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
self.assertEqual(row.user_id, user_id)

View File

@@ -21,7 +21,10 @@
import logging
import synapse
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams._base import (
_STREAM_UPDATE_TARGET_ROW_COUNT,
ToDeviceStream,
)
from synapse.types import JsonDict
from tests.replication._base import BaseStreamTestCase
@@ -82,7 +85,12 @@ class ToDeviceStreamTestCase(BaseStreamTestCase):
)
# replication is disconnected so we shouldn't get any updates yet
self.assertEqual([], self.test_handler.received_rdata_rows)
received_to_device_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == ToDeviceStream.NAME
]
self.assertEqual([], received_to_device_rows)
# now reconnect to pull the updates
self.reconnect()
@@ -90,7 +98,15 @@ class ToDeviceStreamTestCase(BaseStreamTestCase):
# we should receive the fact that we have to_device updates
# for user1 and user2
received_rows = self.test_handler.received_rdata_rows
self.assertEqual(len(received_rows), 2)
self.assertEqual(received_rows[0][2].entity, user1)
self.assertEqual(received_rows[1][2].entity, user2)
received_to_device_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == ToDeviceStream.NAME
]
self.assertEqual(
len(received_to_device_rows),
2,
"Expected two rows in the to_device stream",
)
self.assertEqual(received_to_device_rows[0][2].entity, user1)
self.assertEqual(received_to_device_rows[1][2].entity, user2)

View File

@@ -19,7 +19,6 @@
#
#
import logging
from unittest.mock import Mock
from synapse.handlers.typing import RoomMember, TypingWriterHandler
from synapse.replication.tcp.streams import TypingStream
@@ -27,6 +26,8 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests.replication._base import BaseStreamTestCase
logger = logging.getLogger(__name__)
USER_ID = "@feeling:blue"
USER_ID_2 = "@da-ba-dee:blue"
@@ -35,10 +36,6 @@ ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self) -> Mock:
self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
return self.mock_handler
def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
assert isinstance(typing, TypingWriterHandler)
@@ -47,51 +44,74 @@ class TypingStreamTestCase(BaseStreamTestCase):
# update to fetch.
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
# Not yet connected: no rows should yet have been received
self.assertEqual([], self.test_handler.received_rdata_rows)
# Reconnect
self.reconnect()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
# Pull in the updates
self.replicate()
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.assert_request_is_get_repl_stream_updates(request, TypingStream.NAME)
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Filter the updates to only include typing changes
received_typing_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == TypingStream.NAME
]
self.assertEqual(
len(received_typing_rows),
1,
"Expected exactly one row for the typing stream",
)
(stream_name, token, row) = received_typing_rows[0]
self.assertEqual(stream_name, TypingStream.NAME)
self.assertIsInstance(row, TypingStream.ROW_TYPE)
self.assertEqual(row.room_id, ROOM_ID)
self.assertEqual(row.user_ids, [USER_ID])
# Clear out the received rows that we've checked so we can check for new ones later
self.test_handler.received_rdata_rows.clear()
# Now let's disconnect and insert some data.
self.disconnect()
self.mock_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.mock_handler.on_rdata.assert_not_called()
# Not yet connected: no rows should yet have been received
self.assertEqual([], self.test_handler.received_rdata_rows)
# Now reconnect and pull the updates
self.reconnect()
self.pump(0.1)
self.replicate()
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.assert_request_is_get_repl_stream_updates(request, TypingStream.NAME)
# The from token should be the token from the last RDATA we got.
assert request.args is not None
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([], row.user_ids)
received_typing_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == TypingStream.NAME
]
self.assertEqual(
len(received_typing_rows),
1,
"Expected exactly one row for the typing stream",
)
(stream_name, token, row) = received_typing_rows[0]
self.assertEqual(stream_name, TypingStream.NAME)
self.assertIsInstance(row, TypingStream.ROW_TYPE)
self.assertEqual(row.room_id, ROOM_ID)
self.assertEqual(row.user_ids, [])
def test_reset(self) -> None:
"""
@@ -116,33 +136,47 @@ class TypingStreamTestCase(BaseStreamTestCase):
# update to fetch.
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
# Not yet connected: no rows should yet have been received
self.assertEqual([], self.test_handler.received_rdata_rows)
# Now reconnect to pull the updates
self.reconnect()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
# Pull in the updates
self.replicate()
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row: TypingStream.TypingStreamRow = rdata_rows[0]
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
received_typing_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == TypingStream.NAME
]
self.assertEqual(
len(received_typing_rows),
1,
"Expected exactly one row for the typing stream",
)
(stream_name, token, row) = received_typing_rows[0]
self.assertEqual(stream_name, TypingStream.NAME)
self.assertIsInstance(row, TypingStream.ROW_TYPE)
self.assertEqual(row.room_id, ROOM_ID)
self.assertEqual(row.user_ids, [USER_ID])
# Push the stream forward a bunch so it can be reset.
for i in range(100):
typing._push_update(
member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
)
self.reactor.advance(0)
# Pull in the updates
self.replicate()
# Disconnect.
self.disconnect()
self.test_handler.received_rdata_rows.clear()
# Reset the typing handler
self.hs.get_replication_streams()["typing"].last_token = 0
@@ -155,30 +189,34 @@ class TypingStreamTestCase(BaseStreamTestCase):
)
typing._reset()
# Reconnect.
# Now reconnect and pull the updates
self.reconnect()
self.pump(0.1)
self.replicate()
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
# Reset the test code.
self.mock_handler.on_rdata.reset_mock()
self.mock_handler.on_rdata.assert_not_called()
# Push additional data.
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
self.reactor.advance(0)
self.mock_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(ROOM_ID_2, row.room_id)
self.assertEqual([], row.user_ids)
# Pull the updates
self.replicate()
received_typing_rows = [
row
for row in self.test_handler.received_rdata_rows
if row[0] == TypingStream.NAME
]
self.assertEqual(
len(received_typing_rows),
1,
"Expected exactly one row for the typing stream",
)
(stream_name, token, row) = received_typing_rows[0]
self.assertEqual(stream_name, TypingStream.NAME)
self.assertIsInstance(row, TypingStream.ROW_TYPE)
self.assertEqual(row.room_id, ROOM_ID_2)
self.assertEqual(row.user_ids, [])
# The token should have been reset.
self.assertEqual(token, 1)
finally:

View File

@@ -110,13 +110,13 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.assertGreater(timestamp, 0)
# Test that users with reserved 3pids are not removed from the MAU table
# XXX some of this is redundant. poking things into the config shouldn't
# work, and in any case it's not obvious what we expect to happen when
# we advance the reactor.
self.hs.config.server.max_mau_value = 0
#
# The `start_phone_stats_home()` looping call will cause us to run
# `reap_monthly_active_users` after the time has advanced
self.reactor.advance(FORTY_DAYS)
self.hs.config.server.max_mau_value = 5
# I guess we call this one more time for good measure? Perhaps because
# previously, the phone home stats weren't running in tests?
self.get_success(self.store.reap_monthly_active_users())
active_count = self.get_success(self.store.get_monthly_active_count())

View File

@@ -75,7 +75,7 @@ class CommonMetricsTestCase(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.metrics_manager = hs.get_common_usage_metrics_manager()
self.get_success(self.metrics_manager.setup())
self.metrics_manager.setup()
def test_dau(self) -> None:
"""Tests that the daily active users count is correctly updated."""