Compare commits

...

9 Commits

Author SHA1 Message Date
Half-Shot
059274de05 Merge remote-tracking branch 'origin/rav/state_stream_limit_assertion' into hs/sssh-testing-redis-things 2020-04-28 18:36:16 +01:00
Richard van der Hoff
8b6468cc15 changelog 2020-04-28 17:54:46 +01:00
Richard van der Hoff
3778424b0e Fix AssertionErrors being thrown by EventsStream
Part of the problem was that there was an off-by-one error in the assertion,
but also the limit logic was too simple. Fix it all up and add some tests.
2020-04-28 17:54:46 +01:00
Richard van der Hoff
b21490b656 Rework TestReplicationDataHandler
This wasn't very easy to work with: the mock wrapping was largely superfluous,
and it's useful to be able to inspect the received rows, and clear out the
received list.
2020-04-28 17:54:46 +01:00
Richard van der Hoff
dc3e7e16fb Factor out functions for injecting events into database
I want to add some more flexibility to the tools for injecting events into the
database, and I don't want to clutter up HomeserverTestCase with them, so let's
factor them out to a new file.
2020-04-28 17:43:35 +01:00
Erik Johnston
0877107001 Add test 2020-04-27 15:29:37 +01:00
Erik Johnston
929dbb0fb3 Newsfile 2020-04-27 12:16:34 +01:00
Erik Johnston
659b6bec35 Don't relay REMOTE_SERVER_UP cmds to same conn.
For direct TCP connections we need the master to relay REMOTE_SERVER_UP
commands to the other connections so that all instances get notified
about it. The old implementation just relayed to all connections,
assuming that sending back to the original sender of the command was
safe. This is not true for redis, where commands sent get echoed back to
the sender, which was causing master to effectively infinite loop
sending and then re-receiving REMOTE_SERVER_UP commands that it sent.

The fix is to ensure that we only relay to *other* connections and not
to the connection we received the notification from.

Fixes #7334.
2020-04-27 12:05:57 +01:00
Erik Johnston
70227045ba Pass connection to 'on_*' replication handler 2020-04-27 10:41:15 +01:00
19 changed files with 706 additions and 83 deletions

1
changelog.d/7352.feature Normal file
View File

@@ -0,0 +1 @@
Add support for running replication over Redis when using workers.

1
changelog.d/7358.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

View File

@@ -220,12 +220,6 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
def add_remote_server_up_callback(self, cb: Callable[[str], None]):
"""Add a callback that will be called when synapse detects a server
has been
"""
self.remote_server_up_callbacks.append(cb)
def on_new_room_event(
self, event, room_stream_id, max_room_stream_id, extra_users=[]
):
@@ -544,6 +538,3 @@ class Notifier(object):
# circular dependencies.
if self.federation_sender:
self.federation_sender.wake_destination(server)
for cb in self.remote_server_up_callbacks:
cb(server)

View File

@@ -115,7 +115,6 @@ class ReplicationCommandHandler:
self._server_notices_sender = None
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self._notifier.add_remote_server_up_callback(self.send_remote_server_up)
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -161,7 +160,7 @@ class ReplicationCommandHandler:
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
async def on_REPLICATE(self, cmd: ReplicateCommand):
async def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self._is_master:
@@ -171,7 +170,7 @@ class ReplicationCommandHandler:
current_token = stream.current_token()
self.send_command(PositionCommand(stream_name, current_token))
async def on_USER_SYNC(self, cmd: UserSyncCommand):
async def on_USER_SYNC(self, conn: AbstractConnection, cmd: UserSyncCommand):
user_sync_counter.inc()
if self._is_master:
@@ -179,17 +178,23 @@ class ReplicationCommandHandler:
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
async def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
):
if self._is_master:
await self._presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
async def on_FEDERATION_ACK(
self, conn: AbstractConnection, cmd: FederationAckCommand
):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
async def on_REMOVE_PUSHER(
self, conn: AbstractConnection, cmd: RemovePusherCommand
):
remove_pusher_counter.inc()
if self._is_master:
@@ -199,7 +204,9 @@ class ReplicationCommandHandler:
self._notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
async def on_INVALIDATE_CACHE(
self, conn: AbstractConnection, cmd: InvalidateCacheCommand
):
invalidate_cache_counter.inc()
if self._is_master:
@@ -209,7 +216,7 @@ class ReplicationCommandHandler:
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, cmd: UserIpCommand):
async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self._is_master:
@@ -225,7 +232,7 @@ class ReplicationCommandHandler:
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, cmd: RdataCommand):
async def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -276,7 +283,7 @@ class ReplicationCommandHandler:
logger.debug("Received rdata %s -> %s", stream_name, token)
await self._replication_data_handler.on_rdata(stream_name, token, rows)
async def on_POSITION(self, cmd: PositionCommand):
async def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
stream = self._streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
@@ -330,12 +337,30 @@ class ReplicationCommandHandler:
self._streams_connected.add(cmd.stream_name)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
if self._is_master:
self._notifier.notify_remote_server_up(cmd.data)
self._notifier.notify_remote_server_up(cmd.data)
# We relay to all other connections to ensure every instance gets the
# notification.
#
# When configured to use redis we'll always only have one connection and
# so this is a no-op (all instances will have already received the same
# REMOTE_SERVER_UP command).
#
# For direct TCP connections this will relay to all other connections
# connected to us. When on master this will correctly fan out to all
# other direct TCP clients and on workers there'll only be the one
# connection to master.
#
# (The logic here should also be sound if we have a mix of Redis and
# direct TCP connections so long as there is only one traffic route
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection):
"""Called when we have a new connection.
@@ -380,11 +405,21 @@ class ReplicationCommandHandler:
"""
return bool(self._connections)
def send_command(self, cmd: Command):
def send_command(
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
):
"""Send a command to all connected connections.
Args:
cmd
ignore_conn: If set don't send command to the given connection.
Used when relaying commands from one connection to all others.
"""
if self._connections:
for connection in self._connections:
if connection == ignore_conn:
continue
try:
connection.send_command(cmd)
except Exception:

View File

@@ -260,7 +260,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Then call out to the handler.
cmd_func = getattr(self.command_handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
await cmd_func(self, cmd)
handled = True
if not handled:

View File

@@ -112,7 +112,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
await cmd_func(self, cmd)
handled = True
if not handled:

View File

@@ -176,16 +176,30 @@ class EventsStream(Stream):
from_token, upper_limit, target_row_count
) # type: List[Tuple]
# again, if we've hit the limit there, we'll need to limit the other sources
assert len(state_rows) < target_row_count
assert len(state_rows) <= target_row_count
# there can be more than one row per stream_id in that table, so if we hit
# the limit there, we'll need to truncate the results so that we have a complete
# set of changes for all the stream IDs we include.
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0]
limited = True
upper_limit = state_rows[-1][0] - 1
# FIXME: is it a given that there is only one row per stream_id in the
# state_deltas table (so that we can be sure that we have got all of the
# rows for upper_limit)?
# search for the point to truncate the list
for idx in range(len(state_rows) - 1, 0, -1):
if state_rows[idx - 1][0] <= upper_limit:
state_rows = state_rows[:idx]
break
else:
# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
upper_limit += 1
state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, limit=None
)
limited = True
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.

View File

@@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
class HomeServer(object):
@property
@@ -121,3 +122,7 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass

View File

@@ -1084,15 +1084,23 @@ class EventsWorkerStore(SQLBaseStore):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, limit: Optional[int]
):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (from_token, to_token, limit))
params = [from_token, to_token]
if limit is not None:
sql += "LIMIT ?"
params.append(limit)
txn.execute(sql, params)
return txn.fetchall()
return self.db.runInteraction(

View File

@@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from mock import Mock
import logging
from typing import Any, Dict, List, Optional, Tuple
import attr
@@ -25,6 +24,7 @@ from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
@@ -65,9 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db
self.test_handler = Mock(
wraps=TestReplicationDataHandler(self.worker_hs.get_datastore())
)
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
@@ -78,6 +76,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
def reconnect(self):
if self._client_transport:
self.client.close()
@@ -174,22 +175,28 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class TestReplicationDataHandler(ReplicationDataHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self, hs):
super().__init__(hs)
self.streams = set()
self._received_rdata_rows = []
def __init__(self, store: BaseSlavedStore):
super().__init__(store)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self):
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
return self.stream_positions
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self._received_rdata_rows.append((stream_name, token, r))
self.received_rdata_rows.append((stream_name, token, r))
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s()
@@ -221,7 +228,7 @@ class _PushHTTPChannel(HTTPChannel):
super().__init__()
self.reactor = reactor
self._pull_to_push_producer = None
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.

View File

@@ -0,0 +1,390 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication.tcp.streams._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
class EventsStreamTestCase(BaseStreamTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
def test_update_function_event_row_limit(self):
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
event rows to be replicated.
"""
# generate lots of non-state events. We inject them using inject_event
# so that they are not send out over replication until we call self.replicate().
events = [
self._inject_test_event()
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
]
# also one state event
state_event = self._inject_state_event()
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state_event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state_event.event_id)
self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self):
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.
"""
# we want to generate lots of state changes at a single stream ID.
#
# We do this by having two branches in the DAG. On one, we have a moderator
# which that generates lots of state; on the other, we de-op the moderator,
# thus invalidating all the state.
OTHER_USER = "@other_user:localhost"
# have the user join
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"][OTHER_USER] = 50
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [
self._inject_state_event(sender=OTHER_USER)
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
]
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
self.test_handler.received_rdata_rows.clear()
# a state event which doesn't get rolled back, to check that the state
# before the huge update comes through ok
state1 = self._inject_state_event()
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
pl_event = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# now we should have received all the expected rows in the right order.
#
# we expect:
#
# - two rows for state1
# - the PL event row, plus state rows for the PL event and each
# of the states that got reverted.
# - two rows for state2
received_rows = self.test_handler.received_rdata_rows
# first check the first two rows, which should be state1
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state1.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
# now the last two rows, which should be state2
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state2.event_id)
stream_name, token, row = received_rows.pop(-1)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)
# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for stream_name, token, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_event.event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
"""Test replication with many state events over several stream ids.
"""
# we want to generate lots of state changes, but for this test, we want to
# spread out the state changes over a few stream IDs.
#
# We do this by having two branches in the DAG. On one, we have four moderators,
# each of which that generates lots of state; on the other, we de-op the users,
# thus invalidating all the state.
NUM_USERS = 4
STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
# have the users join
for u in user_ids:
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"].update({u: 50 for u in user_ids})
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [] # type: List[EventBase]
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
)
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
self.test_handler.received_rdata_rows.clear()
# now roll back all that state by de-modding the users
prev_events = fork_point
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
prev_events = [e.event_id]
pl_events.append(e)
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
# the PL event and each of the states that got reverted.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_events[i].event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for j in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_events[i].event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
self.assertEqual([], received_rows)
event_count = 0
def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
) -> EventBase:
if sender is None:
sender = self.user_id
if body is None:
body = "event %i" % (self.event_count,)
self.event_count += 1
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_event",
content={"body": body},
**kwargs
)
def _inject_state_event(
self,
body: Optional[str] = None,
state_key: Optional[str] = None,
sender: Optional[str] = None,
) -> EventBase:
if sender is None:
sender = self.user_id
if state_key is None:
state_key = "state_%i" % (self.event_count,)
self.event_count += 1
if body is None:
body = "state event %s" % (state_key,)
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_state_event",
state_key=state_key,
content={"body": body},
)

View File

@@ -12,6 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# type: ignore
from mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -20,11 +25,14 @@ USER_ID = "@feeling:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.streams.add("receipts")
self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt
self.get_success(

View File

@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
@@ -26,6 +28,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
streams.register_servlets,
]
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_typing(self):
typing = self.hs.get_typing_handler()
@@ -33,8 +38,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.streams.add("typing")
# make the client subscribe to the typing stream
self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
@@ -75,6 +80,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
row = rdata_rows[0]
self.assertEqual(room_id, row.room_id)
self.assertEqual([], row.user_ids)

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from tests.unittest import HomeserverTestCase
class RemoteServerUpTestCase(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.factory = ReplicationStreamProtocolFactory(hs)
def _make_client(self) -> Tuple[IProtocol, StringTransport]:
"""Create a new direct TCP replication connection
"""
proto = self.factory.buildProtocol(("127.0.0.1", 0))
transport = StringTransport()
proto.makeConnection(transport)
# We can safely ignore the commands received during connection.
self.pump()
transport.clear()
return proto, transport
def test_relay(self):
"""Test that Synapse will relay REMOTE_SERVER_UP commands to all
other connections, but not the one that sent it.
"""
proto1, transport1 = self._make_client()
# We shouldn't receive an echo.
proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
self.pump()
self.assertEqual(transport1.value(), b"")
# But we should see an echo if we connect another client
proto2, transport2 = self._make_client()
proto1.dataReceived(b"REMOTE_SERVER_UP example.com\n")
self.pump()
self.assertEqual(transport1.value(), b"")
self.assertEqual(transport2.value(), b"REMOTE_SERVER_UP example.com\n")

View File

@@ -39,7 +39,7 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(self, room_creator, is_public=True, tok=None):
def create_room_as(self, room_creator=None, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,3 +17,22 @@
"""
Utilities for running the unit tests
"""
from typing import Awaitable, TypeVar
TV = TypeVar("TV")
def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
"""Get the result from an Awaitable which should have completed
Asserts that the given awaitable has a result ready, and returns its value
"""
i = awaitable.__await__()
try:
next(i)
except StopIteration as e:
# awaitable returned a result
return e.value
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")

View File

@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
"""
Utility functions for poking events into the storage of the server under test.
"""
def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
**kwargs
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
target = sender
content = {"membership": membership}
if extra_content:
content.update(extra_content)
return inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
sender=sender,
state_key=target,
content=content,
**kwargs
)
def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
Args:
hs: the homeserver under test
room_version: the version of the room we're inserting into.
if not specified, will be looked up
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
test_reactor = hs.get_reactor()
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
room_version = get_awaitable_result(d)
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
d = hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
test_reactor.advance(0)
event, context = get_awaitable_result(d)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event

View File

@@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
@@ -55,6 +54,7 @@ from tests.server import (
render,
setup_test_homeserver,
)
from tests.test_utils import event_injection
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase):
"""
Inject a membership event into a room.
Deprecated: use event_injection.inject_room_member directly
Args:
room: Room ID to inject the event into.
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
room_version = self.get_success(
self.hs.get_datastore().get_room_version_id(room)
)
builder = event_builder_factory.for_room_version(
KNOWN_ROOM_VERSIONS[room_version],
{
"type": EventTypes.Member,
"sender": user,
"state_key": user,
"room_id": room,
"content": {"membership": membership},
},
)
event, context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
self.get_success(
self.hs.get_storage().persistence.persist_event(event, context)
)
event_injection.inject_member_event(self.hs, room, user, membership)
class FederatingHomeserverTestCase(HomeserverTestCase):

View File

@@ -204,6 +204,8 @@ commands = mypy \
synapse/storage/database.py \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
tests/test_utils \
tests/util/test_stream_change_cache.py
# To find all folders that pass mypy you run: