Compare commits

...

36 Commits

Author SHA1 Message Date
Erik Johnston
83ecaeecbf dkjfhsdklfhsdlkjf 2020-03-25 14:55:02 +00:00
Erik Johnston
0473f87a17 Pass instance name through to rdata 2020-03-25 14:05:53 +00:00
Erik Johnston
092b62ee7b fixup! Thread through instance name to replication client 2020-03-25 11:41:38 +00:00
Erik Johnston
b6f6f5c399 Add replication listeners to wall workers 2020-03-25 11:34:56 +00:00
Erik Johnston
f7da931d62 PEP8 ??? 2020-03-25 11:34:43 +00:00
Erik Johnston
9f15bffd72 Thread through instance name to replication client 2020-03-25 11:34:10 +00:00
Erik Johnston
6da24f2d5f Merge branch 'erikj/catchup_on_worker' of github.com:matrix-org/synapse into erikj/split_out_fed_stream 2020-03-25 10:55:23 +00:00
Erik Johnston
5473f1806a Change stream_positions to include instance name 2020-03-25 10:51:46 +00:00
Erik Johnston
f6e7daaac3 Add instance name to command 2020-03-25 10:21:22 +00:00
Erik Johnston
309c7eb1a1 Add some type aliases 2020-03-24 17:43:42 +00:00
Erik Johnston
f8038f4670 Fix HTTP update_function 2020-03-24 17:31:51 +00:00
Erik Johnston
9ea391054f DFSDJFDSLKF 2020-03-24 17:27:50 +00:00
Erik Johnston
604f57f1bd Merge branch 'erikj/catchup_on_worker' into erikj/split_out_typing 2020-03-24 17:21:26 +00:00
Erik Johnston
bd64b8fcd5 Fixup push rules stream 2020-03-24 16:52:17 +00:00
Erik Johnston
309aee4636 Move calling http replication out of base stream 2020-03-24 16:20:05 +00:00
Erik Johnston
e4c5b1d9d6 Review comments 2020-03-24 16:00:54 +00:00
Erik Johnston
7eec84bfbe Shuffle around code typing handlers 2020-03-24 15:54:38 +00:00
Erik Johnston
4dd08f2501 Make ReplicationStreamer work on workers 2020-03-24 15:53:52 +00:00
Erik Johnston
55dfcd2f09 Add redis support 2020-03-24 15:04:18 +00:00
Erik Johnston
11fb08ffa9 mypy 2020-03-24 15:03:59 +00:00
Erik Johnston
ef4f063687 Move command processing out of transport 2020-03-24 14:17:18 +00:00
Erik Johnston
2380e401e4 Remove import loop 2020-03-24 11:47:57 +00:00
Erik Johnston
5d810c36a8 mypy 2020-03-24 10:06:15 +00:00
Erik Johnston
ea17e939df Add CLEAR_USER_SYNCS command that is sent on shutdown.
This should help with the case where a synchrotron gets restarted
gracefully, rather than rely on 5 minute timeout.
2020-03-23 18:55:58 +00:00
Erik Johnston
225b993cf6 Remove conn_id usage for UserSyncCommand.
Each tcp replication connection is assigned a "conn_id", which is used
to give an ID to a remotely connected worker. In a redis world, there
will no longer be a one to one mapping between connection and instance,
so instead we need to replace such usages with an ID generated by the
remote instances and included in the replicaiton commands.

This really only effects UserSyncCommand.
2020-03-23 18:52:24 +00:00
Erik Johnston
3204b0e79f Handle connection closing under us 2020-03-23 18:29:21 +00:00
Erik Johnston
ba1a8be930 Review comments 2020-03-23 16:13:12 +00:00
Erik Johnston
a2070a2c4e Remove unused 'stream' param of REPLICATE and update docs 2020-03-23 14:56:22 +00:00
Erik Johnston
4f2a803c66 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/catchup_on_worker 2020-03-23 14:49:07 +00:00
Erik Johnston
259cdffa96 Newsfile 2020-03-20 15:31:53 +00:00
Erik Johnston
32c656865a Always subscribe to all streams.
This already happens since the worker merge.
2020-03-20 15:31:52 +00:00
Erik Johnston
8734b75ca8 Remove unused token param from REPLICATE cmd 2020-03-20 15:31:51 +00:00
Erik Johnston
1f83255de1 Move stream catchup to workers. 2020-03-20 15:31:49 +00:00
Erik Johnston
ba90596687 Add ability to catchup on stream by talking to master. 2020-03-20 15:31:47 +00:00
Erik Johnston
811d2ecf2e Don't panic if streams get behind.
The catchup will in future happen on workers, so master process won't
need to protect itself by dropping the connection.
2020-03-20 15:31:45 +00:00
Erik Johnston
7233d38690 Move stream fetch DB queries to worker stores. 2020-03-20 15:31:43 +00:00
49 changed files with 1536 additions and 1176 deletions

1
changelog.d/7024.misc Normal file
View File

@@ -0,0 +1 @@
Move catchup of replication streams logic to worker.

View File

@@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
'<' worker to master flows):
> SERVER example.com
< REPLICATE events 53
< REPLICATE
> POSITION events 53
> RDATA events 54 ["$foo1:bar.com", ...]
> RDATA events 55 ["$foo4:bar.com", ...]
The example shows the server accepting a new connection and sending its
identity with the `SERVER` command, followed by the client asking to
subscribe to the `events` stream from the token `53`. The server then
periodically sends `RDATA` commands which have the format
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
defined by the individual streams.
The example shows the server accepting a new connection and sending its identity
with the `SERVER` command, followed by the client server to respond with the
position of all streams. The server then periodically sends `RDATA` commands
which have the format `RDATA <stream_name> <token> <row>`, where the format of
`<row>` is defined by the individual streams.
Error reporting happens by either the client or server sending an ERROR
command, and usually the connection will be closed.
@@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
connect to the server using a tool like netcat. A few things should be
noted when manually using the protocol:
- When subscribing to a stream using `REPLICATE`, the special token
`NOW` can be used to get all future updates. The special stream name
`ALL` can be used with `NOW` to subscribe to all available streams.
- The federation stream is only available if federation sending has
been disabled on the main process.
- The server will only time connections out that have sent a `PING`
@@ -91,9 +88,7 @@ The client:
- Sends a `NAME` command, allowing the server to associate a human
friendly name with the connection. This is optional.
- Sends a `PING` as above
- For each stream the client wishes to subscribe to it sends a
`REPLICATE` with the `stream_name` and token it wants to subscribe
from.
- Sends a `REPLICATE` to get the current position of all streams.
- On receipt of a `SERVER` command, checks that the server name
matches the expected server name.
@@ -140,9 +135,7 @@ the wire:
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -181,9 +174,9 @@ client (C):
#### POSITION (S)
The position of the stream has been updated. Sent to the client
after all missing updates for a stream have been sent to the client
and they're now up to date.
On receipt of a POSITION command clients should check if they have missed any
updates, and if so then fetch them out of band. Sent in response to a
REPLICATE command (but can happen at any time).
#### ERROR (S, C)
@@ -199,25 +192,16 @@ client (C):
#### REPLICATE (C)
Asks the server to replicate a given stream. The syntax is:
```
REPLICATE <stream_name> <token>
```
Where `<token>` may be either:
* a numeric stream_id to stream updates since (exclusive)
* `NOW` to stream all subsequent updates.
The `<stream_name>` is the name of a replication stream to subscribe
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
of streams). It can also be `ALL` to subscribe to all known streams,
in which case the `<token>` must be set to `NOW`.
Asks the server for the current position of all streams.
#### USER_SYNC (C)
A user has started or stopped syncing
#### CLEAR_USER_SYNC (C)
The server should clear all associated user sync data from the worker.
#### FEDERATION_ACK (C)
Acknowledge receipt of some federation data

View File

@@ -75,3 +75,6 @@ ignore_missing_imports = True
[mypy-jwt.*]
ignore_missing_imports = True
[mypy-txredisapi]
ignore_missing_imports = True

View File

@@ -45,6 +45,7 @@ from synapse.http.site import SynapseSite
from synapse.logging.context import LoggingContext, run_in_background
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -64,7 +65,9 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.commands import ClearUserSyncsCommand
from synapse.replication.tcp.handler import ReplicationDataHandler
from synapse.replication.tcp.streams import (
AccountDataStream,
DeviceListsStream,
@@ -75,7 +78,6 @@ from synapse.replication.tcp.streams import (
ReceiptsStream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
)
from synapse.replication.tcp.streams.events import (
EventsStream,
@@ -104,6 +106,7 @@ from synapse.rest.client.v1.room import (
RoomSendEventRestServlet,
RoomStateEventRestServlet,
RoomStateRestServlet,
RoomTypingRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
@@ -124,7 +127,6 @@ from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.generic_worker")
@@ -233,6 +235,7 @@ class GenericWorkerPresence(object):
self.user_to_num_current_syncs = {}
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.instance_id = hs.get_instance_id()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {state.user_id: state for state in active_presence}
@@ -245,13 +248,24 @@ class GenericWorkerPresence(object):
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
)
self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id)
hs.get_reactor().addSystemEventTrigger(
"before",
"shutdown",
run_as_background_process,
"generic_presence.on_shutdown",
self._on_shutdown,
)
def _on_shutdown(self):
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_command(
ClearUserSyncsCommand(self.instance_id)
)
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
if self.hs.config.use_presence:
self.hs.get_tcp_replication().send_user_sync(
user_id, is_syncing, last_sync_ms
self.instance_id, user_id, is_syncing, last_sync_ms
)
def mark_as_coming_online(self, user_id):
@@ -368,40 +382,6 @@ class GenericWorkerPresence(object):
return set()
class GenericWorkerTyping(object):
def __init__(self, hs):
self._latest_room_serial = 0
self._reset()
def _reset(self):
"""
Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}
def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
# clear everything.
self._reset()
# Set the latest serial token to whatever the server gave us.
self._latest_room_serial = token
for row in rows:
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
@@ -486,6 +466,7 @@ class GenericWorkerServer(HomeServer):
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
KeyUploadServlet(self).register(resource)
RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
@@ -541,6 +522,9 @@ class GenericWorkerServer(HomeServer):
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
@@ -583,27 +567,35 @@ class GenericWorkerServer(HomeServer):
else:
logger.warning("Unrecognized listener type: %s", listener["type"])
self.get_tcp_replication().start_replication(self)
if self.config.redis.redis_enabled:
from synapse.replication.tcp.redis import RedisFactory
logger.info("Connecting to redis.")
factory = RedisFactory(self)
self.get_reactor().connectTCP(
self.config.redis.redis_host, self.config.redis.redis_port, factory
)
else:
factory = ReplicationClientFactory(self, self.config.worker_name)
host = self.config.worker_replication_host
port = self.config.worker_replication_port
self.get_reactor().connectTCP(host, port, factory)
def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
def build_tcp_replication(self):
return GenericWorkerReplicationHandler(self)
def build_presence_handler(self):
return GenericWorkerPresence(self)
def build_typing_handler(self):
return GenericWorkerTyping(self)
def build_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
class GenericWorkerReplicationHandler(ReplicationClientHandler):
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
super().__init__(hs)
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
# NB this is a SynchrotronPresence, not a normal PresenceHandler
self.presence_handler = hs.get_presence_handler()
self.notifier = hs.get_notifier()
@@ -612,30 +604,29 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
self.pusher_pool = hs.get_pusherpool()
if hs.config.send_federation:
self.send_handler = FederationSenderHandler(hs, self)
self.send_handler = FederationSenderHandler(hs)
else:
self.send_handler = None
async def on_rdata(self, stream_name, token, rows):
await super(GenericWorkerReplicationHandler, self).on_rdata(
stream_name, token, rows
async def on_rdata(self, stream_name, instance_name, token, rows):
await super().on_rdata(stream_name, instance_name, token, rows)
run_in_background(
self.process_and_notify, stream_name, instance_name, token, rows
)
run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
args.update(self.typing_handler.stream_positions())
args = super().get_streams_to_replicate()
if self.send_handler:
args.update(self.send_handler.stream_positions())
return args
def get_currently_syncing_users(self):
return self.presence_handler.get_currently_syncing_users()
async def process_and_notify(self, stream_name, token, rows):
async def process_and_notify(self, stream_name, instance_name, token, rows):
try:
if self.send_handler:
self.send_handler.process_replication_rows(stream_name, token, rows)
self.send_handler.process_replication_rows(
stream_name, instance_name, token, rows
)
if stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
@@ -675,11 +666,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
await self.pusher_pool.on_new_receipts(
token, token, {row.room_id for row in rows}
)
elif stream_name == TypingStream.NAME:
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
@@ -740,13 +726,14 @@ class FederationSenderHandler(object):
to the federation sender.
"""
def __init__(self, hs: GenericWorkerServer, replication_client):
def __init__(self, hs: GenericWorkerServer):
self.hs = hs
self.store = hs.get_datastore()
self._is_mine_id = hs.is_mine_id
self.federation_sender = hs.get_federation_sender()
self.replication_client = replication_client
# self.replication_client = hs.get_tcp_replication()
self.federation_position = self.store.federation_out_pos_startup
self.federation_position = {"master": self.store.federation_out_pos_startup}
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
self._last_ack = self.federation_position
@@ -767,12 +754,12 @@ class FederationSenderHandler(object):
def stream_positions(self):
return {"federation": self.federation_position}
def process_replication_rows(self, stream_name, token, rows):
def process_replication_rows(self, stream_name, instance_name, token, rows):
# The federation stream contains things that we want to send out, e.g.
# presence, typing, etc.
if stream_name == "federation":
send_queue.process_rows_for_federation(self.federation_sender, rows)
run_in_background(self.update_token, token)
run_in_background(self.update_token, instance_name, token)
# We also need to poke the federation sender when new events happen
elif stream_name == "events":
@@ -820,9 +807,12 @@ class FederationSenderHandler(object):
)
await self.federation_sender.send_read_receipt(receipt_info)
async def update_token(self, token):
async def update_token(self, instance_name, token):
try:
self.federation_position = token
self.federation_position[instance_name] = token
return
# FIXME
# We linearize here to ensure we don't have races updating the token
with (await self._fed_position_linearizer.queue(None)):
@@ -833,7 +823,7 @@ class FederationSenderHandler(object):
# We ACK this token over replication so that the master can drop
# its in memory queues
self.replication_client.send_federation_ack(
self.hs.get_tcp_replication().send_federation_ack(
self.federation_position
)
self._last_ack = self.federation_position
@@ -915,6 +905,10 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.send_federation = True
config.server.handle_typing = False
if config.worker_app == "synapse.app.client_reader":
config.server.handle_typing = True
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
ss = GenericWorkerServer(
@@ -930,6 +924,8 @@ def start(config_options):
"before", "startup", _base.start, ss, config.worker_listeners
)
ss.get_replication_streamer()
_base.start_worker_reactor("synapse-generic-worker", config)

View File

@@ -263,6 +263,15 @@ class SynapseHomeServer(HomeServer):
def start_listening(self, listeners):
config = self.get_config()
if config.redis_enabled:
from synapse.replication.tcp.redis import RedisFactory
logger.info("Connecting to redis.")
factory = RedisFactory(self)
self.get_reactor().connectTCP(
self.config.redis.redis_host, self.config.redis.redis_port, factory
)
for listener in listeners:
if listener["type"] == "http":
self._listening_services.extend(self._listener_http(config, listener))
@@ -282,6 +291,7 @@ class SynapseHomeServer(HomeServer):
)
for s in services:
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
logger.warning(

View File

@@ -31,6 +31,7 @@ from .password import PasswordConfig
from .password_auth_providers import PasswordAuthProviderConfig
from .push import PushConfig
from .ratelimiting import RatelimitConfig
from .redis import RedisConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
from .room_directory import RoomDirectoryConfig
@@ -82,4 +83,5 @@ class HomeServerConfig(RootConfig):
RoomDirectoryConfig,
ThirdPartyRulesConfig,
TracerConfig,
RedisConfig,
]

47
synapse/config/redis.py Normal file
View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.config._base import Config, ConfigError
try:
import txredisapi
except ImportError:
txredisapi = None
MISSING_REDIS = """Missing 'txredisapi' library. This is required for redis support.
Install by running:
pip install txredisapi
"""
class RedisConfig(Config):
section = "redis"
def read_config(self, config, **kwargs):
redis_config = config.get("redis", {})
self.redis_enabled = redis_config.get("enabled", False)
if not self.redis_enabled:
return
if txredisapi is None:
raise ConfigError(MISSING_REDIS)
self.redis_host = redis_config.get("host", "localhost")
self.redis_port = redis_config.get("port", 6379)
self.redis_dbid = redis_config.get("dbid")
self.redis_password = redis_config.get("password")

View File

@@ -83,6 +83,8 @@ class ServerConfig(Config):
# "disable" federation
self.send_federation = config.get("send_federation", True)
self.handle_typing = config.get("handle_typing", True)
# Whether to enable user presence.
self.use_presence = config.get("use_presence", True)

View File

@@ -48,6 +48,8 @@ class WorkerConfig(Config):
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
self.instance_http_map = config.get("instance_http_map", {})
# This option is really only here to support `--manhole` command line
# argument.
manhole = config.get("worker_manhole")

View File

@@ -819,7 +819,16 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
edu_type, origin, content
)
return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
if edu_type == "m.typing":
instance_name = "synapse.app.client_reader"
else:
instance_name = "master"
return await self._send_edu(
instance_name=instance_name,
edu_type=edu_type,
origin=origin,
content=content,
)
async def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry

View File

@@ -499,4 +499,13 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return []

View File

@@ -21,6 +21,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.logging.context import run_in_background
from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
@@ -288,6 +289,54 @@ class TypingHandler(object):
return self._latest_room_serial
class TypingSlaveHandler(object):
def __init__(self, hs):
self.notifier = hs.get_notifier()
self._latest_room_serial = 0
self._reset()
def _reset(self):
"""
Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {
"typing": {"synapse.app.client_reader": self._latest_room_serial}
} # FIXME
def process_replication_rows(self, stream_name, token, rows):
if stream_name != TypingStream.NAME:
return
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
# clear everything.
self._reset()
# Set the latest serial token to whatever the server gave us.
self._latest_room_serial = token
for row in rows:
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
self.notifier.on_new_event(
"typing_key", token, rooms=[row.room_id for row in rows]
)
def get_current_token(self) -> int:
return self._latest_room_serial
class TypingNotificationEventSource(object):
def __init__(self, hs):
self.hs = hs

View File

@@ -98,6 +98,7 @@ CONDITIONAL_REQUIREMENTS = {
"sentry": ["sentry-sdk>=0.7.2"],
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
"jwt": ["pyjwt>=1.6.4"],
"redis": ["txredisapi>=1.4.7"],
}
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]

View File

@@ -21,6 +21,7 @@ from synapse.replication.http import (
membership,
register,
send_event,
streams,
)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -32,9 +33,12 @@ class ReplicationRestResource(JsonResource):
self.register_servlets(hs)
def register_servlets(self, hs):
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
if hs.config.worker_app is None:
send_event.register_servlets(hs, self)
membership.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)
streams.register_servlets(hs, self)
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
devices.register_servlets(hs, self)

View File

@@ -128,14 +128,25 @@ class ReplicationEndpoint(object):
Returns a callable that accepts the same parameters as `_serialize_payload`.
"""
clock = hs.get_clock()
host = hs.config.worker_replication_host
port = hs.config.worker_replication_http_port
master_host = hs.config.worker_replication_host
master_port = hs.config.worker_replication_http_port
instance_http_map = hs.config.instance_http_map
client = hs.get_simple_http_client()
@trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(**kwargs):
def send_request(instance_name="master", **kwargs):
if instance_name == "master":
host = master_host
port = master_port
elif instance_name in instance_http_map:
host = instance_http_map[instance_name]["host"]
port = instance_http_map[instance_name]["port"]
else:
raise Exception("Unknown instance")
data = yield cls._serialize_payload(**kwargs)
url_args = [

View File

@@ -277,8 +277,10 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
def register_servlets(hs, http_server):
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
if hs.config.worker_app is None:
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)

View File

@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationGetStreamUpdates(ReplicationEndpoint):
"""Fetches stream updates from a server. Used for streams not persisted to
the database, e.g. typing notifications.
The API looks like:
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
200 OK
{
updates: [ ... ],
upto_token: 10,
limited: False,
}
"""
NAME = "get_repl_stream_updates"
PATH_ARGS = ("stream_name",)
METHOD = "GET"
def __init__(self, hs):
super().__init__(hs)
# We pull the streams from the replication steamer (if we try and make
# them ourselves we end up in an import loop).
self.streams = hs.get_replication_streamer().get_streams()
self.instance_name = hs.config.worker_name or "master"
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token, limit):
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
async def _handle_request(self, request, stream_name):
stream = self.streams.get(stream_name)
if stream is None:
raise SynapseError(400, "Unknown stream")
from_token = parse_integer(request, "from_token", required=True)
upto_token = parse_integer(request, "upto_token", required=True)
limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since(
self.instance_name, from_token, upto_token, limit
)
return (
200,
{"updates": updates, "upto_token": upto_token, "limited": limited},
)
def register_servlets(hs, http_server):
ReplicationGetStreamUpdates(hs).register(http_server)

View File

@@ -18,8 +18,10 @@ from typing import Dict, Optional
import six
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
from synapse.storage.data_stores.main.cache import (
CURRENT_STATE_CACHE_NAME,
CacheInvalidationWorkerStore,
)
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
@@ -35,7 +37,7 @@ def __func__(inp):
return inp.__func__
class BaseSlavedStore(SQLBaseStore):
class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
@@ -57,9 +59,15 @@ class BaseSlavedStore(SQLBaseStore):
"""
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
pos["caches"] = {"master": self._cache_id_gen.get_current_token()}
return pos
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()
else:
return 0
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
if self._cache_id_gen:

View File

@@ -35,9 +35,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token()
result["user_account_data"] = position
result["room_account_data"] = position
result["tag_account_data"] = position
result["account_data"] = {"master": position}
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -45,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
result["to_device"] = {"master": self._device_inbox_id_gen.get_current_token()}
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -54,8 +54,8 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
# device list stream, so set them both to the device list ID
# generator's current token.
current_token = self._device_list_id_gen.get_current_token()
result[DeviceListsStream.NAME] = current_token
result[UserSignatureStream.NAME] = current_token
result[DeviceListsStream.NAME] = {"master": current_token}
result[UserSignatureStream.NAME] = {"master": current_token}
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -95,8 +95,8 @@ class SlavedEventStore(
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = -self._backfill_id_gen.get_current_token()
result["events"] = {"master": self._stream_id_gen.get_current_token()}
result["backfill"] = {"master": -self._backfill_id_gen.get_current_token()}
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -39,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
result["groups"] = self._group_updates_id_gen.get_current_token()
result["groups"] = {"master": self._group_updates_id_gen.get_current_token()}
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -46,7 +46,7 @@ class SlavedPresenceStore(BaseSlavedStore):
if self.hs.config.use_presence:
position = self._presence_id_gen.get_current_token()
result["presence"] = position
result["presence"] = {"master": position}
return result

View File

@@ -39,7 +39,9 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
result["push_rules"] = {
"master": self._push_rules_stream_id_gen.get_current_token()
}
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -30,9 +30,12 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
result["pushers"] = {"master": self._pushers_id_gen.get_current_token()}
return result
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "pushers":
self._pushers_id_gen.advance(token)

View File

@@ -44,7 +44,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
result["receipts"] = {"master": self._receipts_id_gen.get_current_token()}
return result
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):

View File

@@ -32,7 +32,9 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def stream_positions(self):
result = super(RoomStore, self).stream_positions()
result["public_rooms"] = self._public_room_id_gen.get_current_token()
result["public_rooms"] = {
"master": self._public_room_id_gen.get_current_token()
}
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -16,26 +16,10 @@
"""
import logging
from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.protocol import (
AbstractReplicationClientHandler,
ClientReplicationStreamProtocol,
)
from .commands import (
Command,
FederationAckCommand,
InvalidateCacheCommand,
RemoteServerUpCommand,
RemovePusherCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -51,10 +35,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
initialDelay = 0.1
maxDelay = 1 # Try at least once every N seconds
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
def __init__(self, hs, client_name):
self.client_name = client_name
self.handler = handler
self.handler = hs.get_tcp_replication()
self.server_name = hs.config.server_name
self.hs = hs
self._clock = hs.get_clock() # As self.clock is defined in super class
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
@@ -65,7 +50,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol(
self.client_name, self.server_name, self._clock, self.handler
self.hs, self.client_name, self.server_name, self._clock, self.handler,
)
def clientConnectionLost(self, connector, reason):
@@ -75,170 +60,3 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
self.connection = None
# Any pending commands to be sent once a new connection has been
# established
self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
self.factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
self.store.process_replication_rows(stream_name, token, rows)
async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
for the sync with the given data.
Used by tests.
"""
d = self.awaiting_syncs.pop(data, None)
if d:
d.callback(data)
def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
room_account_data = args.pop("room_account_data", None)
if user_account_data:
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users. (Overriden by the synchrotron's only)
"""
return []
def send_command(self, cmd):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
if self.connection:
self.connection.send_command(cmd)
else:
logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
def send_remove_pusher(self, app_id, push_key, user_id):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func, keys):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
[Not currently] used by tests.
"""
return self.awaiting_syncs.setdefault(data, defer.Deferred())
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
self.connection = connection
if connection:
for cmd in self.pending_commands:
connection.send_command(cmd)
self.pending_commands = []
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
if self.factory:
self.factory.resetDelay()

View File

@@ -86,7 +86,7 @@ class RdataCommand(Command):
Format::
RDATA <stream_name> <token> <row_json>
RDATA <stream_name> <instance_name> <token> <row_json>
The `<token>` may either be a numeric stream id OR "batch". The latter case
is used to support sending multiple updates with the same stream ID. This
@@ -107,22 +107,27 @@ class RdataCommand(Command):
NAME = "RDATA"
def __init__(self, stream_name, token, row):
def __init__(self, stream_name, instance_name, token, row):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
self.row = row
@classmethod
def from_line(cls, line):
stream_name, token, row_json = line.split(" ", 2)
stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls(
stream_name, None if token == "batch" else int(token), json.loads(row_json)
stream_name,
instance_name,
None if token == "batch" else int(token),
json.loads(row_json),
)
def to_line(self):
return " ".join(
(
self.stream_name,
self.instance_name,
str(self.token) if self.token is not None else "batch",
_json_encoder.encode(self.row),
)
@@ -136,23 +141,24 @@ class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
On receipt of a POSITION command clients should check if they have missed
any updates, and if so then fetch them out of band.
"""
NAME = "POSITION"
def __init__(self, stream_name, token):
def __init__(self, stream_name, instance_name, token):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
stream_name, token = line.split(" ", 1)
return cls(stream_name, int(token))
stream_name, instance_name, token = line.split(" ", 2)
return cls(stream_name, instance_name, int(token))
def to_line(self):
return " ".join((self.stream_name, str(self.token)))
return " ".join((self.stream_name, self.instance_name, str(self.token)))
class ErrorCommand(Command):
@@ -179,42 +185,24 @@ class NameCommand(Command):
class ReplicateCommand(Command):
"""Sent by the client to subscribe to the stream.
"""Sent by the client to subscribe to streams.
Format::
REPLICATE <stream_name> <token>
Where <token> may be either:
* a numeric stream_id to stream updates from
* "NOW" to stream all subsequent updates.
The <stream_name> can be "ALL" to subscribe to all known streams, in which
case the <token> must be set to "NOW", i.e.::
REPLICATE ALL NOW
REPLICATE
"""
NAME = "REPLICATE"
def __init__(self, stream_name, token):
self.stream_name = stream_name
self.token = token
def __init__(self):
pass
@classmethod
def from_line(cls, line):
stream_name, token = line.split(" ", 1)
if token in ("NOW", "now"):
token = "NOW"
else:
token = int(token)
return cls(stream_name, token)
return cls()
def to_line(self):
return " ".join((self.stream_name, str(self.token)))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
return ""
class UserSyncCommand(Command):
@@ -225,30 +213,32 @@ class UserSyncCommand(Command):
Format::
USER_SYNC <user_id> <state> <last_sync_ms>
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
Where <state> is either "start" or "stop"
"""
NAME = "USER_SYNC"
def __init__(self, user_id, is_syncing, last_sync_ms):
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
self.instance_id = instance_id
self.user_id = user_id
self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms
@classmethod
def from_line(cls, line):
user_id, state, last_sync_ms = line.split(" ", 2)
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"):
raise Exception("Invalid USER_SYNC state %r" % (state,))
return cls(user_id, state == "start", int(last_sync_ms))
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self):
return " ".join(
(
self.instance_id,
self.user_id,
"start" if self.is_syncing else "end",
str(self.last_sync_ms),
@@ -256,6 +246,30 @@ class UserSyncCommand(Command):
)
class ClearUserSyncsCommand(Command):
"""Sent by the client to inform the server that it should drop all
information about syncing users sent by the client.
Mainly used when client is about to shut down.
Format::
CLEAR_USER_SYNC <instance_id>
"""
NAME = "CLEAR_USER_SYNC"
def __init__(self, instance_id):
self.instance_id = instance_id
@classmethod
def from_line(cls, line):
return cls(line)
def to_line(self):
return self.instance_id
class FederationAckCommand(Command):
"""Sent by the client when it has processed up to a given point in the
federation stream. This allows the master to drop in-memory caches of the
@@ -416,6 +430,7 @@ _COMMANDS = (
InvalidateCacheCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
@@ -438,6 +453,7 @@ VALID_CLIENT_COMMANDS = (
ReplicateCommand.NAME,
PingCommand.NAME,
UserSyncCommand.NAME,
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
RemovePusherCommand.NAME,
InvalidateCacheCommand.NAME,

View File

@@ -0,0 +1,399 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A replication client for use by synapse workers.
"""
import logging
from typing import Any, Callable, Dict, List
from prometheus_client import Counter
from synapse.metrics import LaterGauge
from synapse.replication.tcp.commands import (
ClearUserSyncsCommand,
Command,
FederationAckCommand,
InvalidateCacheCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
RemovePusherCommand,
ReplicateCommand,
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
logger = logging.getLogger(__name__)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
class ReplicationClientHandler:
"""Handles incoming commands from replication.
Proxies data to `HomeServer.get_replication_data_handler()`.
"""
def __init__(self, hs):
self.replication_data_handler = hs.get_replication_data_handler()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.clock = hs.get_clock()
self.presence_handler = hs.get_presence_handler()
self.instance_id = hs.get_instance_id()
self.instance_name = hs.config.worker.worker_name or "master"
self.connections = [] # type: List[Any]
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
LaterGauge(
"synapse_replication_tcp_resource_connections_per_stream",
"",
["stream_name"],
lambda: {
(stream_name,): len(
[
conn
for conn in self.connections
if stream_name in conn.replication_streams
]
)
for stream_name in self.streams
},
)
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, List[Any]]
self.is_master = hs.config.worker_app is None
self.federation_sender = None
if self.is_master and not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self._server_notices_sender = None
if self.is_master:
self._server_notices_sender = hs.get_server_notices_sender()
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
def new_connection(self, connection):
self.connections.append(connection)
def lost_connection(self, connection):
try:
self.connections.remove(connection)
except ValueError:
pass
def connected(self) -> bool:
"""Do we have any replication connections open?
Used to no-op if nothing is connected.
"""
return bool(self.connections)
async def on_REPLICATE(self, cmd: ReplicateCommand):
# We only want to announce positions by the writer of the streams.
# Currently this is just the master process.
if not self.is_master:
return
if not self.connections:
raise Exception("Not connected")
for stream_name, stream in self.streams.items():
current_token = stream.current_token()
self.send_command(
PositionCommand(stream_name, self.instance_name, current_token)
)
async def on_USER_SYNC(self, cmd: UserSyncCommand):
user_sync_counter.inc()
if self.is_master:
await self.presence_handler.update_external_syncs_row(
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
if self.is_master:
await self.presence_handler.update_external_syncs_clear(cmd.instance_id)
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
remove_pusher_counter.inc()
if self.is_master:
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
)
self.notifier.on_new_replication_data()
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
invalidate_cache_counter.inc()
if self.is_master:
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(
cmd.cache_func, tuple(cmd.keys)
)
async def on_USER_IP(self, cmd: UserIpCommand):
user_ip_cache_counter.inc()
if self.is_master:
await self.store.insert_client_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
if self._server_notices_sender:
await self._server_notices_sender.on_user_ip(cmd.user_id)
async def on_RDATA(self, cmd: RdataCommand):
stream_name = cmd.stream_name
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name: name of the replication stream for this batch of rows
token: stream token for this batch of rows
rows: a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
logger.info("Received rdata %s %s -> %s", stream_name, instance_name, token)
await self.replication_data_handler.on_rdata(
stream_name, instance_name, token, rows
)
async def on_POSITION(self, cmd: PositionCommand):
stream = self.streams.get(cmd.stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
return
# Find where we previously streamed up to.
current_tokens = self.replication_data_handler.get_streams_to_replicate().get(
cmd.stream_name
)
if current_tokens is None:
logger.debug(
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
)
return
current_token = current_tokens.get(cmd.instance_name, 0)
# Fetch all updates between then and now.
limited = cmd.token != current_token
while limited:
updates, current_token, limited = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
if updates:
await self.on_rdata(
cmd.stream_name,
cmd.instance_name,
current_token,
[stream.parse_row(update[1]) for update in updates],
)
# We've now caught up to position sent to us, notify handler.
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
# Handle any RDATA that came in while we were catching up.
rows = self.pending_batches.pop(cmd.stream_name, [])
if rows:
await self.on_rdata(
cmd.stream_name, cmd.instance_name, rows[-1].token, rows
)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
"""Called when get a new REMOTE_SERVER_UP command."""
if self.is_master:
self.notifier.notify_remote_server_up(cmd.data)
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users.
"""
return self.presence_handler.get_currently_syncing_users()
def send_command(self, cmd: Command):
"""Send a command to master (when we get establish a connection if we
don't have one already.)
"""
for conn in self.connections:
conn.send_command(cmd)
def send_federation_ack(self, token: int):
"""Ack data for the federation stream. This allows the master to drop
data stored purely in memory.
"""
self.send_command(FederationAckCommand(token))
def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
"""Poke the master that a user has started/stopped syncing.
"""
self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
)
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
"""Poke the master to remove a pusher for a user
"""
cmd = RemovePusherCommand(app_id, push_key, user_id)
self.send_command(cmd)
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
"""Poke the master to invalidate a cache.
"""
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
self.send_command(cmd)
def send_user_ip(
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: str,
last_seen: int,
):
"""Tell the master that the user made a request.
"""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: str, data: Any):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, self.instance_name, token, data))
class ReplicationDataHandler:
"""A replication data handler that simply discards all data.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
self.typing_handler = hs.get_typing_handler()
self.slaved_store = hs.config.worker_app is not None
self.slaved_typing = not hs.config.server.handle_typing
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
if self.slaved_store:
self.store.process_replication_rows(stream_name, token, rows)
if self.slaved_typing:
self.typing_handler.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
args = {} # type: Dict[str, int]
if self.slaved_store:
args = self.store.stream_positions()
if self.slaved_typing:
args.update(self.typing_handler.stream_positions())
return args
async def on_position(self, stream_name: str, token: int):
if self.slaved_store:
self.store.process_replication_rows(stream_name, token, [])
if self.slaved_typing:
self.typing_handler.process_replication_rows(stream_name, token, [])

View File

@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
> PING 1490197665618
< NAME synapse.app.appservice
< PING 1490197665618
< REPLICATE events 1
< REPLICATE backfill 1
< REPLICATE caches 1
< REPLICATE
> POSITION events 1
> POSITION backfill 1
> POSITION caches 1
@@ -48,45 +46,40 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
import abc
import fcntl
import logging
import struct
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple
from typing import Any, DefaultDict, Dict, List, Set
from six import iteritems, iterkeys
from six import iteritems
from prometheus_client import Counter
from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
Command,
ErrorCommand,
NameCommand,
PingCommand,
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
from synapse.replication.tcp.streams import STREAMS_MAP
from synapse.types import Collection
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
from synapse.util import Clock
from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
from synapse.server import HomeServer
connection_close_counter = Counter(
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
)
@@ -127,16 +120,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n"
# Valid commands we expect to receive
VALID_INBOUND_COMMANDS = [] # type: Collection[str]
# Valid commands we can send
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
def __init__(self, clock):
def __init__(self, clock, handler):
self.clock = clock
self.handler = handler
self.last_received_command = self.clock.time_msec()
self.last_sent_command = 0
@@ -176,6 +164,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# can time us out.
self.send_command(PingCommand(self.clock.time_msec()))
self.handler.new_connection(self)
def send_ping(self):
"""Periodically sends a ping and checks if we should close the connection
due to the other side timing out.
@@ -213,11 +203,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
line = line.decode("utf-8")
cmd_name, rest_of_line = line.split(" ", 1)
if cmd_name not in self.VALID_INBOUND_COMMANDS:
logger.error("[%s] invalid command %s", self.id(), cmd_name)
self.send_error("invalid command: %s", cmd_name)
return
self.last_received_command = self.clock.time_msec()
self.inbound_commands_counter[cmd_name] = (
@@ -249,8 +234,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
Args:
cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
await handler(cmd)
handled = False
# First call any command handlers on this instance. These are for TCP
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
def close(self):
logger.warning("[%s] Closing connection", self.id())
@@ -258,6 +258,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.transport.loseConnection()
self.on_connection_closed()
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def send_error(self, error_string, *args):
"""Send an error to remote and close the connection.
"""
@@ -379,6 +382,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.state = ConnectionStates.CLOSED
self.pending_commands = []
self.handler.lost_connection(self)
if self.transport:
self.transport.unregisterProducer()
@@ -402,346 +407,66 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
def __init__(self, server_name, clock, streamer):
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
def __init__(self, hs, server_name, clock, handler):
BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
self.server_name = server_name
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
async def on_USER_SYNC(self, cmd):
await self.streamer.on_user_sync(
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
await self.subscribe_to_stream(stream_name, token)
async def on_FEDERATION_ACK(self, cmd):
self.streamer.federation_ack(cmd.token)
async def on_REMOVE_PUSHER(self, cmd):
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
async def on_INVALIDATE_CACHE(self, cmd):
await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.streamer.on_remote_server_up(cmd.data)
async def on_USER_IP(self, cmd):
self.streamer.on_user_ip(
cmd.user_id,
cmd.access_token,
cmd.ip,
cmd.user_agent,
cmd.device_id,
cmd.last_seen,
)
async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
updates down if they have. During that time new updates for the stream
are queued and sent once we've sent down any missed updates.
"""
self.replication_streams.discard(stream_name)
self.connecting_streams.add(stream_name)
try:
# Get missing updates
updates, current_token = await self.streamer.get_stream_updates(
stream_name, token
)
# Send all the missing updates
for update in updates:
token, row = update[0], update[1]
self.send_command(RdataCommand(stream_name, token, row))
# We send a POSITION command to ensure that they have an up to
# date token (especially useful if we didn't send any updates
# above)
self.send_command(PositionCommand(stream_name, current_token))
# Now we can send any updates that came in while we were subscribing
pending_rdata = self.pending_rdata.pop(stream_name, [])
updates = []
for token, update in pending_rdata:
# If the token is null, it is part of a batch update. Batches
# are multiple updates that share a single token. To denote
# this, the token is set to None for all tokens in the batch
# except for the last. If we find a None token, we keep looking
# through tokens until we find one that is not None and then
# process all previous updates in the batch as if they had the
# final token.
if token is None:
# Store this update as part of a batch
updates.append(update)
continue
if token <= current_token:
# This update or batch of updates is older than
# current_token, dismiss it
updates = []
continue
updates.append(update)
# Send all updates that are part of this batch with the
# found token
for update in updates:
self.send_command(RdataCommand(stream_name, token, update))
# Clear stored updates
updates = []
# They're now fully subscribed
self.replication_streams.add(stream_name)
except Exception as e:
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
self.send_error("failed to handle replicate: %r", e)
finally:
self.connecting_streams.discard(stream_name)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
if stream_name in self.replication_streams:
# The client is subscribed to the stream
self.send_command(RdataCommand(stream_name, token, data))
elif stream_name in self.connecting_streams:
# The client is being subscribed to the stream
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
self.pending_rdata.setdefault(stream_name, []).append((token, data))
else:
# The client isn't subscribed
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
def send_sync(self, data):
self.send_command(SyncCommand(data))
def send_remote_server_up(self, server: str):
self.send_command(RemoteServerUpCommand(server))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
"""
The interface for the handler that should be passed to
ClientReplicationStreamProtocol
"""
@abc.abstractmethod
async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
"""
raise NotImplementedError()
@abc.abstractmethod
async def on_position(self, stream_name, token):
"""Called when we get new position data."""
raise NotImplementedError()
@abc.abstractmethod
def on_sync(self, data):
"""Called when get a new SYNC command."""
raise NotImplementedError()
@abc.abstractmethod
async def on_remote_server_up(self, server: str):
"""Called when get a new REMOTE_SERVER_UP command."""
raise NotImplementedError()
@abc.abstractmethod
def get_streams_to_replicate(self):
"""Called when a new connection has been established and we need to
subscribe to streams.
Returns:
map from stream name to the most recent update we have for
that stream (ie, the point we want to start replicating from)
"""
raise NotImplementedError()
@abc.abstractmethod
def get_currently_syncing_users(self):
"""Get the list of currently syncing users (if any). This is called
when a connection has been established and we need to send the
currently syncing users."""
raise NotImplementedError()
@abc.abstractmethod
def update_connection(self, connection):
"""Called when a connection has been established (or lost with None).
"""
raise NotImplementedError()
@abc.abstractmethod
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
raise NotImplementedError()
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
def __init__(
self,
hs: "HomeServer",
client_name: str,
server_name: str,
clock: Clock,
handler: AbstractReplicationClientHandler,
handler,
):
BaseReplicationStreamProtocol.__init__(self, clock)
BaseReplicationStreamProtocol.__init__(self, clock, handler)
self.instance_id = hs.get_instance_id()
self.client_name = client_name
self.server_name = server_name
self.handler = handler
self.streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set() # type: Set[str]
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {} # type: Dict[str, Any]
self.pending_batches = {} # type: Dict[str, List[Any]]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
# happen on synchrotrons)
currently_syncing = self.handler.get_currently_syncing_users()
now = self.clock.time_msec()
for user_id in currently_syncing:
self.send_command(UserSyncCommand(user_id, True, now))
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
self.send_command(NameCommand(self.client_name))
self.replicate()
async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
await self.handler.on_rdata(stream_name, cmd.token, rows)
async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
await self.handler.on_position(cmd.stream_name, cmd.token)
async def on_SYNC(self, cmd):
self.handler.on_sync(cmd.data)
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
def replicate(self):
"""Send the subscription request to the server
"""
if stream_name not in STREAMS_MAP:
raise Exception("Invalid stream name %r" % (stream_name,))
logger.info("[%s] Subscribing to replication streams", self.id())
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
self.id(),
stream_name,
token,
)
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token))
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.handler.update_connection(None)
self.send_command(ReplicateCommand())
# The following simply registers metrics for the replication connections

View File

@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import txredisapi
from synapse.logging.context import PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
COMMAND_MAP,
Command,
RdataCommand,
ReplicateCommand,
)
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
"""
def connectionMade(self):
logger.info("Connected to redis instance")
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())
self.handler.new_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
if message.strip() == "":
# Ignore blank lines
return
line = message
cmd_name, rest_of_line = line.split(" ", 1)
cmd_cls = COMMAND_MAP[cmd_name]
try:
cmd = cmd_cls.from_line(rest_of_line)
except Exception as e:
logger.exception(
"[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
)
self.send_error(
"failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
)
return
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
cmd: received command
"""
handled = False
# First call any command handlers on this instance. These are for redis
# specific handling.
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
# Then call out to the handler.
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
if cmd_func:
await cmd_func(cmd)
handled = True
if not handled:
logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason):
logger.info("Lost connection to redis instance")
self.handler.lost_connection(self)
def send_command(self, cmd):
"""Send a command if connection has been established.
Args:
cmd (Command)
"""
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
encoded_string = string.encode("utf-8")
async def _send():
with PreserveLoggingContext():
await self.redis_connection.publish(self.stream_name, encoded_string)
run_as_background_process("send-cmd", _send)
def stream_update(self, stream_name, token, data):
"""Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not
"""
self.send_command(RdataCommand(stream_name, "master", token, data))
class RedisFactory(txredisapi.SubscriberFactory):
maxDelay = 5
continueTrying = True
protocol = RedisSubscriber
def __init__(self, hs):
super(RedisFactory, self).__init__()
self.password = hs.config.redis.redis_password
self.handler = hs.get_tcp_replication()
self.stream_name = hs.hostname
self.redis_connection = txredisapi.lazyConnection(
host=hs.config.redis_host,
port=hs.config.redis_port,
dbid=hs.config.redis_dbid,
password=hs.config.redis.redis_password,
reconnect=True,
)
self.conn_id = random_string(5)
def buildProtocol(self, addr):
p = super(RedisFactory, self).buildProtocol(addr)
p.handler = self.handler
p.redis_connection = self.redis_connection
p.conn_id = self.conn_id
p.stream_name = self.stream_name
return p

View File

@@ -17,32 +17,21 @@
import logging
import random
from typing import Any, List
from six import itervalues
from typing import Dict, List
from prometheus_client import Counter
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP
from .streams.federation import FederationStream
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import STREAMS_MAP, Stream, TypingStream
from synapse.replication.tcp.streams.federation import FederationStream
from synapse.util.metrics import Measure
stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
invalidate_cache_counter = Counter(
"synapse_replication_tcp_resource_invalidate_cache", ""
)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -52,13 +41,18 @@ class ReplicationStreamProtocolFactory(Factory):
"""
def __init__(self, hs):
self.streamer = ReplicationStreamer(hs)
self.handler = hs.get_tcp_replication()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
self.hs = hs
# Ensure the replication streamer is started if we register a
# replication server endpoint.
hs.get_replication_streamer()
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
self.server_name, self.clock, self.streamer
self.hs, self.server_name, self.clock, self.handler
)
@@ -71,67 +65,43 @@ class ReplicationStreamer(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self._server_notices_sender = hs.get_server_notices_sender()
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
self.connections = [] # type: List[ServerReplicationStreamProtocol]
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream:
continue
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
[],
lambda: len(self.connections),
)
if stream == TypingStream:
continue
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
stream(hs)
for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams.append(stream(hs))
if hs.config.server.handle_typing:
self.streams.append(TypingStream(hs))
# We always add federation stream
self.streams.append(FederationStream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
LaterGauge(
"synapse_replication_tcp_resource_connections_per_stream",
"",
["stream_name"],
lambda: {
(stream_name,): len(
[
conn
for conn in self.connections
if stream_name in conn.replication_streams
]
)
for stream_name in self.streams_by_name
},
)
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
self.pending_updates = False
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
self.client = hs.get_tcp_replication()
def on_shutdown(self):
# close all connections on shutdown
for conn in self.connections:
conn.send_error("server shutting down")
def get_streams(self) -> Dict[str, Stream]:
"""Get a mapp from stream name to stream instance.
"""
return self.streams_by_name
def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the
@@ -140,7 +110,7 @@ class ReplicationStreamer(object):
This should get called each time new data is available, even if it
is currently being executed, so that nothing gets missed
"""
if not self.connections:
if not self.client.connected():
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever
for stream in self.streams:
@@ -190,15 +160,14 @@ class ReplicationStreamer(object):
stream.current_token(),
)
try:
updates, current_token = await stream.get_updates()
updates, current_token, limited = await stream.get_updates()
self.pending_updates |= limited
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug(
"Sending %d updates to %d connections",
len(updates),
len(self.connections),
"Sending %d updates", len(updates),
)
if updates:
@@ -214,116 +183,17 @@ class ReplicationStreamer(object):
# token. See RdataCommand for more details.
batched_updates = _batch_updates(updates)
for conn in self.connections:
for token, row in batched_updates:
try:
conn.stream_update(stream.NAME, token, row)
except Exception:
logger.exception("Failed to replicate")
for token, row in batched_updates:
try:
self.client.stream_update(stream.NAME, token, row)
except Exception:
logger.exception("Failed to replicate")
logger.debug("No more pending updates, breaking poke loop")
finally:
self.pending_updates = False
self.is_looping = False
@measure_func("repl.get_stream_updates")
async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
stream = self.streams_by_name.get(stream_name, None)
if not stream:
raise Exception("unknown stream %s", stream_name)
return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
"""We've received an ack for federation stream from a client.
"""
federation_ack_counter.inc()
if self.federation_sender:
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
await self.presence_handler.update_external_syncs_row(
conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
# We invalidate the cache locally, but then also stream that to other
# workers.
await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
async def on_user_ip(
self, user_id, access_token, ip, user_agent, device_id, last_seen
):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
await self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen
)
await self._server_notices_sender.on_user_ip(user_id)
@measure_func("repl.on_remote_server_up")
def on_remote_server_up(self, server: str):
self.notifier.notify_remote_server_up(server)
def send_remote_server_up(self, server: str):
for conn in self.connections:
conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
Used in tests.
"""
for conn in self.connections:
conn.send_sync(data)
def new_connection(self, connection):
"""A new client connection has been established
"""
self.connections.append(connection)
def lost_connection(self, connection):
"""A client connection has been lost
"""
try:
self.connections.remove(connection)
except ValueError:
pass
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
run_as_background_process(
"update_external_syncs_clear",
self.presence_handler.update_external_syncs_clear,
connection.conn_id,
)
def _batch_updates(updates):
"""Takes a list of updates of form [(token, row)] and sets the token to

View File

@@ -24,6 +24,9 @@ Each stream is defined by the following information:
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
from typing import Dict, Type
from synapse.replication.tcp.streams._base import (
AccountDataStream,
BackfillStream,
@@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
PushersStream,
PushRulesStream,
ReceiptsStream,
Stream,
TagAccountDataStream,
ToDeviceStream,
TypingStream,
@@ -63,10 +67,12 @@ STREAMS_MAP = {
GroupServerStream,
UserSignatureStream,
)
}
} # type: Dict[str, Type[Stream]]
__all__ = [
"STREAMS_MAP",
"Stream",
"BackfillStream",
"PresenceStream",
"TypingStream",

View File

@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
from typing import Any, List, Optional, Tuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
# Some type aliases to make things a bit easier.
# A stream position token
Token = int
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
StreamRow = Tuple[Token, tuple]
class Stream(object):
"""Base class for the streams.
@@ -56,70 +65,58 @@ class Stream(object):
return cls.ROW_TYPE(*row)
def __init__(self, hs):
# The token from which we last asked for updates
self.last_token = self.current_token()
self.local_instance_name = hs.config.worker_name or "master"
def discard_updates_and_advance(self):
"""Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers.
"""
self.last_token = self.current_token()
async def get_updates(self):
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before).
Returns:
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
"""
updates, current_token = await self.get_updates_since(self.last_token)
current_token = self.current_token()
updates, current_token, limited = await self.get_updates_since(
self.local_instance_name, self.last_token, current_token
)
self.last_token = current_token
return updates, current_token
return updates, current_token, limited
async def get_updates_since(
self, from_token: int
) -> Tuple[List[Tuple[int, JsonDict]], int]:
self, instance_name: str, from_token: Token, upto_token: Token, limit: int = 100
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
"""Like get_updates except allows specifying from when we should
stream updates
Returns:
Resolves to a pair `(updates, new_last_token)`, where `updates` is
a list of `(token, row)` entries and `new_last_token` is the new
position in stream.
A triplet `(updates, new_last_token, limited)`, where `updates` is
a list of `(token, row)` entries, `new_last_token` is the new
position in stream, and `limited` is whether there are more updates
to fetch.
"""
if from_token in ("NOW", "now"):
return [], self.current_token()
current_token = self.current_token()
from_token = int(from_token)
if from_token == current_token:
return [], current_token
if from_token == upto_token:
return [], upto_token, False
rows = await self.update_function(
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
updates, upto_token, limited = await self.update_function(
instance_name, from_token, upto_token, limit=limit,
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
updates = [(row[0], row[1:]) for row in rows]
# check we didn't get more rows than the limit.
# doing it like this allows the update_function to be a generator.
if len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
# The update function didn't hit the limit, so we must have got all
# the updates to `current_token`, and can return that as our new
# stream position.
return updates, current_token
return updates, upto_token, limited
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -141,6 +138,49 @@ class Stream(object):
raise NotImplementedError()
def db_query_to_update_function(
query_function: Callable[[str, Token, Token, int], Awaitable[List[tuple]]]
) -> Callable[[str, Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) == limit:
upto_token = rows[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(
hs, stream_name: str
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
"""
client = ReplicationGetStreamUpdates.make_client(hs)
async def update_function(
instance_name: str, from_token: int, upto_token: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
return await client(
instance_name=instance_name,
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
)
return update_function
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -164,7 +204,7 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -190,8 +230,15 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
self._is_worker = hs.config.worker_app is not None
self.current_token = store.get_current_presence_token # type: ignore
self.update_function = presence_handler.get_all_presence_updates # type: ignore
if hs.config.worker_app is None:
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -208,7 +255,12 @@ class TypingStream(Stream):
typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token # type: ignore
self.update_function = typing_handler.get_all_typing_updates # type: ignore
if hs.config.handle_typing:
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
else:
# Query master process
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
super(TypingStream, self).__init__(hs)
@@ -232,7 +284,7 @@ class ReceiptsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = store.get_all_updated_receipts # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -256,7 +308,13 @@ class PushRulesStream(Stream):
async def update_function(self, from_token, to_token, limit):
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
return [(row[0], row[2]) for row in rows]
limited = False
if len(rows) == limit:
to_token = rows[-1][0]
limited = True
return [(row[0], (row[2],)) for row in rows], to_token, limited
class PushersStream(Stream):
@@ -275,7 +333,7 @@ class PushersStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = store.get_all_updated_pushers_rows # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
super(PushersStream, self).__init__(hs)
@@ -307,7 +365,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = store.get_all_updated_caches # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
super(CachesStream, self).__init__(hs)
@@ -333,7 +391,7 @@ class PublicRoomsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = store.get_all_new_public_rooms # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -354,7 +412,7 @@ class DeviceListsStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -372,7 +430,7 @@ class ToDeviceStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = store.get_all_new_device_messages # type: ignore
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,7 +450,7 @@ class TagAccountDataStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = store.get_all_updated_tags # type: ignore
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -412,10 +470,11 @@ class AccountDataStream(Stream):
self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(AccountDataStream, self).__init__(hs)
async def update_function(self, from_token, to_token, limit):
async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
@@ -442,7 +501,7 @@ class GroupServerStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_group_stream_token # type: ignore
self.update_function = store.get_all_groups_changes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -460,6 +519,6 @@ class UserSignatureStream(Stream):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
super(UserSignatureStream, self).__init__(hs)

View File

@@ -19,7 +19,7 @@ from typing import Tuple, Type
import attr
from ._base import Stream
from ._base import Stream, db_query_to_update_function
"""Handling of the 'events' replication stream
@@ -117,10 +117,11 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token # type: ignore
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
super(EventsStream, self).__init__(hs)
async def update_function(self, from_token, current_token, limit=None):
async def _update_function(self, from_token, current_token, limit=None):
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)

View File

@@ -15,7 +15,7 @@
# limitations under the License.
from collections import namedtuple
from ._base import Stream
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
class FederationStream(Stream):
@@ -33,11 +33,14 @@ class FederationStream(Stream):
NAME = "federation"
ROW_TYPE = FederationStreamRow
_QUERY_MASTER = True
def __init__(self, hs):
# Not all synapse instances will have a federation sender instance,
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
# so we stub the stream out when that is the case.
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = federation_sender.get_replication_rows # type: ignore
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
super(FederationStream, self).__init__(hs)

View File

@@ -816,7 +816,7 @@ class RoomTypingRestServlet(RestServlet):
content = parse_json_object_from_request(request)
await self.presence_handler.bump_presence_active_time(requester.user)
# await self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)

View File

@@ -78,13 +78,18 @@ from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.typing import TypingHandler, TypingSlaveHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.handler import (
ReplicationClientHandler,
ReplicationDataHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
@@ -100,6 +105,7 @@ from synapse.storage import DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
@@ -199,6 +205,8 @@ class HomeServer(object):
"saml_handler",
"event_client_serializer",
"storage",
"replication_streamer",
"replication_data_handler",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -224,6 +232,8 @@ class HomeServer(object):
self._listening_services = []
self.start_time = None
self.instance_id = random_string(5)
self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
@@ -236,6 +246,11 @@ class HomeServer(object):
for depname in kwargs:
setattr(self, depname, kwargs[depname])
def get_instance_id(self):
"""A unique ID for this synapse process instance.
"""
return self.instance_id
def setup(self):
logger.info("Setting up.")
self.start_time = int(self.get_clock().time())
@@ -339,7 +354,10 @@ class HomeServer(object):
return PresenceHandler(self)
def build_typing_handler(self):
return TypingHandler(self)
if self.config.handle_typing:
return TypingHandler(self)
else:
return TypingSlaveHandler(self)
def build_sync_handler(self):
return SyncHandler(self)
@@ -439,10 +457,8 @@ class HomeServer(object):
def build_federation_sender(self):
if self.should_send_federation():
return FederationSender(self)
elif not self.config.worker_app:
return FederationRemoteSendQueue(self)
else:
raise Exception("Workers cannot send federation traffic")
return FederationRemoteSendQueue(self)
def build_receipts_handler(self):
return ReceiptsHandler(self)
@@ -451,7 +467,7 @@ class HomeServer(object):
return ReadMarkerHandler(self)
def build_tcp_replication(self):
raise NotImplementedError()
return ReplicationClientHandler(self)
def build_action_generator(self):
return ActionGenerator(self)
@@ -536,6 +552,12 @@ class HomeServer(object):
def build_storage(self) -> Storage:
return Storage(self, self.datastores)
def build_replication_streamer(self) -> ReplicationStreamer:
return ReplicationStreamer(self)
def build_replication_data_handler(self):
return ReplicationDataHandler(self)
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@@ -106,7 +106,7 @@ class HomeServer(object):
pass
def get_tcp_replication(
self,
) -> synapse.replication.tcp.client.ReplicationClientHandler:
) -> synapse.replication.tcp.handler.ReplicationClientHandler:
pass
def get_federation_registry(
self,
@@ -114,3 +114,5 @@ class HomeServer(object):
pass
def is_mine_id(self, domain_id: str) -> bool:
pass
def get_instance_id(self) -> str:
pass

View File

@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class CacheInvalidationStore(SQLBaseStore):
class CacheInvalidationWorkerStore(SQLBaseStore):
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
class CacheInvalidationStore(CacheInvalidationWorkerStore):
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
},
)
def get_all_updated_caches(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
# send across cache invalidations as quickly as possible. Cache
# invalidations are idempotent, so duplicates are fine.
sql = (
"SELECT stream_id, cache_func, keys, invalidation_ts"
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def get_cache_stream_token(self):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token()

View File

@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
last_pos(int):
current_pos(int):
limit(int):
Returns:
A deferred list of rows from the device inbox
"""
if last_pos == current_pos:
return defer.succeed([])
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
sql = (
"SELECT max(stream_id), destination"
" FROM device_federation_outbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
# Order by ascending stream ordering
rows.sort()
return rows
return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)

View File

@@ -1267,104 +1267,6 @@ class EventsStore(
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
return ret
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
return new_event_updates
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
new_event_updates.extend(txn.fetchall())
return new_event_updates
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
@@ -1850,22 +1752,6 @@ class EventsStore(
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
def insert_labels_for_event_txn(
self, txn, event_id, labels, room_id, topological_ordering
):

View File

@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1}
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
def get_current_events_token(self):
"""The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
return new_event_updates
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
txn.execute(sql, (-last_id, -current_id, limit))
new_event_updates = txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_id, -upper_bound))
new_event_updates.extend(txn.fetchall())
return new_event_updates
return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)

View File

@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times.

View File

@@ -15,9 +15,10 @@
from mock import Mock, NonCallableMock
from synapse.replication.tcp.client import (
ReplicationClientFactory,
from synapse.replication.tcp.client import ReplicationClientFactory
from synapse.replication.tcp.handler import (
ReplicationClientHandler,
WorkerReplicationDataHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.storage.database import make_conn
@@ -51,16 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
self.streamer = hs.get_replication_streamer()
handler_factory = Mock()
self.replication_handler = ReplicationClientHandler(self.slaved_store)
self.replication_handler.factory = handler_factory
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
# We now do some gut wrenching so that we have a client that is based
# off of the slave store rather than the main store.
self.replication_handler = ReplicationClientHandler(self.hs)
self.replication_handler.store = self.slaved_store
self.replication_handler.replication_data_handler = WorkerReplicationDataHandler(
self.slaved_store
)
client_factory = ReplicationClientFactory(self.hs, "client_name")
client_factory.handler = self.replication_handler
server = server_factory.buildProtocol(None)
client = client_factory.buildProtocol(None)

View File

@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from synapse.replication.tcp.commands import ReplicateCommand
from synapse.replication.tcp.handler import ReplicationClientHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -25,23 +26,46 @@ from tests.server import FakeTransport
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
def make_homeserver(self, reactor, clock):
self.test_handler = Mock(wraps=TestReplicationClientHandler())
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
server = server_factory.buildProtocol(None)
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(None)
# build a replication client, with a dummy handler
handler_factory = Mock()
self.test_handler = TestReplicationClientHandler()
self.test_handler.factory = handler_factory
repl_handler = ReplicationClientHandler(hs)
repl_handler.handler = self.test_handler
self.client = ClientReplicationStreamProtocol(
"client", "test", clock, self.test_handler
hs, "client", "test", clock, repl_handler,
)
# wire them together
self.client.makeConnection(FakeTransport(server, reactor))
server.makeConnection(FakeTransport(self.client, reactor))
self._client_transport = None
self._server_transport = None
def reconnect(self):
if self._client_transport:
self.client.close()
if self._server_transport:
self.server.close()
self._client_transport = FakeTransport(self.server, self.reactor)
self.client.makeConnection(self._client_transport)
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
def disconnect(self):
if self._client_transport:
self._client_transport = None
self.client.close()
if self._server_transport:
self._server_transport = None
self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -50,29 +74,22 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
def replicate_stream(self, stream, token="NOW"):
"""Make the client end a REPLICATE command to set up a subscription to a stream"""
self.client.send_command(ReplicateCommand(stream, token))
class TestReplicationClientHandler(object):
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
class TestReplicationClientHandler:
def __init__(self):
self.received_rdata_rows = []
self.streams = set()
self._received_rdata_rows = []
def get_streams_to_replicate(self):
return {}
def get_currently_syncing_users(self):
return []
def update_connection(self, connection):
pass
def finished_connecting(self):
pass
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
async def on_rdata(self, stream_name, token, rows):
for r in rows:
self.received_rdata_rows.append((stream_name, token, r))
self._received_rdata_rows.append((stream_name, token, r))
async def on_position(self, stream_name, token):
pass

View File

@@ -17,30 +17,63 @@ from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
ROOM_ID = "!room:blue"
EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.replicate_stream("receipts", "NOW")
self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
rdata_rows = self.test_handler.received_rdata_rows
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
self.assertEqual(rdata_rows[0][0], "receipts")
row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow
self.assertEqual(ROOM_ID, row.room_id)
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual(EVENT_ID, row.event_id)
self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
# Now let's disconnect and insert some data.
self.disconnect()
self.test_handler.on_rdata.reset_mock()
self.get_success(
self.hs.get_datastore().insert_receipt(
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
)
)
self.replicate()
# Nothing should have happened as we are disconnected
self.test_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
# We should now have caught up and get the missing data
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "receipts")
self.assertEqual(token, 3)
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
self.assertEqual("!room2:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
self.assertEqual("$event2:foo", row.event_id)
self.assertEqual({"a": 2}, row.data)