mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-09 01:30:18 +00:00
Compare commits
36 Commits
develop
...
erikj/spli
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
83ecaeecbf | ||
|
|
0473f87a17 | ||
|
|
092b62ee7b | ||
|
|
b6f6f5c399 | ||
|
|
f7da931d62 | ||
|
|
9f15bffd72 | ||
|
|
6da24f2d5f | ||
|
|
5473f1806a | ||
|
|
f6e7daaac3 | ||
|
|
309c7eb1a1 | ||
|
|
f8038f4670 | ||
|
|
9ea391054f | ||
|
|
604f57f1bd | ||
|
|
bd64b8fcd5 | ||
|
|
309aee4636 | ||
|
|
e4c5b1d9d6 | ||
|
|
7eec84bfbe | ||
|
|
4dd08f2501 | ||
|
|
55dfcd2f09 | ||
|
|
11fb08ffa9 | ||
|
|
ef4f063687 | ||
|
|
2380e401e4 | ||
|
|
5d810c36a8 | ||
|
|
ea17e939df | ||
|
|
225b993cf6 | ||
|
|
3204b0e79f | ||
|
|
ba1a8be930 | ||
|
|
a2070a2c4e | ||
|
|
4f2a803c66 | ||
|
|
259cdffa96 | ||
|
|
32c656865a | ||
|
|
8734b75ca8 | ||
|
|
1f83255de1 | ||
|
|
ba90596687 | ||
|
|
811d2ecf2e | ||
|
|
7233d38690 |
1
changelog.d/7024.misc
Normal file
1
changelog.d/7024.misc
Normal file
@@ -0,0 +1 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
@@ -14,16 +14,16 @@ example flow would be (where '>' indicates master to worker and
|
||||
'<' worker to master flows):
|
||||
|
||||
> SERVER example.com
|
||||
< REPLICATE events 53
|
||||
< REPLICATE
|
||||
> POSITION events 53
|
||||
> RDATA events 54 ["$foo1:bar.com", ...]
|
||||
> RDATA events 55 ["$foo4:bar.com", ...]
|
||||
|
||||
The example shows the server accepting a new connection and sending its
|
||||
identity with the `SERVER` command, followed by the client asking to
|
||||
subscribe to the `events` stream from the token `53`. The server then
|
||||
periodically sends `RDATA` commands which have the format
|
||||
`RDATA <stream_name> <token> <row>`, where the format of `<row>` is
|
||||
defined by the individual streams.
|
||||
The example shows the server accepting a new connection and sending its identity
|
||||
with the `SERVER` command, followed by the client server to respond with the
|
||||
position of all streams. The server then periodically sends `RDATA` commands
|
||||
which have the format `RDATA <stream_name> <token> <row>`, where the format of
|
||||
`<row>` is defined by the individual streams.
|
||||
|
||||
Error reporting happens by either the client or server sending an ERROR
|
||||
command, and usually the connection will be closed.
|
||||
@@ -32,9 +32,6 @@ Since the protocol is a simple line based, its possible to manually
|
||||
connect to the server using a tool like netcat. A few things should be
|
||||
noted when manually using the protocol:
|
||||
|
||||
- When subscribing to a stream using `REPLICATE`, the special token
|
||||
`NOW` can be used to get all future updates. The special stream name
|
||||
`ALL` can be used with `NOW` to subscribe to all available streams.
|
||||
- The federation stream is only available if federation sending has
|
||||
been disabled on the main process.
|
||||
- The server will only time connections out that have sent a `PING`
|
||||
@@ -91,9 +88,7 @@ The client:
|
||||
- Sends a `NAME` command, allowing the server to associate a human
|
||||
friendly name with the connection. This is optional.
|
||||
- Sends a `PING` as above
|
||||
- For each stream the client wishes to subscribe to it sends a
|
||||
`REPLICATE` with the `stream_name` and token it wants to subscribe
|
||||
from.
|
||||
- Sends a `REPLICATE` to get the current position of all streams.
|
||||
- On receipt of a `SERVER` command, checks that the server name
|
||||
matches the expected server name.
|
||||
|
||||
@@ -140,9 +135,7 @@ the wire:
|
||||
> PING 1490197665618
|
||||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE events 1
|
||||
< REPLICATE backfill 1
|
||||
< REPLICATE caches 1
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
@@ -181,9 +174,9 @@ client (C):
|
||||
|
||||
#### POSITION (S)
|
||||
|
||||
The position of the stream has been updated. Sent to the client
|
||||
after all missing updates for a stream have been sent to the client
|
||||
and they're now up to date.
|
||||
On receipt of a POSITION command clients should check if they have missed any
|
||||
updates, and if so then fetch them out of band. Sent in response to a
|
||||
REPLICATE command (but can happen at any time).
|
||||
|
||||
#### ERROR (S, C)
|
||||
|
||||
@@ -199,25 +192,16 @@ client (C):
|
||||
|
||||
#### REPLICATE (C)
|
||||
|
||||
Asks the server to replicate a given stream. The syntax is:
|
||||
|
||||
```
|
||||
REPLICATE <stream_name> <token>
|
||||
```
|
||||
|
||||
Where `<token>` may be either:
|
||||
* a numeric stream_id to stream updates since (exclusive)
|
||||
* `NOW` to stream all subsequent updates.
|
||||
|
||||
The `<stream_name>` is the name of a replication stream to subscribe
|
||||
to (see [here](../synapse/replication/tcp/streams/_base.py) for a list
|
||||
of streams). It can also be `ALL` to subscribe to all known streams,
|
||||
in which case the `<token>` must be set to `NOW`.
|
||||
Asks the server for the current position of all streams.
|
||||
|
||||
#### USER_SYNC (C)
|
||||
|
||||
A user has started or stopped syncing
|
||||
|
||||
#### CLEAR_USER_SYNC (C)
|
||||
|
||||
The server should clear all associated user sync data from the worker.
|
||||
|
||||
#### FEDERATION_ACK (C)
|
||||
|
||||
Acknowledge receipt of some federation data
|
||||
|
||||
3
mypy.ini
3
mypy.ini
@@ -75,3 +75,6 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-jwt.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-txredisapi]
|
||||
ignore_missing_imports = True
|
||||
|
||||
@@ -45,6 +45,7 @@ from synapse.http.site import SynapseSite
|
||||
from synapse.logging.context import LoggingContext, run_in_background
|
||||
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
@@ -64,7 +65,9 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.replication.tcp.client import ReplicationClientFactory
|
||||
from synapse.replication.tcp.commands import ClearUserSyncsCommand
|
||||
from synapse.replication.tcp.handler import ReplicationDataHandler
|
||||
from synapse.replication.tcp.streams import (
|
||||
AccountDataStream,
|
||||
DeviceListsStream,
|
||||
@@ -75,7 +78,6 @@ from synapse.replication.tcp.streams import (
|
||||
ReceiptsStream,
|
||||
TagAccountDataStream,
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
@@ -104,6 +106,7 @@ from synapse.rest.client.v1.room import (
|
||||
RoomSendEventRestServlet,
|
||||
RoomStateEventRestServlet,
|
||||
RoomStateRestServlet,
|
||||
RoomTypingRestServlet,
|
||||
)
|
||||
from synapse.rest.client.v1.voip import VoipRestServlet
|
||||
from synapse.rest.client.v2_alpha import groups, sync, user_directory
|
||||
@@ -124,7 +127,6 @@ from synapse.types import ReadReceipt
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
logger = logging.getLogger("synapse.app.generic_worker")
|
||||
@@ -233,6 +235,7 @@ class GenericWorkerPresence(object):
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
self.user_to_current_state = {state.user_id: state for state in active_presence}
|
||||
@@ -245,13 +248,24 @@ class GenericWorkerPresence(object):
|
||||
self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
|
||||
)
|
||||
|
||||
self.process_id = random_string(16)
|
||||
logger.info("Presence process_id is %r", self.process_id)
|
||||
hs.get_reactor().addSystemEventTrigger(
|
||||
"before",
|
||||
"shutdown",
|
||||
run_as_background_process,
|
||||
"generic_presence.on_shutdown",
|
||||
self._on_shutdown,
|
||||
)
|
||||
|
||||
def _on_shutdown(self):
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_command(
|
||||
ClearUserSyncsCommand(self.instance_id)
|
||||
)
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
if self.hs.config.use_presence:
|
||||
self.hs.get_tcp_replication().send_user_sync(
|
||||
user_id, is_syncing, last_sync_ms
|
||||
self.instance_id, user_id, is_syncing, last_sync_ms
|
||||
)
|
||||
|
||||
def mark_as_coming_online(self, user_id):
|
||||
@@ -368,40 +382,6 @@ class GenericWorkerPresence(object):
|
||||
return set()
|
||||
|
||||
|
||||
class GenericWorkerTyping(object):
|
||||
def __init__(self, hs):
|
||||
self._latest_room_serial = 0
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
"""
|
||||
Reset the typing handler's data caches.
|
||||
"""
|
||||
# map room IDs to serial numbers
|
||||
self._room_serials = {}
|
||||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {}
|
||||
|
||||
def stream_positions(self):
|
||||
# We must update this typing token from the response of the previous
|
||||
# sync. In particular, the stream id may "reset" back to zero/a low
|
||||
# value which we *must* use for the next replication request.
|
||||
return {"typing": self._latest_room_serial}
|
||||
|
||||
def process_replication_rows(self, token, rows):
|
||||
if self._latest_room_serial > token:
|
||||
# The master has gone backwards. To prevent inconsistent data, just
|
||||
# clear everything.
|
||||
self._reset()
|
||||
|
||||
# Set the latest serial token to whatever the server gave us.
|
||||
self._latest_room_serial = token
|
||||
|
||||
for row in rows:
|
||||
self._room_serials[row.room_id] = token
|
||||
self._room_typing[row.room_id] = row.user_ids
|
||||
|
||||
|
||||
class GenericWorkerSlavedStore(
|
||||
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
|
||||
# rather than going via the correct worker.
|
||||
@@ -486,6 +466,7 @@ class GenericWorkerServer(HomeServer):
|
||||
ProfileDisplaynameRestServlet(self).register(resource)
|
||||
ProfileRestServlet(self).register(resource)
|
||||
KeyUploadServlet(self).register(resource)
|
||||
RoomTypingRestServlet(self).register(resource)
|
||||
|
||||
sync.register_servlets(self, resource)
|
||||
events.register_servlets(self, resource)
|
||||
@@ -541,6 +522,9 @@ class GenericWorkerServer(HomeServer):
|
||||
if name in ["keys", "federation"]:
|
||||
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
|
||||
|
||||
if name == "replication":
|
||||
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
_base.listen_tcp(
|
||||
@@ -583,27 +567,35 @@ class GenericWorkerServer(HomeServer):
|
||||
else:
|
||||
logger.warning("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
self.get_tcp_replication().start_replication(self)
|
||||
if self.config.redis.redis_enabled:
|
||||
from synapse.replication.tcp.redis import RedisFactory
|
||||
|
||||
logger.info("Connecting to redis.")
|
||||
factory = RedisFactory(self)
|
||||
self.get_reactor().connectTCP(
|
||||
self.config.redis.redis_host, self.config.redis.redis_port, factory
|
||||
)
|
||||
else:
|
||||
factory = ReplicationClientFactory(self, self.config.worker_name)
|
||||
host = self.config.worker_replication_host
|
||||
port = self.config.worker_replication_port
|
||||
self.get_reactor().connectTCP(host, port, factory)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
def build_tcp_replication(self):
|
||||
return GenericWorkerReplicationHandler(self)
|
||||
|
||||
def build_presence_handler(self):
|
||||
return GenericWorkerPresence(self)
|
||||
|
||||
def build_typing_handler(self):
|
||||
return GenericWorkerTyping(self)
|
||||
def build_replication_data_handler(self):
|
||||
return GenericWorkerReplicationHandler(self)
|
||||
|
||||
|
||||
class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
||||
class GenericWorkerReplicationHandler(ReplicationDataHandler):
|
||||
def __init__(self, hs):
|
||||
super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
|
||||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
# NB this is a SynchrotronPresence, not a normal PresenceHandler
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.notifier = hs.get_notifier()
|
||||
@@ -612,30 +604,29 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
if hs.config.send_federation:
|
||||
self.send_handler = FederationSenderHandler(hs, self)
|
||||
self.send_handler = FederationSenderHandler(hs)
|
||||
else:
|
||||
self.send_handler = None
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
await super(GenericWorkerReplicationHandler, self).on_rdata(
|
||||
stream_name, token, rows
|
||||
async def on_rdata(self, stream_name, instance_name, token, rows):
|
||||
await super().on_rdata(stream_name, instance_name, token, rows)
|
||||
run_in_background(
|
||||
self.process_and_notify, stream_name, instance_name, token, rows
|
||||
)
|
||||
run_in_background(self.process_and_notify, stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
|
||||
args.update(self.typing_handler.stream_positions())
|
||||
args = super().get_streams_to_replicate()
|
||||
|
||||
if self.send_handler:
|
||||
args.update(self.send_handler.stream_positions())
|
||||
return args
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
return self.presence_handler.get_currently_syncing_users()
|
||||
|
||||
async def process_and_notify(self, stream_name, token, rows):
|
||||
async def process_and_notify(self, stream_name, instance_name, token, rows):
|
||||
try:
|
||||
if self.send_handler:
|
||||
self.send_handler.process_replication_rows(stream_name, token, rows)
|
||||
self.send_handler.process_replication_rows(
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
if stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
@@ -675,11 +666,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
|
||||
await self.pusher_pool.on_new_receipts(
|
||||
token, token, {row.room_id for row in rows}
|
||||
)
|
||||
elif stream_name == TypingStream.NAME:
|
||||
self.typing_handler.process_replication_rows(token, rows)
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
elif stream_name == ToDeviceStream.NAME:
|
||||
entities = [row.entity for row in rows if row.entity.startswith("@")]
|
||||
if entities:
|
||||
@@ -740,13 +726,14 @@ class FederationSenderHandler(object):
|
||||
to the federation sender.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: GenericWorkerServer, replication_client):
|
||||
def __init__(self, hs: GenericWorkerServer):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
self.replication_client = replication_client
|
||||
# self.replication_client = hs.get_tcp_replication()
|
||||
|
||||
self.federation_position = self.store.federation_out_pos_startup
|
||||
self.federation_position = {"master": self.store.federation_out_pos_startup}
|
||||
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
|
||||
|
||||
self._last_ack = self.federation_position
|
||||
@@ -767,12 +754,12 @@ class FederationSenderHandler(object):
|
||||
def stream_positions(self):
|
||||
return {"federation": self.federation_position}
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
# The federation stream contains things that we want to send out, e.g.
|
||||
# presence, typing, etc.
|
||||
if stream_name == "federation":
|
||||
send_queue.process_rows_for_federation(self.federation_sender, rows)
|
||||
run_in_background(self.update_token, token)
|
||||
run_in_background(self.update_token, instance_name, token)
|
||||
|
||||
# We also need to poke the federation sender when new events happen
|
||||
elif stream_name == "events":
|
||||
@@ -820,9 +807,12 @@ class FederationSenderHandler(object):
|
||||
)
|
||||
await self.federation_sender.send_read_receipt(receipt_info)
|
||||
|
||||
async def update_token(self, token):
|
||||
async def update_token(self, instance_name, token):
|
||||
try:
|
||||
self.federation_position = token
|
||||
self.federation_position[instance_name] = token
|
||||
return
|
||||
|
||||
# FIXME
|
||||
|
||||
# We linearize here to ensure we don't have races updating the token
|
||||
with (await self._fed_position_linearizer.queue(None)):
|
||||
@@ -833,7 +823,7 @@ class FederationSenderHandler(object):
|
||||
|
||||
# We ACK this token over replication so that the master can drop
|
||||
# its in memory queues
|
||||
self.replication_client.send_federation_ack(
|
||||
self.hs.get_tcp_replication().send_federation_ack(
|
||||
self.federation_position
|
||||
)
|
||||
self._last_ack = self.federation_position
|
||||
@@ -915,6 +905,10 @@ def start(config_options):
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.send_federation = True
|
||||
|
||||
config.server.handle_typing = False
|
||||
if config.worker_app == "synapse.app.client_reader":
|
||||
config.server.handle_typing = True
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
ss = GenericWorkerServer(
|
||||
@@ -930,6 +924,8 @@ def start(config_options):
|
||||
"before", "startup", _base.start, ss, config.worker_listeners
|
||||
)
|
||||
|
||||
ss.get_replication_streamer()
|
||||
|
||||
_base.start_worker_reactor("synapse-generic-worker", config)
|
||||
|
||||
|
||||
|
||||
@@ -263,6 +263,15 @@ class SynapseHomeServer(HomeServer):
|
||||
def start_listening(self, listeners):
|
||||
config = self.get_config()
|
||||
|
||||
if config.redis_enabled:
|
||||
from synapse.replication.tcp.redis import RedisFactory
|
||||
|
||||
logger.info("Connecting to redis.")
|
||||
factory = RedisFactory(self)
|
||||
self.get_reactor().connectTCP(
|
||||
self.config.redis.redis_host, self.config.redis.redis_port, factory
|
||||
)
|
||||
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listening_services.extend(self._listener_http(config, listener))
|
||||
@@ -282,6 +291,7 @@ class SynapseHomeServer(HomeServer):
|
||||
)
|
||||
for s in services:
|
||||
reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
|
||||
|
||||
elif listener["type"] == "metrics":
|
||||
if not self.get_config().enable_metrics:
|
||||
logger.warning(
|
||||
|
||||
@@ -31,6 +31,7 @@ from .password import PasswordConfig
|
||||
from .password_auth_providers import PasswordAuthProviderConfig
|
||||
from .push import PushConfig
|
||||
from .ratelimiting import RatelimitConfig
|
||||
from .redis import RedisConfig
|
||||
from .registration import RegistrationConfig
|
||||
from .repository import ContentRepositoryConfig
|
||||
from .room_directory import RoomDirectoryConfig
|
||||
@@ -82,4 +83,5 @@ class HomeServerConfig(RootConfig):
|
||||
RoomDirectoryConfig,
|
||||
ThirdPartyRulesConfig,
|
||||
TracerConfig,
|
||||
RedisConfig,
|
||||
]
|
||||
|
||||
47
synapse/config/redis.py
Normal file
47
synapse/config/redis.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.config._base import Config, ConfigError
|
||||
|
||||
try:
|
||||
import txredisapi
|
||||
except ImportError:
|
||||
txredisapi = None
|
||||
|
||||
|
||||
MISSING_REDIS = """Missing 'txredisapi' library. This is required for redis support.
|
||||
|
||||
Install by running:
|
||||
pip install txredisapi
|
||||
"""
|
||||
|
||||
|
||||
class RedisConfig(Config):
|
||||
section = "redis"
|
||||
|
||||
def read_config(self, config, **kwargs):
|
||||
redis_config = config.get("redis", {})
|
||||
self.redis_enabled = redis_config.get("enabled", False)
|
||||
|
||||
if not self.redis_enabled:
|
||||
return
|
||||
|
||||
if txredisapi is None:
|
||||
raise ConfigError(MISSING_REDIS)
|
||||
|
||||
self.redis_host = redis_config.get("host", "localhost")
|
||||
self.redis_port = redis_config.get("port", 6379)
|
||||
self.redis_dbid = redis_config.get("dbid")
|
||||
self.redis_password = redis_config.get("password")
|
||||
@@ -83,6 +83,8 @@ class ServerConfig(Config):
|
||||
# "disable" federation
|
||||
self.send_federation = config.get("send_federation", True)
|
||||
|
||||
self.handle_typing = config.get("handle_typing", True)
|
||||
|
||||
# Whether to enable user presence.
|
||||
self.use_presence = config.get("use_presence", True)
|
||||
|
||||
|
||||
@@ -48,6 +48,8 @@ class WorkerConfig(Config):
|
||||
|
||||
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
|
||||
|
||||
self.instance_http_map = config.get("instance_http_map", {})
|
||||
|
||||
# This option is really only here to support `--manhole` command line
|
||||
# argument.
|
||||
manhole = config.get("worker_manhole")
|
||||
|
||||
@@ -819,7 +819,16 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||
edu_type, origin, content
|
||||
)
|
||||
|
||||
return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
|
||||
if edu_type == "m.typing":
|
||||
instance_name = "synapse.app.client_reader"
|
||||
else:
|
||||
instance_name = "master"
|
||||
return await self._send_edu(
|
||||
instance_name=instance_name,
|
||||
edu_type=edu_type,
|
||||
origin=origin,
|
||||
content=content,
|
||||
)
|
||||
|
||||
async def on_query(self, query_type, args):
|
||||
"""Overrides FederationHandlerRegistry
|
||||
|
||||
@@ -499,4 +499,13 @@ class FederationSender(object):
|
||||
self._get_per_destination_queue(destination).attempt_new_transaction()
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return 0
|
||||
|
||||
async def get_replication_rows(
|
||||
self, from_token, to_token, limit, federation_ack=None
|
||||
):
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return []
|
||||
|
||||
@@ -21,6 +21,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.replication.tcp.streams import TypingStream
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -288,6 +289,54 @@ class TypingHandler(object):
|
||||
return self._latest_room_serial
|
||||
|
||||
|
||||
class TypingSlaveHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self._latest_room_serial = 0
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
"""
|
||||
Reset the typing handler's data caches.
|
||||
"""
|
||||
# map room IDs to serial numbers
|
||||
self._room_serials = {}
|
||||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {}
|
||||
|
||||
def stream_positions(self):
|
||||
# We must update this typing token from the response of the previous
|
||||
# sync. In particular, the stream id may "reset" back to zero/a low
|
||||
# value which we *must* use for the next replication request.
|
||||
return {
|
||||
"typing": {"synapse.app.client_reader": self._latest_room_serial}
|
||||
} # FIXME
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name != TypingStream.NAME:
|
||||
return
|
||||
|
||||
if self._latest_room_serial > token:
|
||||
# The master has gone backwards. To prevent inconsistent data, just
|
||||
# clear everything.
|
||||
self._reset()
|
||||
|
||||
# Set the latest serial token to whatever the server gave us.
|
||||
self._latest_room_serial = token
|
||||
|
||||
for row in rows:
|
||||
self._room_serials[row.room_id] = token
|
||||
self._room_typing[row.room_id] = row.user_ids
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"typing_key", token, rooms=[row.room_id for row in rows]
|
||||
)
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
return self._latest_room_serial
|
||||
|
||||
|
||||
class TypingNotificationEventSource(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
|
||||
@@ -98,6 +98,7 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"sentry": ["sentry-sdk>=0.7.2"],
|
||||
"opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
|
||||
"jwt": ["pyjwt>=1.6.4"],
|
||||
"redis": ["txredisapi>=1.4.7"],
|
||||
}
|
||||
|
||||
ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
|
||||
|
||||
@@ -21,6 +21,7 @@ from synapse.replication.http import (
|
||||
membership,
|
||||
register,
|
||||
send_event,
|
||||
streams,
|
||||
)
|
||||
|
||||
REPLICATION_PREFIX = "/_synapse/replication"
|
||||
@@ -32,9 +33,12 @@ class ReplicationRestResource(JsonResource):
|
||||
self.register_servlets(hs)
|
||||
|
||||
def register_servlets(self, hs):
|
||||
send_event.register_servlets(hs, self)
|
||||
membership.register_servlets(hs, self)
|
||||
if hs.config.worker_app is None:
|
||||
send_event.register_servlets(hs, self)
|
||||
membership.register_servlets(hs, self)
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
|
||||
streams.register_servlets(hs, self)
|
||||
federation.register_servlets(hs, self)
|
||||
login.register_servlets(hs, self)
|
||||
register.register_servlets(hs, self)
|
||||
devices.register_servlets(hs, self)
|
||||
|
||||
@@ -128,14 +128,25 @@ class ReplicationEndpoint(object):
|
||||
Returns a callable that accepts the same parameters as `_serialize_payload`.
|
||||
"""
|
||||
clock = hs.get_clock()
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_http_port
|
||||
master_host = hs.config.worker_replication_host
|
||||
master_port = hs.config.worker_replication_http_port
|
||||
|
||||
instance_http_map = hs.config.instance_http_map
|
||||
|
||||
client = hs.get_simple_http_client()
|
||||
|
||||
@trace(opname="outgoing_replication_request")
|
||||
@defer.inlineCallbacks
|
||||
def send_request(**kwargs):
|
||||
def send_request(instance_name="master", **kwargs):
|
||||
if instance_name == "master":
|
||||
host = master_host
|
||||
port = master_port
|
||||
elif instance_name in instance_http_map:
|
||||
host = instance_http_map[instance_name]["host"]
|
||||
port = instance_http_map[instance_name]["port"]
|
||||
else:
|
||||
raise Exception("Unknown instance")
|
||||
|
||||
data = yield cls._serialize_payload(**kwargs)
|
||||
|
||||
url_args = [
|
||||
|
||||
@@ -277,8 +277,10 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
|
||||
if hs.config.worker_app is None:
|
||||
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
|
||||
ReplicationGetQueryRestServlet(hs).register(http_server)
|
||||
ReplicationCleanRoomRestServlet(hs).register(http_server)
|
||||
ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
|
||||
|
||||
ReplicationFederationSendEduRestServlet(hs).register(http_server)
|
||||
ReplicationGetQueryRestServlet(hs).register(http_server)
|
||||
ReplicationCleanRoomRestServlet(hs).register(http_server)
|
||||
ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
|
||||
|
||||
80
synapse/replication/http/streams.py
Normal file
80
synapse/replication/http/streams.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import parse_integer
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ReplicationGetStreamUpdates(ReplicationEndpoint):
|
||||
"""Fetches stream updates from a server. Used for streams not persisted to
|
||||
the database, e.g. typing notifications.
|
||||
|
||||
The API looks like:
|
||||
|
||||
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
|
||||
|
||||
200 OK
|
||||
|
||||
{
|
||||
updates: [ ... ],
|
||||
upto_token: 10,
|
||||
limited: False,
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
NAME = "get_repl_stream_updates"
|
||||
PATH_ARGS = ("stream_name",)
|
||||
METHOD = "GET"
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
# We pull the streams from the replication steamer (if we try and make
|
||||
# them ourselves we end up in an import loop).
|
||||
self.streams = hs.get_replication_streamer().get_streams()
|
||||
|
||||
self.instance_name = hs.config.worker_name or "master"
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(stream_name, from_token, upto_token, limit):
|
||||
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
|
||||
|
||||
async def _handle_request(self, request, stream_name):
|
||||
stream = self.streams.get(stream_name)
|
||||
if stream is None:
|
||||
raise SynapseError(400, "Unknown stream")
|
||||
|
||||
from_token = parse_integer(request, "from_token", required=True)
|
||||
upto_token = parse_integer(request, "upto_token", required=True)
|
||||
limit = parse_integer(request, "limit", required=True)
|
||||
|
||||
updates, upto_token, limited = await stream.get_updates_since(
|
||||
self.instance_name, from_token, upto_token, limit
|
||||
)
|
||||
|
||||
return (
|
||||
200,
|
||||
{"updates": updates, "upto_token": upto_token, "limited": limited},
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReplicationGetStreamUpdates(hs).register(http_server)
|
||||
@@ -18,8 +18,10 @@ from typing import Dict, Optional
|
||||
|
||||
import six
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
|
||||
from synapse.storage.data_stores.main.cache import (
|
||||
CURRENT_STATE_CACHE_NAME,
|
||||
CacheInvalidationWorkerStore,
|
||||
)
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
|
||||
@@ -35,7 +37,7 @@ def __func__(inp):
|
||||
return inp.__func__
|
||||
|
||||
|
||||
class BaseSlavedStore(SQLBaseStore):
|
||||
class BaseSlavedStore(CacheInvalidationWorkerStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
super(BaseSlavedStore, self).__init__(database, db_conn, hs)
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
@@ -57,9 +59,15 @@ class BaseSlavedStore(SQLBaseStore):
|
||||
"""
|
||||
pos = {}
|
||||
if self._cache_id_gen:
|
||||
pos["caches"] = self._cache_id_gen.get_current_token()
|
||||
pos["caches"] = {"master": self._cache_id_gen.get_current_token()}
|
||||
return pos
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
else:
|
||||
return 0
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "caches":
|
||||
if self._cache_id_gen:
|
||||
|
||||
@@ -35,9 +35,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||
def stream_positions(self):
|
||||
result = super(SlavedAccountDataStore, self).stream_positions()
|
||||
position = self._account_data_id_gen.get_current_token()
|
||||
result["user_account_data"] = position
|
||||
result["room_account_data"] = position
|
||||
result["tag_account_data"] = position
|
||||
result["account_data"] = {"master": position}
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -45,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
||||
result["to_device"] = {"master": self._device_inbox_id_gen.get_current_token()}
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -54,8 +54,8 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
# device list stream, so set them both to the device list ID
|
||||
# generator's current token.
|
||||
current_token = self._device_list_id_gen.get_current_token()
|
||||
result[DeviceListsStream.NAME] = current_token
|
||||
result[UserSignatureStream.NAME] = current_token
|
||||
result[DeviceListsStream.NAME] = {"master": current_token}
|
||||
result[UserSignatureStream.NAME] = {"master": current_token}
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -95,8 +95,8 @@ class SlavedEventStore(
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
result["backfill"] = -self._backfill_id_gen.get_current_token()
|
||||
result["events"] = {"master": self._stream_id_gen.get_current_token()}
|
||||
result["backfill"] = {"master": -self._backfill_id_gen.get_current_token()}
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -39,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedGroupServerStore, self).stream_positions()
|
||||
result["groups"] = self._group_updates_id_gen.get_current_token()
|
||||
result["groups"] = {"master": self._group_updates_id_gen.get_current_token()}
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -46,7 +46,7 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
|
||||
if self.hs.config.use_presence:
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
result["presence"] = {"master": position}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -39,7 +39,9 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPushRuleStore, self).stream_positions()
|
||||
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
|
||||
result["push_rules"] = {
|
||||
"master": self._push_rules_stream_id_gen.get_current_token()
|
||||
}
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -30,9 +30,12 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPusherStore, self).stream_positions()
|
||||
result["pushers"] = self._pushers_id_gen.get_current_token()
|
||||
result["pushers"] = {"master": self._pushers_id_gen.get_current_token()}
|
||||
return result
|
||||
|
||||
def get_pushers_stream_token(self):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "pushers":
|
||||
self._pushers_id_gen.advance(token)
|
||||
|
||||
@@ -44,7 +44,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedReceiptsStore, self).stream_positions()
|
||||
result["receipts"] = self._receipts_id_gen.get_current_token()
|
||||
result["receipts"] = {"master": self._receipts_id_gen.get_current_token()}
|
||||
return result
|
||||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
|
||||
@@ -32,7 +32,9 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(RoomStore, self).stream_positions()
|
||||
result["public_rooms"] = self._public_room_id_gen.get_current_token()
|
||||
result["public_rooms"] = {
|
||||
"master": self._public_room_id_gen.get_current_token()
|
||||
}
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -16,26 +16,10 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.tcp.protocol import (
|
||||
AbstractReplicationClientHandler,
|
||||
ClientReplicationStreamProtocol,
|
||||
)
|
||||
|
||||
from .commands import (
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
RemoteServerUpCommand,
|
||||
RemovePusherCommand,
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,10 +35,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
initialDelay = 0.1
|
||||
maxDelay = 1 # Try at least once every N seconds
|
||||
|
||||
def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
|
||||
def __init__(self, hs, client_name):
|
||||
self.client_name = client_name
|
||||
self.handler = handler
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.server_name = hs.config.server_name
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock() # As self.clock is defined in super class
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
|
||||
@@ -65,7 +50,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
def buildProtocol(self, addr):
|
||||
logger.info("Connected to replication: %r", addr)
|
||||
return ClientReplicationStreamProtocol(
|
||||
self.client_name, self.server_name, self._clock, self.handler
|
||||
self.hs, self.client_name, self.server_name, self._clock, self.handler,
|
||||
)
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
@@ -75,170 +60,3 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
logger.error("Failed to connect to replication: %r", reason)
|
||||
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
|
||||
|
||||
|
||||
class ReplicationClientHandler(AbstractReplicationClientHandler):
|
||||
"""A base handler that can be passed to the ReplicationClientFactory.
|
||||
|
||||
By default proxies incoming replication data to the SlaveStore.
|
||||
"""
|
||||
|
||||
def __init__(self, store: BaseSlavedStore):
|
||||
self.store = store
|
||||
|
||||
# The current connection. None if we are currently (re)connecting
|
||||
self.connection = None
|
||||
|
||||
# Any pending commands to be sent once a new connection has been
|
||||
# established
|
||||
self.pending_commands = [] # type: List[Command]
|
||||
|
||||
# Map from string -> deferred, to wake up when receiveing a SYNC with
|
||||
# the given string.
|
||||
# Used for tests.
|
||||
self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
|
||||
|
||||
# The factory used to create connections.
|
||||
self.factory = None # type: Optional[ReplicationClientFactory]
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
client_name = hs.config.worker_name
|
||||
self.factory = ReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
async def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data. By default this just pokes
|
||||
the slave store.
|
||||
|
||||
Can be overriden in subclasses to handle more.
|
||||
"""
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
|
||||
def on_sync(self, data):
|
||||
"""When we received a SYNC we wake up any deferreds that were waiting
|
||||
for the sync with the given data.
|
||||
|
||||
Used by tests.
|
||||
"""
|
||||
d = self.awaiting_syncs.pop(data, None)
|
||||
if d:
|
||||
d.callback(data)
|
||||
|
||||
def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
args = self.store.stream_positions()
|
||||
user_account_data = args.pop("user_account_data", None)
|
||||
room_account_data = args.pop("room_account_data", None)
|
||||
if user_account_data:
|
||||
args["account_data"] = user_account_data
|
||||
elif room_account_data:
|
||||
args["account_data"] = room_account_data
|
||||
|
||||
return args
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users. (Overriden by the synchrotron's only)
|
||||
"""
|
||||
return []
|
||||
|
||||
def send_command(self, cmd):
|
||||
"""Send a command to master (when we get establish a connection if we
|
||||
don't have one already.)
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.send_command(cmd)
|
||||
else:
|
||||
logger.warning("Queuing command as not connected: %r", cmd.NAME)
|
||||
self.pending_commands.append(cmd)
|
||||
|
||||
def send_federation_ack(self, token):
|
||||
"""Ack data for the federation stream. This allows the master to drop
|
||||
data stored purely in memory.
|
||||
"""
|
||||
self.send_command(FederationAckCommand(token))
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
"""Poke the master that a user has started/stopped syncing.
|
||||
"""
|
||||
self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
|
||||
|
||||
def send_remove_pusher(self, app_id, push_key, user_id):
|
||||
"""Poke the master to remove a pusher for a user
|
||||
"""
|
||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func, keys):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||
"""Tell the master that the user made a request.
|
||||
"""
|
||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def await_sync(self, data):
|
||||
"""Returns a deferred that is resolved when we receive a SYNC command
|
||||
with given data.
|
||||
|
||||
[Not currently] used by tests.
|
||||
"""
|
||||
return self.awaiting_syncs.setdefault(data, defer.Deferred())
|
||||
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
self.connection = connection
|
||||
if connection:
|
||||
for cmd in self.pending_commands:
|
||||
connection.send_command(cmd)
|
||||
self.pending_commands = []
|
||||
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
logger.info("Finished connecting to server")
|
||||
|
||||
# We don't reset the delay any earlier as otherwise if there is a
|
||||
# problem during start up we'll end up tight looping connecting to the
|
||||
# server.
|
||||
if self.factory:
|
||||
self.factory.resetDelay()
|
||||
|
||||
@@ -86,7 +86,7 @@ class RdataCommand(Command):
|
||||
|
||||
Format::
|
||||
|
||||
RDATA <stream_name> <token> <row_json>
|
||||
RDATA <stream_name> <instance_name> <token> <row_json>
|
||||
|
||||
The `<token>` may either be a numeric stream id OR "batch". The latter case
|
||||
is used to support sending multiple updates with the same stream ID. This
|
||||
@@ -107,22 +107,27 @@ class RdataCommand(Command):
|
||||
|
||||
NAME = "RDATA"
|
||||
|
||||
def __init__(self, stream_name, token, row):
|
||||
def __init__(self, stream_name, instance_name, token, row):
|
||||
self.stream_name = stream_name
|
||||
self.instance_name = instance_name
|
||||
self.token = token
|
||||
self.row = row
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, token, row_json = line.split(" ", 2)
|
||||
stream_name, instance_name, token, row_json = line.split(" ", 3)
|
||||
return cls(
|
||||
stream_name, None if token == "batch" else int(token), json.loads(row_json)
|
||||
stream_name,
|
||||
instance_name,
|
||||
None if token == "batch" else int(token),
|
||||
json.loads(row_json),
|
||||
)
|
||||
|
||||
def to_line(self):
|
||||
return " ".join(
|
||||
(
|
||||
self.stream_name,
|
||||
self.instance_name,
|
||||
str(self.token) if self.token is not None else "batch",
|
||||
_json_encoder.encode(self.row),
|
||||
)
|
||||
@@ -136,23 +141,24 @@ class PositionCommand(Command):
|
||||
"""Sent by the server to tell the client the stream postition without
|
||||
needing to send an RDATA.
|
||||
|
||||
Sent to the client after all missing updates for a stream have been sent
|
||||
to the client and they're now up to date.
|
||||
On receipt of a POSITION command clients should check if they have missed
|
||||
any updates, and if so then fetch them out of band.
|
||||
"""
|
||||
|
||||
NAME = "POSITION"
|
||||
|
||||
def __init__(self, stream_name, token):
|
||||
def __init__(self, stream_name, instance_name, token):
|
||||
self.stream_name = stream_name
|
||||
self.instance_name = instance_name
|
||||
self.token = token
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, token = line.split(" ", 1)
|
||||
return cls(stream_name, int(token))
|
||||
stream_name, instance_name, token = line.split(" ", 2)
|
||||
return cls(stream_name, instance_name, int(token))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.stream_name, str(self.token)))
|
||||
return " ".join((self.stream_name, self.instance_name, str(self.token)))
|
||||
|
||||
|
||||
class ErrorCommand(Command):
|
||||
@@ -179,42 +185,24 @@ class NameCommand(Command):
|
||||
|
||||
|
||||
class ReplicateCommand(Command):
|
||||
"""Sent by the client to subscribe to the stream.
|
||||
"""Sent by the client to subscribe to streams.
|
||||
|
||||
Format::
|
||||
|
||||
REPLICATE <stream_name> <token>
|
||||
|
||||
Where <token> may be either:
|
||||
* a numeric stream_id to stream updates from
|
||||
* "NOW" to stream all subsequent updates.
|
||||
|
||||
The <stream_name> can be "ALL" to subscribe to all known streams, in which
|
||||
case the <token> must be set to "NOW", i.e.::
|
||||
|
||||
REPLICATE ALL NOW
|
||||
REPLICATE
|
||||
"""
|
||||
|
||||
NAME = "REPLICATE"
|
||||
|
||||
def __init__(self, stream_name, token):
|
||||
self.stream_name = stream_name
|
||||
self.token = token
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
stream_name, token = line.split(" ", 1)
|
||||
if token in ("NOW", "now"):
|
||||
token = "NOW"
|
||||
else:
|
||||
token = int(token)
|
||||
return cls(stream_name, token)
|
||||
return cls()
|
||||
|
||||
def to_line(self):
|
||||
return " ".join((self.stream_name, str(self.token)))
|
||||
|
||||
def get_logcontext_id(self):
|
||||
return "REPLICATE-" + self.stream_name
|
||||
return ""
|
||||
|
||||
|
||||
class UserSyncCommand(Command):
|
||||
@@ -225,30 +213,32 @@ class UserSyncCommand(Command):
|
||||
|
||||
Format::
|
||||
|
||||
USER_SYNC <user_id> <state> <last_sync_ms>
|
||||
USER_SYNC <instance_id> <user_id> <state> <last_sync_ms>
|
||||
|
||||
Where <state> is either "start" or "stop"
|
||||
"""
|
||||
|
||||
NAME = "USER_SYNC"
|
||||
|
||||
def __init__(self, user_id, is_syncing, last_sync_ms):
|
||||
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms):
|
||||
self.instance_id = instance_id
|
||||
self.user_id = user_id
|
||||
self.is_syncing = is_syncing
|
||||
self.last_sync_ms = last_sync_ms
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
user_id, state, last_sync_ms = line.split(" ", 2)
|
||||
instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
|
||||
|
||||
if state not in ("start", "end"):
|
||||
raise Exception("Invalid USER_SYNC state %r" % (state,))
|
||||
|
||||
return cls(user_id, state == "start", int(last_sync_ms))
|
||||
return cls(instance_id, user_id, state == "start", int(last_sync_ms))
|
||||
|
||||
def to_line(self):
|
||||
return " ".join(
|
||||
(
|
||||
self.instance_id,
|
||||
self.user_id,
|
||||
"start" if self.is_syncing else "end",
|
||||
str(self.last_sync_ms),
|
||||
@@ -256,6 +246,30 @@ class UserSyncCommand(Command):
|
||||
)
|
||||
|
||||
|
||||
class ClearUserSyncsCommand(Command):
|
||||
"""Sent by the client to inform the server that it should drop all
|
||||
information about syncing users sent by the client.
|
||||
|
||||
Mainly used when client is about to shut down.
|
||||
|
||||
Format::
|
||||
|
||||
CLEAR_USER_SYNC <instance_id>
|
||||
"""
|
||||
|
||||
NAME = "CLEAR_USER_SYNC"
|
||||
|
||||
def __init__(self, instance_id):
|
||||
self.instance_id = instance_id
|
||||
|
||||
@classmethod
|
||||
def from_line(cls, line):
|
||||
return cls(line)
|
||||
|
||||
def to_line(self):
|
||||
return self.instance_id
|
||||
|
||||
|
||||
class FederationAckCommand(Command):
|
||||
"""Sent by the client when it has processed up to a given point in the
|
||||
federation stream. This allows the master to drop in-memory caches of the
|
||||
@@ -416,6 +430,7 @@ _COMMANDS = (
|
||||
InvalidateCacheCommand,
|
||||
UserIpCommand,
|
||||
RemoteServerUpCommand,
|
||||
ClearUserSyncsCommand,
|
||||
) # type: Tuple[Type[Command], ...]
|
||||
|
||||
# Map of command name to command type.
|
||||
@@ -438,6 +453,7 @@ VALID_CLIENT_COMMANDS = (
|
||||
ReplicateCommand.NAME,
|
||||
PingCommand.NAME,
|
||||
UserSyncCommand.NAME,
|
||||
ClearUserSyncsCommand.NAME,
|
||||
FederationAckCommand.NAME,
|
||||
RemovePusherCommand.NAME,
|
||||
InvalidateCacheCommand.NAME,
|
||||
|
||||
399
synapse/replication/tcp/handler.py
Normal file
399
synapse/replication/tcp/handler.py
Normal file
@@ -0,0 +1,399 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A replication client for use by synapse workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.replication.tcp.commands import (
|
||||
ClearUserSyncsCommand,
|
||||
Command,
|
||||
FederationAckCommand,
|
||||
InvalidateCacheCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
RemovePusherCommand,
|
||||
ReplicateCommand,
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
|
||||
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
|
||||
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
|
||||
invalidate_cache_counter = Counter(
|
||||
"synapse_replication_tcp_resource_invalidate_cache", ""
|
||||
)
|
||||
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
|
||||
|
||||
|
||||
class ReplicationClientHandler:
|
||||
"""Handles incoming commands from replication.
|
||||
|
||||
Proxies data to `HomeServer.get_replication_data_handler()`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.replication_data_handler = hs.get_replication_data_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.clock = hs.get_clock()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
self.instance_name = hs.config.worker.worker_name or "master"
|
||||
|
||||
self.connections = [] # type: List[Any]
|
||||
|
||||
self.streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
"",
|
||||
[],
|
||||
lambda: len(self.connections),
|
||||
)
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_connections_per_stream",
|
||||
"",
|
||||
["stream_name"],
|
||||
lambda: {
|
||||
(stream_name,): len(
|
||||
[
|
||||
conn
|
||||
for conn in self.connections
|
||||
if stream_name in conn.replication_streams
|
||||
]
|
||||
)
|
||||
for stream_name in self.streams
|
||||
},
|
||||
)
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
self.is_master = hs.config.worker_app is None
|
||||
|
||||
self.federation_sender = None
|
||||
if self.is_master and not hs.config.send_federation:
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self._server_notices_sender = None
|
||||
if self.is_master:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
|
||||
|
||||
def new_connection(self, connection):
|
||||
self.connections.append(connection)
|
||||
|
||||
def lost_connection(self, connection):
|
||||
try:
|
||||
self.connections.remove(connection)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def connected(self) -> bool:
|
||||
"""Do we have any replication connections open?
|
||||
|
||||
Used to no-op if nothing is connected.
|
||||
"""
|
||||
return bool(self.connections)
|
||||
|
||||
async def on_REPLICATE(self, cmd: ReplicateCommand):
|
||||
# We only want to announce positions by the writer of the streams.
|
||||
# Currently this is just the master process.
|
||||
if not self.is_master:
|
||||
return
|
||||
|
||||
if not self.connections:
|
||||
raise Exception("Not connected")
|
||||
|
||||
for stream_name, stream in self.streams.items():
|
||||
current_token = stream.current_token()
|
||||
self.send_command(
|
||||
PositionCommand(stream_name, self.instance_name, current_token)
|
||||
)
|
||||
|
||||
async def on_USER_SYNC(self, cmd: UserSyncCommand):
|
||||
user_sync_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
await self.presence_handler.update_external_syncs_row(
|
||||
cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
|
||||
if self.is_master:
|
||||
await self.presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
|
||||
async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self.federation_sender:
|
||||
self.federation_sender.federation_ack(cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
|
||||
remove_pusher_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
|
||||
)
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
# We invalidate the cache locally, but then also stream that to other
|
||||
# workers.
|
||||
await self.store.invalidate_cache_and_stream(
|
||||
cmd.cache_func, tuple(cmd.keys)
|
||||
)
|
||||
|
||||
async def on_USER_IP(self, cmd: UserIpCommand):
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
if self.is_master:
|
||||
await self.store.insert_client_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
cmd.ip,
|
||||
cmd.user_agent,
|
||||
cmd.device_id,
|
||||
cmd.last_seen,
|
||||
)
|
||||
|
||||
if self._server_notices_sender:
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
async def on_RDATA(self, cmd: RdataCommand):
|
||||
stream_name = cmd.stream_name
|
||||
|
||||
try:
|
||||
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
|
||||
except Exception:
|
||||
logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row)
|
||||
raise
|
||||
|
||||
if cmd.token is None:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
else:
|
||||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
|
||||
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name: name of the replication stream for this batch of rows
|
||||
token: stream token for this batch of rows
|
||||
rows: a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
logger.info("Received rdata %s %s -> %s", stream_name, instance_name, token)
|
||||
await self.replication_data_handler.on_rdata(
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
async def on_POSITION(self, cmd: PositionCommand):
|
||||
stream = self.streams.get(cmd.stream_name)
|
||||
if not stream:
|
||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||
return
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_tokens = self.replication_data_handler.get_streams_to_replicate().get(
|
||||
cmd.stream_name
|
||||
)
|
||||
if current_tokens is None:
|
||||
logger.debug(
|
||||
"Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
|
||||
)
|
||||
return
|
||||
|
||||
current_token = current_tokens.get(cmd.instance_name, 0)
|
||||
|
||||
# Fetch all updates between then and now.
|
||||
limited = cmd.token != current_token
|
||||
while limited:
|
||||
updates, current_token, limited = await stream.get_updates_since(
|
||||
cmd.instance_name, current_token, cmd.token
|
||||
)
|
||||
if updates:
|
||||
await self.on_rdata(
|
||||
cmd.stream_name,
|
||||
cmd.instance_name,
|
||||
current_token,
|
||||
[stream.parse_row(update[1]) for update in updates],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
# Handle any RDATA that came in while we were catching up.
|
||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||
if rows:
|
||||
await self.on_rdata(
|
||||
cmd.stream_name, cmd.instance_name, rows[-1].token, rows
|
||||
)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
if self.is_master:
|
||||
self.notifier.notify_remote_server_up(cmd.data)
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users.
|
||||
"""
|
||||
return self.presence_handler.get_currently_syncing_users()
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send a command to master (when we get establish a connection if we
|
||||
don't have one already.)
|
||||
"""
|
||||
for conn in self.connections:
|
||||
conn.send_command(cmd)
|
||||
|
||||
def send_federation_ack(self, token: int):
|
||||
"""Ack data for the federation stream. This allows the master to drop
|
||||
data stored purely in memory.
|
||||
"""
|
||||
self.send_command(FederationAckCommand(token))
|
||||
|
||||
def send_user_sync(
|
||||
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
|
||||
):
|
||||
"""Poke the master that a user has started/stopped syncing.
|
||||
"""
|
||||
self.send_command(
|
||||
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
|
||||
)
|
||||
|
||||
def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
||||
"""Poke the master to remove a pusher for a user
|
||||
"""
|
||||
cmd = RemovePusherCommand(app_id, push_key, user_id)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
|
||||
"""Poke the master to invalidate a cache.
|
||||
"""
|
||||
cmd = InvalidateCacheCommand(cache_func.__name__, keys)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_user_ip(
|
||||
self,
|
||||
user_id: str,
|
||||
access_token: str,
|
||||
ip: str,
|
||||
user_agent: str,
|
||||
device_id: str,
|
||||
last_seen: int,
|
||||
):
|
||||
"""Tell the master that the user made a request.
|
||||
"""
|
||||
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
self.send_command(cmd)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def stream_update(self, stream_name: str, token: str, data: Any):
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
self.send_command(RdataCommand(stream_name, self.instance_name, token, data))
|
||||
|
||||
|
||||
class ReplicationDataHandler:
|
||||
"""A replication data handler that simply discards all data.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.typing_handler = hs.get_typing_handler()
|
||||
|
||||
self.slaved_store = hs.config.worker_app is not None
|
||||
self.slaved_typing = not hs.config.server.handle_typing
|
||||
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
if self.slaved_store:
|
||||
self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
if self.slaved_typing:
|
||||
self.typing_handler.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
def get_streams_to_replicate(self) -> Dict[str, int]:
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
args = {} # type: Dict[str, int]
|
||||
|
||||
if self.slaved_store:
|
||||
args = self.store.stream_positions()
|
||||
|
||||
if self.slaved_typing:
|
||||
args.update(self.typing_handler.stream_positions())
|
||||
|
||||
return args
|
||||
|
||||
async def on_position(self, stream_name: str, token: int):
|
||||
if self.slaved_store:
|
||||
self.store.process_replication_rows(stream_name, token, [])
|
||||
|
||||
if self.slaved_typing:
|
||||
self.typing_handler.process_replication_rows(stream_name, token, [])
|
||||
@@ -35,9 +35,7 @@ indicate which side is sending, these are *not* included on the wire::
|
||||
> PING 1490197665618
|
||||
< NAME synapse.app.appservice
|
||||
< PING 1490197665618
|
||||
< REPLICATE events 1
|
||||
< REPLICATE backfill 1
|
||||
< REPLICATE caches 1
|
||||
< REPLICATE
|
||||
> POSITION events 1
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
@@ -48,45 +46,40 @@ indicate which side is sending, these are *not* included on the wire::
|
||||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
"""
|
||||
import abc
|
||||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
from collections import defaultdict
|
||||
from typing import Any, DefaultDict, Dict, List, Set, Tuple
|
||||
from typing import Any, DefaultDict, Dict, List, Set
|
||||
|
||||
from six import iteritems, iterkeys
|
||||
from six import iteritems
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import (
|
||||
COMMAND_MAP,
|
||||
VALID_CLIENT_COMMANDS,
|
||||
VALID_SERVER_COMMANDS,
|
||||
Command,
|
||||
ErrorCommand,
|
||||
NameCommand,
|
||||
PingCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
ReplicateCommand,
|
||||
ServerCommand,
|
||||
SyncCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP
|
||||
from synapse.types import Collection
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
connection_close_counter = Counter(
|
||||
"synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
|
||||
)
|
||||
@@ -127,16 +120,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
delimiter = b"\n"
|
||||
|
||||
# Valid commands we expect to receive
|
||||
VALID_INBOUND_COMMANDS = [] # type: Collection[str]
|
||||
|
||||
# Valid commands we can send
|
||||
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
|
||||
|
||||
max_line_buffer = 10000
|
||||
|
||||
def __init__(self, clock):
|
||||
def __init__(self, clock, handler):
|
||||
self.clock = clock
|
||||
self.handler = handler
|
||||
|
||||
self.last_received_command = self.clock.time_msec()
|
||||
self.last_sent_command = 0
|
||||
@@ -176,6 +164,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
# can time us out.
|
||||
self.send_command(PingCommand(self.clock.time_msec()))
|
||||
|
||||
self.handler.new_connection(self)
|
||||
|
||||
def send_ping(self):
|
||||
"""Periodically sends a ping and checks if we should close the connection
|
||||
due to the other side timing out.
|
||||
@@ -213,11 +203,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
line = line.decode("utf-8")
|
||||
cmd_name, rest_of_line = line.split(" ", 1)
|
||||
|
||||
if cmd_name not in self.VALID_INBOUND_COMMANDS:
|
||||
logger.error("[%s] invalid command %s", self.id(), cmd_name)
|
||||
self.send_error("invalid command: %s", cmd_name)
|
||||
return
|
||||
|
||||
self.last_received_command = self.clock.time_msec()
|
||||
|
||||
self.inbound_commands_counter[cmd_name] = (
|
||||
@@ -249,8 +234,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
Args:
|
||||
cmd: received command
|
||||
"""
|
||||
handler = getattr(self, "on_%s" % (cmd.NAME,))
|
||||
await handler(cmd)
|
||||
handled = False
|
||||
|
||||
# First call any command handlers on this instance. These are for TCP
|
||||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
|
||||
def close(self):
|
||||
logger.warning("[%s] Closing connection", self.id())
|
||||
@@ -258,6 +258,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self.transport.loseConnection()
|
||||
self.on_connection_closed()
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def send_error(self, error_string, *args):
|
||||
"""Send an error to remote and close the connection.
|
||||
"""
|
||||
@@ -379,6 +382,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
self.state = ConnectionStates.CLOSED
|
||||
self.pending_commands = []
|
||||
|
||||
self.handler.lost_connection(self)
|
||||
|
||||
if self.transport:
|
||||
self.transport.unregisterProducer()
|
||||
|
||||
@@ -402,346 +407,66 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
|
||||
|
||||
class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||
VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||
|
||||
def __init__(self, server_name, clock, streamer):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
|
||||
def __init__(self, hs, server_name, clock, handler):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
|
||||
|
||||
self.server_name = server_name
|
||||
self.streamer = streamer
|
||||
|
||||
# The streams the client has subscribed to and is up to date with
|
||||
self.replication_streams = set() # type: Set[str]
|
||||
|
||||
# The streams the client is currently subscribing to.
|
||||
self.connecting_streams = set() # type: Set[str]
|
||||
|
||||
# Map from stream name to list of updates to send once we've finished
|
||||
# subscribing the client to the stream.
|
||||
self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(ServerCommand(self.server_name))
|
||||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
self.streamer.new_connection(self)
|
||||
|
||||
async def on_NAME(self, cmd):
|
||||
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
||||
self.name = cmd.data
|
||||
|
||||
async def on_USER_SYNC(self, cmd):
|
||||
await self.streamer.on_user_sync(
|
||||
self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
||||
)
|
||||
|
||||
async def on_REPLICATE(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
token = cmd.token
|
||||
|
||||
if stream_name == "ALL":
|
||||
# Subscribe to all streams we're publishing to.
|
||||
deferreds = [
|
||||
run_in_background(self.subscribe_to_stream, stream, token)
|
||||
for stream in iterkeys(self.streamer.streams_by_name)
|
||||
]
|
||||
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
else:
|
||||
await self.subscribe_to_stream(stream_name, token)
|
||||
|
||||
async def on_FEDERATION_ACK(self, cmd):
|
||||
self.streamer.federation_ack(cmd.token)
|
||||
|
||||
async def on_REMOVE_PUSHER(self, cmd):
|
||||
await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
||||
|
||||
async def on_INVALIDATE_CACHE(self, cmd):
|
||||
await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
self.streamer.on_remote_server_up(cmd.data)
|
||||
|
||||
async def on_USER_IP(self, cmd):
|
||||
self.streamer.on_user_ip(
|
||||
cmd.user_id,
|
||||
cmd.access_token,
|
||||
cmd.ip,
|
||||
cmd.user_agent,
|
||||
cmd.device_id,
|
||||
cmd.last_seen,
|
||||
)
|
||||
|
||||
async def subscribe_to_stream(self, stream_name, token):
|
||||
"""Subscribe the remote to a stream.
|
||||
|
||||
This invloves checking if they've missed anything and sending those
|
||||
updates down if they have. During that time new updates for the stream
|
||||
are queued and sent once we've sent down any missed updates.
|
||||
"""
|
||||
self.replication_streams.discard(stream_name)
|
||||
self.connecting_streams.add(stream_name)
|
||||
|
||||
try:
|
||||
# Get missing updates
|
||||
updates, current_token = await self.streamer.get_stream_updates(
|
||||
stream_name, token
|
||||
)
|
||||
|
||||
# Send all the missing updates
|
||||
for update in updates:
|
||||
token, row = update[0], update[1]
|
||||
self.send_command(RdataCommand(stream_name, token, row))
|
||||
|
||||
# We send a POSITION command to ensure that they have an up to
|
||||
# date token (especially useful if we didn't send any updates
|
||||
# above)
|
||||
self.send_command(PositionCommand(stream_name, current_token))
|
||||
|
||||
# Now we can send any updates that came in while we were subscribing
|
||||
pending_rdata = self.pending_rdata.pop(stream_name, [])
|
||||
updates = []
|
||||
for token, update in pending_rdata:
|
||||
# If the token is null, it is part of a batch update. Batches
|
||||
# are multiple updates that share a single token. To denote
|
||||
# this, the token is set to None for all tokens in the batch
|
||||
# except for the last. If we find a None token, we keep looking
|
||||
# through tokens until we find one that is not None and then
|
||||
# process all previous updates in the batch as if they had the
|
||||
# final token.
|
||||
if token is None:
|
||||
# Store this update as part of a batch
|
||||
updates.append(update)
|
||||
continue
|
||||
|
||||
if token <= current_token:
|
||||
# This update or batch of updates is older than
|
||||
# current_token, dismiss it
|
||||
updates = []
|
||||
continue
|
||||
|
||||
updates.append(update)
|
||||
|
||||
# Send all updates that are part of this batch with the
|
||||
# found token
|
||||
for update in updates:
|
||||
self.send_command(RdataCommand(stream_name, token, update))
|
||||
|
||||
# Clear stored updates
|
||||
updates = []
|
||||
|
||||
# They're now fully subscribed
|
||||
self.replication_streams.add(stream_name)
|
||||
except Exception as e:
|
||||
logger.exception("[%s] Failed to handle REPLICATE command", self.id())
|
||||
self.send_error("failed to handle replicate: %r", e)
|
||||
finally:
|
||||
self.connecting_streams.discard(stream_name)
|
||||
|
||||
def stream_update(self, stream_name, token, data):
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
if stream_name in self.replication_streams:
|
||||
# The client is subscribed to the stream
|
||||
self.send_command(RdataCommand(stream_name, token, data))
|
||||
elif stream_name in self.connecting_streams:
|
||||
# The client is being subscribed to the stream
|
||||
logger.debug("[%s] Queuing RDATA %r %r", self.id(), stream_name, token)
|
||||
self.pending_rdata.setdefault(stream_name, []).append((token, data))
|
||||
else:
|
||||
# The client isn't subscribed
|
||||
logger.debug("[%s] Dropping RDATA %r %r", self.id(), stream_name, token)
|
||||
|
||||
def send_sync(self, data):
|
||||
self.send_command(SyncCommand(data))
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def on_connection_closed(self):
|
||||
BaseReplicationStreamProtocol.on_connection_closed(self)
|
||||
self.streamer.lost_connection(self)
|
||||
|
||||
|
||||
class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The interface for the handler that should be passed to
|
||||
ClientReplicationStreamProtocol
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_position(self, stream_name, token):
|
||||
"""Called when we get new position data."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_sync(self, data):
|
||||
"""Called when get a new SYNC command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def on_remote_server_up(self, server: str):
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_streams_to_replicate(self):
|
||||
"""Called when a new connection has been established and we need to
|
||||
subscribe to streams.
|
||||
|
||||
Returns:
|
||||
map from stream name to the most recent update we have for
|
||||
that stream (ie, the point we want to start replicating from)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_currently_syncing_users(self):
|
||||
"""Get the list of currently syncing users (if any). This is called
|
||||
when a connection has been established and we need to send the
|
||||
currently syncing users."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_connection(self, connection):
|
||||
"""Called when a connection has been established (or lost with None).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
||||
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
client_name: str,
|
||||
server_name: str,
|
||||
clock: Clock,
|
||||
handler: AbstractReplicationClientHandler,
|
||||
handler,
|
||||
):
|
||||
BaseReplicationStreamProtocol.__init__(self, clock)
|
||||
BaseReplicationStreamProtocol.__init__(self, clock, handler)
|
||||
|
||||
self.instance_id = hs.get_instance_id()
|
||||
|
||||
self.client_name = client_name
|
||||
self.server_name = server_name
|
||||
self.handler = handler
|
||||
|
||||
self.streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
|
||||
# Set of stream names that have been subscribe to, but haven't yet
|
||||
# caught up with. This is used to track when the client has been fully
|
||||
# connected to the remote.
|
||||
self.streams_connecting = set() # type: Set[str]
|
||||
self.streams_connecting = set(STREAMS_MAP) # type: Set[str]
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {} # type: Dict[str, Any]
|
||||
self.pending_batches = {} # type: Dict[str, List[Any]]
|
||||
|
||||
def connectionMade(self):
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
BaseReplicationStreamProtocol.connectionMade(self)
|
||||
|
||||
# Once we've connected subscribe to the necessary streams
|
||||
for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
|
||||
self.replicate(stream_name, token)
|
||||
|
||||
# Tell the server if we have any users currently syncing (should only
|
||||
# happen on synchrotrons)
|
||||
currently_syncing = self.handler.get_currently_syncing_users()
|
||||
now = self.clock.time_msec()
|
||||
for user_id in currently_syncing:
|
||||
self.send_command(UserSyncCommand(user_id, True, now))
|
||||
|
||||
# We've now finished connecting to so inform the client handler
|
||||
self.handler.update_connection(self)
|
||||
|
||||
# This will happen if we don't actually subscribe to any streams
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
self.send_command(NameCommand(self.client_name))
|
||||
self.replicate()
|
||||
|
||||
async def on_SERVER(self, cmd):
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
self.send_error("Wrong remote")
|
||||
|
||||
async def on_RDATA(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.labels(stream_name).inc()
|
||||
|
||||
try:
|
||||
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
|
||||
)
|
||||
raise
|
||||
|
||||
if cmd.token is None:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
else:
|
||||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows.append(row)
|
||||
await self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
async def on_POSITION(self, cmd):
|
||||
# When we get a `POSITION` command it means we've finished getting
|
||||
# missing updates for the given stream, and are now up to date.
|
||||
self.streams_connecting.discard(cmd.stream_name)
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
await self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
async def on_SYNC(self, cmd):
|
||||
self.handler.on_sync(cmd.data)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
||||
self.handler.on_remote_server_up(cmd.data)
|
||||
|
||||
def replicate(self, stream_name, token):
|
||||
def replicate(self):
|
||||
"""Send the subscription request to the server
|
||||
"""
|
||||
if stream_name not in STREAMS_MAP:
|
||||
raise Exception("Invalid stream name %r" % (stream_name,))
|
||||
logger.info("[%s] Subscribing to replication streams", self.id())
|
||||
|
||||
logger.info(
|
||||
"[%s] Subscribing to replication stream: %r from %r",
|
||||
self.id(),
|
||||
stream_name,
|
||||
token,
|
||||
)
|
||||
|
||||
self.streams_connecting.add(stream_name)
|
||||
|
||||
self.send_command(ReplicateCommand(stream_name, token))
|
||||
|
||||
def on_connection_closed(self):
|
||||
BaseReplicationStreamProtocol.on_connection_closed(self)
|
||||
self.handler.update_connection(None)
|
||||
self.send_command(ReplicateCommand())
|
||||
|
||||
|
||||
# The following simply registers metrics for the replication connections
|
||||
|
||||
158
synapse/replication/tcp/redis.py
Normal file
158
synapse/replication/tcp/redis.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import txredisapi
|
||||
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import (
|
||||
COMMAND_MAP,
|
||||
Command,
|
||||
RdataCommand,
|
||||
ReplicateCommand,
|
||||
)
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
"""Connection to redis subscribed to replication stream.
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("Connected to redis instance")
|
||||
self.subscribe(self.stream_name)
|
||||
self.send_command(ReplicateCommand())
|
||||
|
||||
self.handler.new_connection(self)
|
||||
|
||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||
"""Received a message from redis.
|
||||
"""
|
||||
|
||||
if message.strip() == "":
|
||||
# Ignore blank lines
|
||||
return
|
||||
|
||||
line = message
|
||||
cmd_name, rest_of_line = line.split(" ", 1)
|
||||
|
||||
cmd_cls = COMMAND_MAP[cmd_name]
|
||||
try:
|
||||
cmd = cmd_cls.from_line(rest_of_line)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line
|
||||
)
|
||||
self.send_error(
|
||||
"failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line)
|
||||
)
|
||||
return
|
||||
|
||||
# Now lets try and call on_<CMD_NAME> function
|
||||
run_as_background_process(
|
||||
"replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
|
||||
)
|
||||
|
||||
async def handle_command(self, cmd: Command):
|
||||
"""Handle a command we have received over the replication stream.
|
||||
|
||||
By default delegates to on_<COMMAND>, which should return an awaitable.
|
||||
|
||||
Args:
|
||||
cmd: received command
|
||||
"""
|
||||
handled = False
|
||||
|
||||
# First call any command handlers on this instance. These are for redis
|
||||
# specific handling.
|
||||
cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
# Then call out to the handler.
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
if cmd_func:
|
||||
await cmd_func(cmd)
|
||||
handled = True
|
||||
|
||||
if not handled:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
logger.info("Lost connection to redis instance")
|
||||
self.handler.lost_connection(self)
|
||||
|
||||
def send_command(self, cmd):
|
||||
"""Send a command if connection has been established.
|
||||
|
||||
Args:
|
||||
cmd (Command)
|
||||
"""
|
||||
string = "%s %s" % (cmd.NAME, cmd.to_line())
|
||||
if "\n" in string:
|
||||
raise Exception("Unexpected newline in command: %r", string)
|
||||
|
||||
encoded_string = string.encode("utf-8")
|
||||
|
||||
async def _send():
|
||||
with PreserveLoggingContext():
|
||||
await self.redis_connection.publish(self.stream_name, encoded_string)
|
||||
|
||||
run_as_background_process("send-cmd", _send)
|
||||
|
||||
def stream_update(self, stream_name, token, data):
|
||||
"""Called when a new update is available to stream to clients.
|
||||
|
||||
We need to check if the client is interested in the stream or not
|
||||
"""
|
||||
self.send_command(RdataCommand(stream_name, "master", token, data))
|
||||
|
||||
|
||||
class RedisFactory(txredisapi.SubscriberFactory):
|
||||
|
||||
maxDelay = 5
|
||||
continueTrying = True
|
||||
protocol = RedisSubscriber
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RedisFactory, self).__init__()
|
||||
|
||||
self.password = hs.config.redis.redis_password
|
||||
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.stream_name = hs.hostname
|
||||
|
||||
self.redis_connection = txredisapi.lazyConnection(
|
||||
host=hs.config.redis_host,
|
||||
port=hs.config.redis_port,
|
||||
dbid=hs.config.redis_dbid,
|
||||
password=hs.config.redis.redis_password,
|
||||
reconnect=True,
|
||||
)
|
||||
|
||||
self.conn_id = random_string(5)
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = super(RedisFactory, self).buildProtocol(addr)
|
||||
p.handler = self.handler
|
||||
p.redis_connection = self.redis_connection
|
||||
p.conn_id = self.conn_id
|
||||
p.stream_name = self.stream_name
|
||||
return p
|
||||
@@ -17,32 +17,21 @@
|
||||
|
||||
import logging
|
||||
import random
|
||||
from typing import Any, List
|
||||
|
||||
from six import itervalues
|
||||
from typing import Dict, List
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.protocol import Factory
|
||||
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.metrics import Measure, measure_func
|
||||
|
||||
from .protocol import ServerReplicationStreamProtocol
|
||||
from .streams import STREAMS_MAP
|
||||
from .streams.federation import FederationStream
|
||||
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream, TypingStream
|
||||
from synapse.replication.tcp.streams.federation import FederationStream
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
stream_updates_counter = Counter(
|
||||
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
|
||||
)
|
||||
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
|
||||
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
|
||||
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
|
||||
invalidate_cache_counter = Counter(
|
||||
"synapse_replication_tcp_resource_invalidate_cache", ""
|
||||
)
|
||||
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,13 +41,18 @@ class ReplicationStreamProtocolFactory(Factory):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.streamer = ReplicationStreamer(hs)
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server_name
|
||||
self.hs = hs
|
||||
|
||||
# Ensure the replication streamer is started if we register a
|
||||
# replication server endpoint.
|
||||
hs.get_replication_streamer()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return ServerReplicationStreamProtocol(
|
||||
self.server_name, self.clock, self.streamer
|
||||
self.hs, self.server_name, self.clock, self.handler
|
||||
)
|
||||
|
||||
|
||||
@@ -71,67 +65,43 @@ class ReplicationStreamer(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
self._replication_torture_level = hs.config.replication_torture_level
|
||||
|
||||
# Current connections.
|
||||
self.connections = [] # type: List[ServerReplicationStreamProtocol]
|
||||
# Work out list of streams that this instance is the source of.
|
||||
self.streams = [] # type: List[Stream]
|
||||
if hs.config.worker_app is None:
|
||||
for stream in STREAMS_MAP.values():
|
||||
if stream == FederationStream:
|
||||
continue
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
"",
|
||||
[],
|
||||
lambda: len(self.connections),
|
||||
)
|
||||
if stream == TypingStream:
|
||||
continue
|
||||
|
||||
# List of streams that clients can subscribe to.
|
||||
# We only support federation stream if federation sending hase been
|
||||
# disabled on the master.
|
||||
self.streams = [
|
||||
stream(hs)
|
||||
for stream in itervalues(STREAMS_MAP)
|
||||
if stream != FederationStream or not hs.config.send_federation
|
||||
]
|
||||
self.streams.append(stream(hs))
|
||||
|
||||
if hs.config.server.handle_typing:
|
||||
self.streams.append(TypingStream(hs))
|
||||
|
||||
# We always add federation stream
|
||||
self.streams.append(FederationStream(hs))
|
||||
|
||||
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_connections_per_stream",
|
||||
"",
|
||||
["stream_name"],
|
||||
lambda: {
|
||||
(stream_name,): len(
|
||||
[
|
||||
conn
|
||||
for conn in self.connections
|
||||
if stream_name in conn.replication_streams
|
||||
]
|
||||
)
|
||||
for stream_name in self.streams_by_name
|
||||
},
|
||||
)
|
||||
|
||||
self.federation_sender = None
|
||||
if not hs.config.send_federation:
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.notifier.add_replication_callback(self.on_notifier_poke)
|
||||
self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
|
||||
|
||||
# Keeps track of whether we are currently checking for updates
|
||||
self.is_looping = False
|
||||
self.pending_updates = False
|
||||
|
||||
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
|
||||
self.client = hs.get_tcp_replication()
|
||||
|
||||
def on_shutdown(self):
|
||||
# close all connections on shutdown
|
||||
for conn in self.connections:
|
||||
conn.send_error("server shutting down")
|
||||
def get_streams(self) -> Dict[str, Stream]:
|
||||
"""Get a mapp from stream name to stream instance.
|
||||
"""
|
||||
return self.streams_by_name
|
||||
|
||||
def on_notifier_poke(self):
|
||||
"""Checks if there is actually any new data and sends it to the
|
||||
@@ -140,7 +110,7 @@ class ReplicationStreamer(object):
|
||||
This should get called each time new data is available, even if it
|
||||
is currently being executed, so that nothing gets missed
|
||||
"""
|
||||
if not self.connections:
|
||||
if not self.client.connected():
|
||||
# Don't bother if nothing is listening. We still need to advance
|
||||
# the stream tokens otherwise they'll fall beihind forever
|
||||
for stream in self.streams:
|
||||
@@ -190,15 +160,14 @@ class ReplicationStreamer(object):
|
||||
stream.current_token(),
|
||||
)
|
||||
try:
|
||||
updates, current_token = await stream.get_updates()
|
||||
updates, current_token, limited = await stream.get_updates()
|
||||
self.pending_updates |= limited
|
||||
except Exception:
|
||||
logger.info("Failed to handle stream %s", stream.NAME)
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"Sending %d updates to %d connections",
|
||||
len(updates),
|
||||
len(self.connections),
|
||||
"Sending %d updates", len(updates),
|
||||
)
|
||||
|
||||
if updates:
|
||||
@@ -214,116 +183,17 @@ class ReplicationStreamer(object):
|
||||
# token. See RdataCommand for more details.
|
||||
batched_updates = _batch_updates(updates)
|
||||
|
||||
for conn in self.connections:
|
||||
for token, row in batched_updates:
|
||||
try:
|
||||
conn.stream_update(stream.NAME, token, row)
|
||||
except Exception:
|
||||
logger.exception("Failed to replicate")
|
||||
for token, row in batched_updates:
|
||||
try:
|
||||
self.client.stream_update(stream.NAME, token, row)
|
||||
except Exception:
|
||||
logger.exception("Failed to replicate")
|
||||
|
||||
logger.debug("No more pending updates, breaking poke loop")
|
||||
finally:
|
||||
self.pending_updates = False
|
||||
self.is_looping = False
|
||||
|
||||
@measure_func("repl.get_stream_updates")
|
||||
async def get_stream_updates(self, stream_name, token):
|
||||
"""For a given stream get all updates since token. This is called when
|
||||
a client first subscribes to a stream.
|
||||
"""
|
||||
stream = self.streams_by_name.get(stream_name, None)
|
||||
if not stream:
|
||||
raise Exception("unknown stream %s", stream_name)
|
||||
|
||||
return await stream.get_updates_since(token)
|
||||
|
||||
@measure_func("repl.federation_ack")
|
||||
def federation_ack(self, token):
|
||||
"""We've received an ack for federation stream from a client.
|
||||
"""
|
||||
federation_ack_counter.inc()
|
||||
if self.federation_sender:
|
||||
self.federation_sender.federation_ack(token)
|
||||
|
||||
@measure_func("repl.on_user_sync")
|
||||
async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
||||
"""A client has started/stopped syncing on a worker.
|
||||
"""
|
||||
user_sync_counter.inc()
|
||||
await self.presence_handler.update_external_syncs_row(
|
||||
conn_id, user_id, is_syncing, last_sync_ms
|
||||
)
|
||||
|
||||
@measure_func("repl.on_remove_pusher")
|
||||
async def on_remove_pusher(self, app_id, push_key, user_id):
|
||||
"""A client has asked us to remove a pusher
|
||||
"""
|
||||
remove_pusher_counter.inc()
|
||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id=app_id, pushkey=push_key, user_id=user_id
|
||||
)
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
@measure_func("repl.on_invalidate_cache")
|
||||
async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
|
||||
"""The client has asked us to invalidate a cache
|
||||
"""
|
||||
invalidate_cache_counter.inc()
|
||||
|
||||
# We invalidate the cache locally, but then also stream that to other
|
||||
# workers.
|
||||
await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
|
||||
|
||||
@measure_func("repl.on_user_ip")
|
||||
async def on_user_ip(
|
||||
self, user_id, access_token, ip, user_agent, device_id, last_seen
|
||||
):
|
||||
"""The client saw a user request
|
||||
"""
|
||||
user_ip_cache_counter.inc()
|
||||
await self.store.insert_client_ip(
|
||||
user_id, access_token, ip, user_agent, device_id, last_seen
|
||||
)
|
||||
await self._server_notices_sender.on_user_ip(user_id)
|
||||
|
||||
@measure_func("repl.on_remote_server_up")
|
||||
def on_remote_server_up(self, server: str):
|
||||
self.notifier.notify_remote_server_up(server)
|
||||
|
||||
def send_remote_server_up(self, server: str):
|
||||
for conn in self.connections:
|
||||
conn.send_remote_server_up(server)
|
||||
|
||||
def send_sync_to_all_connections(self, data):
|
||||
"""Sends a SYNC command to all clients.
|
||||
|
||||
Used in tests.
|
||||
"""
|
||||
for conn in self.connections:
|
||||
conn.send_sync(data)
|
||||
|
||||
def new_connection(self, connection):
|
||||
"""A new client connection has been established
|
||||
"""
|
||||
self.connections.append(connection)
|
||||
|
||||
def lost_connection(self, connection):
|
||||
"""A client connection has been lost
|
||||
"""
|
||||
try:
|
||||
self.connections.remove(connection)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# We need to tell the presence handler that the connection has been
|
||||
# lost so that it can handle any ongoing syncs on that connection.
|
||||
run_as_background_process(
|
||||
"update_external_syncs_clear",
|
||||
self.presence_handler.update_external_syncs_clear,
|
||||
connection.conn_id,
|
||||
)
|
||||
|
||||
|
||||
def _batch_updates(updates):
|
||||
"""Takes a list of updates of form [(token, row)] and sets the token to
|
||||
|
||||
@@ -24,6 +24,9 @@ Each stream is defined by the following information:
|
||||
current_token: The function that returns the current token for the stream
|
||||
update_function: The function that returns a list of updates between two tokens
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
AccountDataStream,
|
||||
BackfillStream,
|
||||
@@ -35,6 +38,7 @@ from synapse.replication.tcp.streams._base import (
|
||||
PushersStream,
|
||||
PushRulesStream,
|
||||
ReceiptsStream,
|
||||
Stream,
|
||||
TagAccountDataStream,
|
||||
ToDeviceStream,
|
||||
TypingStream,
|
||||
@@ -63,10 +67,12 @@ STREAMS_MAP = {
|
||||
GroupServerStream,
|
||||
UserSignatureStream,
|
||||
)
|
||||
}
|
||||
} # type: Dict[str, Type[Stream]]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"STREAMS_MAP",
|
||||
"Stream",
|
||||
"BackfillStream",
|
||||
"PresenceStream",
|
||||
"TypingStream",
|
||||
|
||||
@@ -14,13 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.replication.http.streams import ReplicationGetStreamUpdates
|
||||
from synapse.types import JsonDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,6 +29,15 @@ logger = logging.getLogger(__name__)
|
||||
MAX_EVENTS_BEHIND = 500000
|
||||
|
||||
|
||||
# Some type aliases to make things a bit easier.
|
||||
|
||||
# A stream position token
|
||||
Token = int
|
||||
|
||||
# A pair of position in stream and args used to create an instance of `ROW_TYPE`.
|
||||
StreamRow = Tuple[Token, tuple]
|
||||
|
||||
|
||||
class Stream(object):
|
||||
"""Base class for the streams.
|
||||
|
||||
@@ -56,70 +65,58 @@ class Stream(object):
|
||||
return cls.ROW_TYPE(*row)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
|
||||
self.local_instance_name = hs.config.worker_name or "master"
|
||||
|
||||
def discard_updates_and_advance(self):
|
||||
"""Called when the stream should advance but the updates would be discarded,
|
||||
e.g. when there are no currently connected workers.
|
||||
"""
|
||||
self.last_token = self.current_token()
|
||||
|
||||
async def get_updates(self):
|
||||
async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Gets all updates since the last time this function was called (or
|
||||
since the stream was constructed if it hadn't been called before).
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List[Tuple[int, Any]], int]:
|
||||
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
|
||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
||||
sent over the replication steam.
|
||||
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||
a list of `(token, row)` entries, `new_last_token` is the new
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
updates, current_token = await self.get_updates_since(self.last_token)
|
||||
current_token = self.current_token()
|
||||
updates, current_token, limited = await self.get_updates_since(
|
||||
self.local_instance_name, self.last_token, current_token
|
||||
)
|
||||
self.last_token = current_token
|
||||
|
||||
return updates, current_token
|
||||
return updates, current_token, limited
|
||||
|
||||
async def get_updates_since(
|
||||
self, from_token: int
|
||||
) -> Tuple[List[Tuple[int, JsonDict]], int]:
|
||||
self, instance_name: str, from_token: Token, upto_token: Token, limit: int = 100
|
||||
) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]:
|
||||
"""Like get_updates except allows specifying from when we should
|
||||
stream updates
|
||||
|
||||
Returns:
|
||||
Resolves to a pair `(updates, new_last_token)`, where `updates` is
|
||||
a list of `(token, row)` entries and `new_last_token` is the new
|
||||
position in stream.
|
||||
A triplet `(updates, new_last_token, limited)`, where `updates` is
|
||||
a list of `(token, row)` entries, `new_last_token` is the new
|
||||
position in stream, and `limited` is whether there are more updates
|
||||
to fetch.
|
||||
"""
|
||||
|
||||
if from_token in ("NOW", "now"):
|
||||
return [], self.current_token()
|
||||
|
||||
current_token = self.current_token()
|
||||
|
||||
from_token = int(from_token)
|
||||
|
||||
if from_token == current_token:
|
||||
return [], current_token
|
||||
if from_token == upto_token:
|
||||
return [], upto_token, False
|
||||
|
||||
rows = await self.update_function(
|
||||
from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
|
||||
updates, upto_token, limited = await self.update_function(
|
||||
instance_name, from_token, upto_token, limit=limit,
|
||||
)
|
||||
|
||||
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
|
||||
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
|
||||
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
|
||||
# check we didn't get more rows than the limit.
|
||||
# doing it like this allows the update_function to be a generator.
|
||||
if len(updates) >= MAX_EVENTS_BEHIND:
|
||||
raise Exception("stream %s has fallen behind" % (self.NAME))
|
||||
|
||||
# The update function didn't hit the limit, so we must have got all
|
||||
# the updates to `current_token`, and can return that as our new
|
||||
# stream position.
|
||||
return updates, current_token
|
||||
return updates, upto_token, limited
|
||||
|
||||
def current_token(self):
|
||||
"""Gets the current token of the underlying streams. Should be provided
|
||||
@@ -141,6 +138,49 @@ class Stream(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[str, Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> Callable[[str, Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
|
||||
async def update_function(instance_name, from_token, upto_token, limit):
|
||||
rows = await query_function(from_token, upto_token, limit)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) == limit:
|
||||
upto_token = rows[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
def make_http_update_function(
|
||||
hs, stream_name: str
|
||||
) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]:
|
||||
"""Makes a suitable function for use as an `update_function` that queries
|
||||
the master process for updates.
|
||||
"""
|
||||
|
||||
client = ReplicationGetStreamUpdates.make_client(hs)
|
||||
|
||||
async def update_function(
|
||||
instance_name: str, from_token: int, upto_token: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
return await client(
|
||||
instance_name=instance_name,
|
||||
stream_name=stream_name,
|
||||
from_token=from_token,
|
||||
upto_token=upto_token,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
class BackfillStream(Stream):
|
||||
"""We fetched some old events and either we had never seen that event before
|
||||
or it went from being an outlier to not.
|
||||
@@ -164,7 +204,7 @@ class BackfillStream(Stream):
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.current_token = store.get_current_backfill_token # type: ignore
|
||||
self.update_function = store.get_all_new_backfill_event_rows # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_backfill_event_rows) # type: ignore
|
||||
|
||||
super(BackfillStream, self).__init__(hs)
|
||||
|
||||
@@ -190,8 +230,15 @@ class PresenceStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
presence_handler = hs.get_presence_handler()
|
||||
|
||||
self._is_worker = hs.config.worker_app is not None
|
||||
|
||||
self.current_token = store.get_current_presence_token # type: ignore
|
||||
self.update_function = presence_handler.get_all_presence_updates # type: ignore
|
||||
|
||||
if hs.config.worker_app is None:
|
||||
self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super(PresenceStream, self).__init__(hs)
|
||||
|
||||
@@ -208,7 +255,12 @@ class TypingStream(Stream):
|
||||
typing_handler = hs.get_typing_handler()
|
||||
|
||||
self.current_token = typing_handler.get_current_token # type: ignore
|
||||
self.update_function = typing_handler.get_all_typing_updates # type: ignore
|
||||
|
||||
if hs.config.handle_typing:
|
||||
self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore
|
||||
else:
|
||||
# Query master process
|
||||
self.update_function = make_http_update_function(hs, self.NAME) # type: ignore
|
||||
|
||||
super(TypingStream, self).__init__(hs)
|
||||
|
||||
@@ -232,7 +284,7 @@ class ReceiptsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_receipt_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_receipts # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_receipts) # type: ignore
|
||||
|
||||
super(ReceiptsStream, self).__init__(hs)
|
||||
|
||||
@@ -256,7 +308,13 @@ class PushRulesStream(Stream):
|
||||
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
|
||||
return [(row[0], row[2]) for row in rows]
|
||||
|
||||
limited = False
|
||||
if len(rows) == limit:
|
||||
to_token = rows[-1][0]
|
||||
limited = True
|
||||
|
||||
return [(row[0], (row[2],)) for row in rows], to_token, limited
|
||||
|
||||
|
||||
class PushersStream(Stream):
|
||||
@@ -275,7 +333,7 @@ class PushersStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_pushers_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_pushers_rows # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_pushers_rows) # type: ignore
|
||||
|
||||
super(PushersStream, self).__init__(hs)
|
||||
|
||||
@@ -307,7 +365,7 @@ class CachesStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_cache_stream_token # type: ignore
|
||||
self.update_function = store.get_all_updated_caches # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_caches) # type: ignore
|
||||
|
||||
super(CachesStream, self).__init__(hs)
|
||||
|
||||
@@ -333,7 +391,7 @@ class PublicRoomsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_current_public_room_stream_id # type: ignore
|
||||
self.update_function = store.get_all_new_public_rooms # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_public_rooms) # type: ignore
|
||||
|
||||
super(PublicRoomsStream, self).__init__(hs)
|
||||
|
||||
@@ -354,7 +412,7 @@ class DeviceListsStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_device_list_changes_for_remotes) # type: ignore
|
||||
|
||||
super(DeviceListsStream, self).__init__(hs)
|
||||
|
||||
@@ -372,7 +430,7 @@ class ToDeviceStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_to_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_new_device_messages # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_new_device_messages) # type: ignore
|
||||
|
||||
super(ToDeviceStream, self).__init__(hs)
|
||||
|
||||
@@ -392,7 +450,7 @@ class TagAccountDataStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = store.get_all_updated_tags # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_updated_tags) # type: ignore
|
||||
|
||||
super(TagAccountDataStream, self).__init__(hs)
|
||||
|
||||
@@ -412,10 +470,11 @@ class AccountDataStream(Stream):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.current_token = self.store.get_max_account_data_stream_id # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(AccountDataStream, self).__init__(hs)
|
||||
|
||||
async def update_function(self, from_token, to_token, limit):
|
||||
async def _update_function(self, from_token, to_token, limit):
|
||||
global_results, room_results = await self.store.get_all_updated_account_data(
|
||||
from_token, from_token, to_token, limit
|
||||
)
|
||||
@@ -442,7 +501,7 @@ class GroupServerStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_group_stream_token # type: ignore
|
||||
self.update_function = store.get_all_groups_changes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_groups_changes) # type: ignore
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
@@ -460,6 +519,6 @@ class UserSignatureStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_device_stream_token # type: ignore
|
||||
self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
|
||||
self.update_function = db_query_to_update_function(store.get_all_user_signature_changes_for_remotes) # type: ignore
|
||||
|
||||
super(UserSignatureStream, self).__init__(hs)
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream
|
||||
from ._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
@@ -117,10 +117,11 @@ class EventsStream(Stream):
|
||||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
self.current_token = self._store.get_current_events_token # type: ignore
|
||||
self.update_function = db_query_to_update_function(self._update_function) # type: ignore
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
async def update_function(self, from_token, current_token, limit=None):
|
||||
async def _update_function(self, from_token, current_token, limit=None):
|
||||
event_rows = await self._store.get_all_new_forward_event_rows(
|
||||
from_token, current_token, limit
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from ._base import Stream
|
||||
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
@@ -33,11 +33,14 @@ class FederationStream(Stream):
|
||||
|
||||
NAME = "federation"
|
||||
ROW_TYPE = FederationStreamRow
|
||||
_QUERY_MASTER = True
|
||||
|
||||
def __init__(self, hs):
|
||||
# Not all synapse instances will have a federation sender instance,
|
||||
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
|
||||
# so we stub the stream out when that is the case.
|
||||
federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.current_token = federation_sender.get_current_token # type: ignore
|
||||
self.update_function = federation_sender.get_replication_rows # type: ignore
|
||||
self.update_function = db_query_to_update_function(federation_sender.get_replication_rows) # type: ignore
|
||||
|
||||
super(FederationStream, self).__init__(hs)
|
||||
|
||||
@@ -816,7 +816,7 @@ class RoomTypingRestServlet(RestServlet):
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
# await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
||||
# Limit timeout to stop people from setting silly typing timeouts.
|
||||
timeout = min(content.get("timeout", 30000), 120000)
|
||||
|
||||
@@ -78,13 +78,18 @@ from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
|
||||
from synapse.handlers.set_password import SetPasswordHandler
|
||||
from synapse.handlers.stats import StatsHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.typing import TypingHandler, TypingSlaveHandler
|
||||
from synapse.handlers.user_directory import UserDirectoryHandler
|
||||
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.notifier import Notifier
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.replication.tcp.handler import (
|
||||
ReplicationClientHandler,
|
||||
ReplicationDataHandler,
|
||||
)
|
||||
from synapse.replication.tcp.resource import ReplicationStreamer
|
||||
from synapse.rest.media.v1.media_repository import (
|
||||
MediaRepository,
|
||||
MediaRepositoryResource,
|
||||
@@ -100,6 +105,7 @@ from synapse.storage import DataStores, Storage
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -199,6 +205,8 @@ class HomeServer(object):
|
||||
"saml_handler",
|
||||
"event_client_serializer",
|
||||
"storage",
|
||||
"replication_streamer",
|
||||
"replication_data_handler",
|
||||
]
|
||||
|
||||
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
|
||||
@@ -224,6 +232,8 @@ class HomeServer(object):
|
||||
self._listening_services = []
|
||||
self.start_time = None
|
||||
|
||||
self.instance_id = random_string(5)
|
||||
|
||||
self.clock = Clock(reactor)
|
||||
self.distributor = Distributor()
|
||||
self.ratelimiter = Ratelimiter()
|
||||
@@ -236,6 +246,11 @@ class HomeServer(object):
|
||||
for depname in kwargs:
|
||||
setattr(self, depname, kwargs[depname])
|
||||
|
||||
def get_instance_id(self):
|
||||
"""A unique ID for this synapse process instance.
|
||||
"""
|
||||
return self.instance_id
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.start_time = int(self.get_clock().time())
|
||||
@@ -339,7 +354,10 @@ class HomeServer(object):
|
||||
return PresenceHandler(self)
|
||||
|
||||
def build_typing_handler(self):
|
||||
return TypingHandler(self)
|
||||
if self.config.handle_typing:
|
||||
return TypingHandler(self)
|
||||
else:
|
||||
return TypingSlaveHandler(self)
|
||||
|
||||
def build_sync_handler(self):
|
||||
return SyncHandler(self)
|
||||
@@ -439,10 +457,8 @@ class HomeServer(object):
|
||||
def build_federation_sender(self):
|
||||
if self.should_send_federation():
|
||||
return FederationSender(self)
|
||||
elif not self.config.worker_app:
|
||||
return FederationRemoteSendQueue(self)
|
||||
else:
|
||||
raise Exception("Workers cannot send federation traffic")
|
||||
return FederationRemoteSendQueue(self)
|
||||
|
||||
def build_receipts_handler(self):
|
||||
return ReceiptsHandler(self)
|
||||
@@ -451,7 +467,7 @@ class HomeServer(object):
|
||||
return ReadMarkerHandler(self)
|
||||
|
||||
def build_tcp_replication(self):
|
||||
raise NotImplementedError()
|
||||
return ReplicationClientHandler(self)
|
||||
|
||||
def build_action_generator(self):
|
||||
return ActionGenerator(self)
|
||||
@@ -536,6 +552,12 @@ class HomeServer(object):
|
||||
def build_storage(self) -> Storage:
|
||||
return Storage(self, self.datastores)
|
||||
|
||||
def build_replication_streamer(self) -> ReplicationStreamer:
|
||||
return ReplicationStreamer(self)
|
||||
|
||||
def build_replication_data_handler(self):
|
||||
return ReplicationDataHandler(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class HomeServer(object):
|
||||
pass
|
||||
def get_tcp_replication(
|
||||
self,
|
||||
) -> synapse.replication.tcp.client.ReplicationClientHandler:
|
||||
) -> synapse.replication.tcp.handler.ReplicationClientHandler:
|
||||
pass
|
||||
def get_federation_registry(
|
||||
self,
|
||||
@@ -114,3 +114,5 @@ class HomeServer(object):
|
||||
pass
|
||||
def is_mine_id(self, domain_id: str) -> bool:
|
||||
pass
|
||||
def get_instance_id(self) -> str:
|
||||
pass
|
||||
|
||||
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
|
||||
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
|
||||
|
||||
|
||||
class CacheInvalidationStore(SQLBaseStore):
|
||||
class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_updated_caches_txn(txn):
|
||||
# We purposefully don't bound by the current token, as we want to
|
||||
# send across cache invalidations as quickly as possible. Cache
|
||||
# invalidations are idempotent, so duplicates are fine.
|
||||
sql = (
|
||||
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
||||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
|
||||
class CacheInvalidationStore(CacheInvalidationWorkerStore):
|
||||
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_updated_caches_txn(txn):
|
||||
# We purposefully don't bound by the current token, as we want to
|
||||
# send across cache invalidations as quickly as possible. Cache
|
||||
# invalidations are idempotent, so duplicates are fine.
|
||||
sql = (
|
||||
"SELECT stream_id, cache_func, keys, invalidation_ts"
|
||||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token()
|
||||
|
||||
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
|
||||
)
|
||||
|
||||
def get_all_new_device_messages(self, last_pos, current_pos, limit):
|
||||
"""
|
||||
Args:
|
||||
last_pos(int):
|
||||
current_pos(int):
|
||||
limit(int):
|
||||
Returns:
|
||||
A deferred list of rows from the device inbox
|
||||
"""
|
||||
if last_pos == current_pos:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_new_device_messages_txn(txn):
|
||||
# We limit like this as we might have multiple rows per stream_id, and
|
||||
# we want to make sure we always get all entries for any stream_id
|
||||
# we return.
|
||||
upper_pos = min(current_pos, last_pos + limit)
|
||||
sql = (
|
||||
"SELECT max(stream_id), user_id"
|
||||
" FROM device_inbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY user_id"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows = txn.fetchall()
|
||||
|
||||
sql = (
|
||||
"SELECT max(stream_id), destination"
|
||||
" FROM device_federation_outbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY destination"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows.extend(txn)
|
||||
|
||||
# Order by ascending stream ordering
|
||||
rows.sort()
|
||||
|
||||
return rows
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||
)
|
||||
|
||||
|
||||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||
rows.append((user_id, device_id, stream_id, message_json))
|
||||
|
||||
txn.executemany(sql, rows)
|
||||
|
||||
def get_all_new_device_messages(self, last_pos, current_pos, limit):
|
||||
"""
|
||||
Args:
|
||||
last_pos(int):
|
||||
current_pos(int):
|
||||
limit(int):
|
||||
Returns:
|
||||
A deferred list of rows from the device inbox
|
||||
"""
|
||||
if last_pos == current_pos:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_new_device_messages_txn(txn):
|
||||
# We limit like this as we might have multiple rows per stream_id, and
|
||||
# we want to make sure we always get all entries for any stream_id
|
||||
# we return.
|
||||
upper_pos = min(current_pos, last_pos + limit)
|
||||
sql = (
|
||||
"SELECT max(stream_id), user_id"
|
||||
" FROM device_inbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY user_id"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows = txn.fetchall()
|
||||
|
||||
sql = (
|
||||
"SELECT max(stream_id), destination"
|
||||
" FROM device_federation_outbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY destination"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows.extend(txn)
|
||||
|
||||
# Order by ascending stream ordering
|
||||
rows.sort()
|
||||
|
||||
return rows
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||
)
|
||||
|
||||
@@ -1267,104 +1267,6 @@ class EventsStore(
|
||||
ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
|
||||
return ret
|
||||
|
||||
def get_current_backfill_token(self):
|
||||
"""The current minimum token that backfilled events have reached"""
|
||||
return -self._backfill_id_gen.get_current_token()
|
||||
|
||||
def get_current_events_token(self):
|
||||
"""The current maximum token that events have reached"""
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_new_forward_event_rows(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
new_event_updates = txn.fetchall()
|
||||
|
||||
if len(new_event_updates) == limit:
|
||||
upper_bound = new_event_updates[-1][0]
|
||||
else:
|
||||
upper_bound = current_id
|
||||
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" INNER JOIN ex_outlier_stream USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? < event_stream_ordering"
|
||||
" AND event_stream_ordering <= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (last_id, upper_bound))
|
||||
new_event_updates.extend(txn)
|
||||
|
||||
return new_event_updates
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||
)
|
||||
|
||||
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_new_backfill_event_rows(txn):
|
||||
sql = (
|
||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (-last_id, -current_id, limit))
|
||||
new_event_updates = txn.fetchall()
|
||||
|
||||
if len(new_event_updates) == limit:
|
||||
upper_bound = new_event_updates[-1][0]
|
||||
else:
|
||||
upper_bound = current_id
|
||||
|
||||
sql = (
|
||||
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" INNER JOIN ex_outlier_stream USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? > event_stream_ordering"
|
||||
" AND event_stream_ordering >= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (-last_id, -upper_bound))
|
||||
new_event_updates.extend(txn.fetchall())
|
||||
|
||||
return new_event_updates
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
|
||||
)
|
||||
|
||||
@cached(num_args=5, max_entries=10)
|
||||
def get_all_new_events(
|
||||
self,
|
||||
@@ -1850,22 +1752,6 @@ class EventsStore(
|
||||
|
||||
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
|
||||
|
||||
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
|
||||
def get_all_updated_current_state_deltas_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, type, state_key, event_id
|
||||
FROM current_state_delta_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (from_token, to_token, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_updated_current_state_deltas",
|
||||
get_all_updated_current_state_deltas_txn,
|
||||
)
|
||||
|
||||
def insert_labels_for_event_txn(
|
||||
self, txn, event_id, labels, room_id, topological_ordering
|
||||
):
|
||||
|
||||
@@ -963,3 +963,117 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
complexity_v1 = round(state_events / 500, 2)
|
||||
|
||||
return {"v1": complexity_v1}
|
||||
|
||||
def get_current_backfill_token(self):
|
||||
"""The current minimum token that backfilled events have reached"""
|
||||
return -self._backfill_id_gen.get_current_token()
|
||||
|
||||
def get_current_events_token(self):
|
||||
"""The current maximum token that events have reached"""
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_new_forward_event_rows(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
new_event_updates = txn.fetchall()
|
||||
|
||||
if len(new_event_updates) == limit:
|
||||
upper_bound = new_event_updates[-1][0]
|
||||
else:
|
||||
upper_bound = current_id
|
||||
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" INNER JOIN ex_outlier_stream USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? < event_stream_ordering"
|
||||
" AND event_stream_ordering <= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (last_id, upper_bound))
|
||||
new_event_updates.extend(txn)
|
||||
|
||||
return new_event_updates
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||
)
|
||||
|
||||
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_new_backfill_event_rows(txn):
|
||||
sql = (
|
||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (-last_id, -current_id, limit))
|
||||
new_event_updates = txn.fetchall()
|
||||
|
||||
if len(new_event_updates) == limit:
|
||||
upper_bound = new_event_updates[-1][0]
|
||||
else:
|
||||
upper_bound = current_id
|
||||
|
||||
sql = (
|
||||
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts, relates_to_id"
|
||||
" FROM events AS e"
|
||||
" INNER JOIN ex_outlier_stream USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" LEFT JOIN event_relations USING (event_id)"
|
||||
" WHERE ? > event_stream_ordering"
|
||||
" AND event_stream_ordering >= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (-last_id, -upper_bound))
|
||||
new_event_updates.extend(txn.fetchall())
|
||||
|
||||
return new_event_updates
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
|
||||
)
|
||||
|
||||
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
|
||||
def get_all_updated_current_state_deltas_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, type, state_key, event_id
|
||||
FROM current_state_delta_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (from_token, to_token, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_updated_current_state_deltas",
|
||||
get_all_updated_current_state_deltas_txn,
|
||||
)
|
||||
|
||||
@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||
def get_all_new_public_rooms(txn):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||
FROM public_room_list_stream
|
||||
WHERE stream_id > ? AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (prev_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
if prev_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_public_rooms", get_all_new_public_rooms
|
||||
)
|
||||
|
||||
|
||||
class RoomBackgroundUpdateStore(SQLBaseStore):
|
||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
def get_current_public_room_stream_id(self):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||
def get_all_new_public_rooms(txn):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||
FROM public_room_list_stream
|
||||
WHERE stream_id > ? AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (prev_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
if prev_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_new_public_rooms", get_all_new_public_rooms
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def block_room(self, room_id, user_id):
|
||||
"""Marks the room as blocked. Can be called multiple times.
|
||||
|
||||
@@ -15,9 +15,10 @@
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from synapse.replication.tcp.client import (
|
||||
ReplicationClientFactory,
|
||||
from synapse.replication.tcp.client import ReplicationClientFactory
|
||||
from synapse.replication.tcp.handler import (
|
||||
ReplicationClientHandler,
|
||||
WorkerReplicationDataHandler,
|
||||
)
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.storage.database import make_conn
|
||||
@@ -51,16 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.event_id = 0
|
||||
|
||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||
self.streamer = server_factory.streamer
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
|
||||
handler_factory = Mock()
|
||||
self.replication_handler = ReplicationClientHandler(self.slaved_store)
|
||||
self.replication_handler.factory = handler_factory
|
||||
|
||||
client_factory = ReplicationClientFactory(
|
||||
self.hs, "client_name", self.replication_handler
|
||||
# We now do some gut wrenching so that we have a client that is based
|
||||
# off of the slave store rather than the main store.
|
||||
self.replication_handler = ReplicationClientHandler(self.hs)
|
||||
self.replication_handler.store = self.slaved_store
|
||||
self.replication_handler.replication_data_handler = WorkerReplicationDataHandler(
|
||||
self.slaved_store
|
||||
)
|
||||
|
||||
client_factory = ReplicationClientFactory(self.hs, "client_name")
|
||||
client_factory.handler = self.replication_handler
|
||||
|
||||
server = server_factory.buildProtocol(None)
|
||||
client = client_factory.buildProtocol(None)
|
||||
|
||||
|
||||
@@ -12,9 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.replication.tcp.commands import ReplicateCommand
|
||||
from synapse.replication.tcp.handler import ReplicationClientHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
|
||||
@@ -25,23 +26,46 @@ from tests.server import FakeTransport
|
||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
self.test_handler = Mock(wraps=TestReplicationClientHandler())
|
||||
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||
self.streamer = server_factory.streamer
|
||||
server = server_factory.buildProtocol(None)
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
self.server = server_factory.buildProtocol(None)
|
||||
|
||||
# build a replication client, with a dummy handler
|
||||
handler_factory = Mock()
|
||||
self.test_handler = TestReplicationClientHandler()
|
||||
self.test_handler.factory = handler_factory
|
||||
repl_handler = ReplicationClientHandler(hs)
|
||||
repl_handler.handler = self.test_handler
|
||||
self.client = ClientReplicationStreamProtocol(
|
||||
"client", "test", clock, self.test_handler
|
||||
hs, "client", "test", clock, repl_handler,
|
||||
)
|
||||
|
||||
# wire them together
|
||||
self.client.makeConnection(FakeTransport(server, reactor))
|
||||
server.makeConnection(FakeTransport(self.client, reactor))
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
|
||||
def reconnect(self):
|
||||
if self._client_transport:
|
||||
self.client.close()
|
||||
|
||||
if self._server_transport:
|
||||
self.server.close()
|
||||
|
||||
self._client_transport = FakeTransport(self.server, self.reactor)
|
||||
self.client.makeConnection(self._client_transport)
|
||||
|
||||
self._server_transport = FakeTransport(self.client, self.reactor)
|
||||
self.server.makeConnection(self._server_transport)
|
||||
|
||||
def disconnect(self):
|
||||
if self._client_transport:
|
||||
self._client_transport = None
|
||||
self.client.close()
|
||||
|
||||
if self._server_transport:
|
||||
self._server_transport = None
|
||||
self.server.close()
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
@@ -50,29 +74,22 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.streamer.on_notifier_poke()
|
||||
self.pump(0.1)
|
||||
|
||||
def replicate_stream(self, stream, token="NOW"):
|
||||
"""Make the client end a REPLICATE command to set up a subscription to a stream"""
|
||||
self.client.send_command(ReplicateCommand(stream, token))
|
||||
|
||||
|
||||
class TestReplicationClientHandler(object):
|
||||
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
|
||||
|
||||
class TestReplicationClientHandler:
|
||||
def __init__(self):
|
||||
self.received_rdata_rows = []
|
||||
self.streams = set()
|
||||
self._received_rdata_rows = []
|
||||
|
||||
def get_streams_to_replicate(self):
|
||||
return {}
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
return []
|
||||
|
||||
def update_connection(self, connection):
|
||||
pass
|
||||
|
||||
def finished_connecting(self):
|
||||
pass
|
||||
positions = {s: 0 for s in self.streams}
|
||||
for stream, token, _ in self._received_rdata_rows:
|
||||
if stream in self.streams:
|
||||
positions[stream] = max(token, positions.get(stream, 0))
|
||||
return positions
|
||||
|
||||
async def on_rdata(self, stream_name, token, rows):
|
||||
for r in rows:
|
||||
self.received_rdata_rows.append((stream_name, token, r))
|
||||
self._received_rdata_rows.append((stream_name, token, r))
|
||||
|
||||
async def on_position(self, stream_name, token):
|
||||
pass
|
||||
|
||||
@@ -17,30 +17,63 @@ from synapse.replication.tcp.streams._base import ReceiptsStream
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
|
||||
USER_ID = "@feeling:blue"
|
||||
ROOM_ID = "!room:blue"
|
||||
EVENT_ID = "$event:blue"
|
||||
|
||||
|
||||
class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||
def test_receipt(self):
|
||||
self.reconnect()
|
||||
|
||||
# make the client subscribe to the receipts stream
|
||||
self.replicate_stream("receipts", "NOW")
|
||||
self.test_handler.streams.add("receipts")
|
||||
|
||||
# tell the master to send a new receipt
|
||||
self.get_success(
|
||||
self.hs.get_datastore().insert_receipt(
|
||||
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
|
||||
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
|
||||
)
|
||||
)
|
||||
self.replicate()
|
||||
|
||||
# there should be one RDATA command
|
||||
rdata_rows = self.test_handler.received_rdata_rows
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "receipts")
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
self.assertEqual(rdata_rows[0][0], "receipts")
|
||||
row = rdata_rows[0][2] # type: ReceiptsStream.ReceiptsStreamRow
|
||||
self.assertEqual(ROOM_ID, row.room_id)
|
||||
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
||||
self.assertEqual("!room:blue", row.room_id)
|
||||
self.assertEqual("m.read", row.receipt_type)
|
||||
self.assertEqual(USER_ID, row.user_id)
|
||||
self.assertEqual(EVENT_ID, row.event_id)
|
||||
self.assertEqual("$event:blue", row.event_id)
|
||||
self.assertEqual({"a": 1}, row.data)
|
||||
|
||||
# Now let's disconnect and insert some data.
|
||||
self.disconnect()
|
||||
|
||||
self.test_handler.on_rdata.reset_mock()
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_datastore().insert_receipt(
|
||||
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
|
||||
)
|
||||
)
|
||||
self.replicate()
|
||||
|
||||
# Nothing should have happened as we are disconnected
|
||||
self.test_handler.on_rdata.assert_not_called()
|
||||
|
||||
self.reconnect()
|
||||
self.pump(0.1)
|
||||
|
||||
# We should now have caught up and get the missing data
|
||||
self.test_handler.on_rdata.assert_called_once()
|
||||
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
|
||||
self.assertEqual(stream_name, "receipts")
|
||||
self.assertEqual(token, 3)
|
||||
self.assertEqual(1, len(rdata_rows))
|
||||
|
||||
row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
|
||||
self.assertEqual("!room2:blue", row.room_id)
|
||||
self.assertEqual("m.read", row.receipt_type)
|
||||
self.assertEqual(USER_ID, row.user_id)
|
||||
self.assertEqual("$event2:foo", row.event_id)
|
||||
self.assertEqual({"a": 2}, row.data)
|
||||
|
||||
Reference in New Issue
Block a user