Compare commits

..

1 Commits

Author SHA1 Message Date
Erik Johnston
171829bb94 Add logging to initialSync 2015-05-01 11:24:03 +01:00
119 changed files with 2302 additions and 4164 deletions

View File

@@ -35,6 +35,3 @@ Turned to Dust <dwinslow86 at gmail.com>
Brabo <brabo at riseup.net>
* Installation instruction fixes
Ivan Shapovalov <intelfx100 at gmail.com>
* contrib/systemd: a sample systemd unit file and a logger configuration

View File

@@ -1,110 +1,9 @@
Changes in synapse v0.9.2 (2015-06-12)
======================================
Changes in synapse vX
=====================
General:
* Changed config option from ``disable_registration`` to
``enable_registration``. Old option will be ignored.
* Use ultrajson for json (de)serialisation when a canonical encoding is not
required. Ultrajson is significantly faster than simplejson in certain
circumstances.
* Use connection pools for outgoing HTTP connections.
* Process thumbnails on separate threads.
Configuration:
* Add option, ``gzip_responses``, to disable HTTP response compression.
Federation:
* Improve resilience of backfill by ensuring we fetch any missing auth events.
* Improve performance of backfill and joining remote rooms by removing
unnecessary computations. This included handling events we'd previously
handled as well as attempting to compute the current state for outliers.
Changes in synapse v0.9.1 (2015-05-26)
======================================
General:
* Add support for backfilling when a client paginates. This allows servers to
request history for a room from remote servers when a client tries to
paginate history the server does not have - SYN-36
* Fix bug where you couldn't disable non-default pushrules - SYN-378
* Fix ``register_new_user`` script - SYN-359
* Improve performance of fetching events from the database, this improves both
initialSync and sending of events.
* Improve performance of event streams, allowing synapse to handle more
simultaneous connected clients.
Federation:
* Fix bug with existing backfill implementation where it returned the wrong
selection of events in some circumstances.
* Improve performance of joining remote rooms.
Configuration:
* Add support for changing the bind host of the metrics listener via the
``metrics_bind_host`` option.
Changes in synapse v0.9.0-r5 (2015-05-21)
=========================================
* Add more database caches to reduce amount of work done for each pusher. This
radically reduces CPU usage when multiple pushers are set up in the same room.
Changes in synapse v0.9.0 (2015-05-07)
======================================
General:
* Add support for using a PostgreSQL database instead of SQLite. See
`docs/postgres.rst`_ for details.
* Add password change and reset APIs. See `Registration`_ in the spec.
* Fix memory leak due to not releasing stale notifiers - SYN-339.
* Fix race in caches that occasionally caused some presence updates to be
dropped - SYN-369.
* Check server name has not changed on restart.
* Add a sample systemd unit file and a logger configuration in
contrib/systemd. Contributed Ivan Shapovalov.
Federation:
* Add key distribution mechanisms for fetching public keys of unavailable
remote home servers. See `Retrieving Server Keys`_ in the spec.
Configuration:
* Add support for multiple config files.
* Add support for dictionaries in config files.
* Remove support for specifying config options on the command line, except
for:
* ``--daemonize`` - Daemonize the home server.
* ``--manhole`` - Turn on the twisted telnet manhole service on the given
port.
* ``--database-path`` - The path to a sqlite database to use.
* ``--verbose`` - The verbosity level.
* ``--log-file`` - File to log to.
* ``--log-config`` - Python logging config file.
* ``--enable-registration`` - Enable registration for new users.
Application services:
* Reliably retry sending of events from Synapse to application services, as per
`Application Services`_ spec.
* Application services can no longer register via the ``/register`` API,
instead their configuration should be saved to a file and listed in the
synapse ``app_service_config_files`` config option. The AS configuration file
has the same format as the old ``/register`` request.
See `docs/application_services.rst`_ for more information.
.. _`docs/postgres.rst`: docs/postgres.rst
.. _`docs/application_services.rst`: docs/application_services.rst
.. _`Registration`: https://github.com/matrix-org/matrix-doc/blob/master/specification/10_client_server_api.rst#registration
.. _`Retrieving Server Keys`: https://github.com/matrix-org/matrix-doc/blob/6f2698/specification/30_server_server_api.rst#retrieving-server-keys
.. _`Application Services`: https://github.com/matrix-org/matrix-doc/blob/0c6bd9/specification/25_application_service_api.rst#home-server---application-service-api
Changes in synapse v0.8.1 (2015-03-18)
======================================

View File

@@ -117,7 +117,7 @@ Installing prerequisites on Mac OS X::
To install the synapse homeserver run::
$ virtualenv -p python2.7 ~/.synapse
$ virtualenv ~/.synapse
$ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
@@ -318,7 +318,7 @@ ArchLinux
If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid
...or by editing synctl with the correct python executable.
@@ -409,6 +409,7 @@ SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \
--bind-port 8448 \
--config-path homeserver.yaml \
--generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml

View File

@@ -1,4 +1,4 @@
Upgrading to v0.9.0
Upgrading to v0.x.x
===================
Application services have had a breaking API change in this version.

View File

@@ -21,5 +21,3 @@ handlers:
root:
level: INFO
handlers: [journal]
disable_existing_loggers: False

View File

@@ -16,31 +16,30 @@ if [ $# -eq 1 ]; then
fi
fi
export PYTHONPATH=$(readlink -f $(pwd))
echo $PYTHONPATH
for port in 8080 8081 8082; do
echo "Starting server on port $port... "
https_port=$((port + 400))
mkdir -p demo/$port
pushd demo/$port
#rm $DIR/etc/$port.config
python -m synapse.app.homeserver \
--generate-config \
--enable_registration \
--config-path "demo/etc/$port.config" \
-p "$https_port" \
--unsecure-port "$port" \
-H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \
-f "$DIR/$port.log" \
-d "$DIR/$port.db" \
-D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \
--media-store-path "demo/media_store.$port" \
$PARAMS $SYNAPSE_PARAMS \
--enable-registration
python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \
-D \
--config-path "demo/etc/$port.config" \
-vv \
popd
done
cd "$CWD"

View File

@@ -1,36 +0,0 @@
Registering an Application Service
==================================
The registration of new application services depends on the homeserver used.
In synapse, you need to create a new configuration file for your AS and add it
to the list specified under the ``app_service_config_files`` config
option in your synapse config.
For example:
.. code-block:: yaml
app_service_config_files:
- /home/matrix/.synapse/<your-AS>.yaml
The format of the AS configuration file is as follows:
.. code-block:: yaml
url: <base url of AS>
as_token: <token AS will add to requests to HS>
hs_token: <token HS will add to requests to AS>
sender_localpart: <localpart of AS user>
namespaces:
users: # List of users we're interested in
- exclusive: <bool>
regex: <regex>
- ...
aliases: [] # List of aliases we're interested in
rooms: [] # List of room ids we're interested in
See the spec_ for further details on how application services work.
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api

View File

@@ -34,15 +34,19 @@ Synapse config
When you are ready to start using PostgreSQL, add the following line to your
config file::
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
database_config: <db_config_file>
Where ``<db_config_file>`` is the file name that points to a yaml file of the
following form::
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
function, except keys beginning with ``cp_``, which are consumed by the twisted
@@ -82,13 +86,13 @@ complete, restart synapse. For instance::
cp homeserver.db homeserver.db.snapshot
./synctl start
Assuming your new config file (as described in the section *Synapse config*)
is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at
Assuming your database config file (as described in the section *Synapse
config*) is named ``database_config.yaml`` and the SQLite snapshot is at
``homeserver.db.snapshot`` then simply run::
python scripts/port_from_sqlite_to_postgres.py \
--sqlite-database homeserver.db.snapshot \
--postgres-config homeserver-postgres.yaml
--postgres-config database_config.yaml
The flag ``--curses`` displays a coloured curses progress UI.

View File

@@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest()
data = {
"user": user,
"username": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
}
server_location = server_location.rstrip("/")
@@ -44,7 +43,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..."
req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,),
"%s/_matrix/client/v2_alpha/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)

View File

@@ -1,116 +0,0 @@
import psycopg2
import yaml
import sys
import json
import time
import hashlib
from syutil.base64util import encode_base64
from syutil.crypto.signing_key import read_signing_keys
from syutil.crypto.jsonsign import sign_json
from syutil.jsonutil import encode_canonical_json
def select_v1_keys(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, verify_key in rows:
results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
return results
def select_v1_certs(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, tls_certificate in rows:
results[server_name] = tls_certificate
return results
def select_v2_json(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, key_json in rows:
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
return results
def convert_v1_to_v2(server_name, valid_until, keys, certificate):
return {
"old_verify_keys": {},
"server_name": server_name,
"verify_keys": {
key_id: {"key": key}
for key_id, key in keys.items()
},
"valid_until_ts": valid_until,
"tls_fingerprints": [fingerprint(certificate)],
}
def fingerprint(certificate):
finger = hashlib.sha256(certificate)
return {"sha256": encode_base64(finger.digest())}
def rows_v2(server, json):
valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
def main():
config = yaml.load(open(sys.argv[1]))
valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
server_name = config["server_name"]
signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
database = config["database"]
assert database["name"] == "psycopg2", "Can only convert for postgresql"
args = database["args"]
args.pop("cp_max")
args.pop("cp_min")
connection = psycopg2.connect(**args)
keys = select_v1_keys(connection)
certificates = select_v1_certs(connection)
json = select_v2_json(connection)
result = {}
for server in keys:
if not server in json:
v2_json = convert_v1_to_v2(
server, valid_until, keys[server], certificates[server]
)
v2_json = sign_json(v2_json, server_name, signing_key)
result[server] = v2_json
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
rows = list(
row for server, json in result.items()
for row in rows_v2(server, json)
)
cursor = connection.cursor()
cursor.executemany(
"INSERT INTO server_keys_json ("
" server_name, key_id, from_server,"
" ts_added_ms, ts_valid_until_ms, key_json"
") VALUES (%s, %s, %s, %s, %s, %s)",
rows
)
connection.commit()
if __name__ == '__main__':
main()

10
scripts/port_from_sqlite_to_postgres.py Executable file → Normal file
View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
@@ -106,7 +105,7 @@ class Store(object):
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine, []),
LoggingTransaction(txn, desc, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.DatabaseError as e:
@@ -378,7 +377,9 @@ class Porter(object):
for i, row in enumerate(rows):
rows[i] = tuple(
conv(j, col)
self.postgres_store.database_engine.encode_parameter(
conv(j, col)
)
for j, col in enumerate(row)
if j > 0
)
@@ -723,9 +724,6 @@ if __name__ == "__main__":
postgres_config = yaml.safe_load(args.postgres_config)
if "database" in postgres_config:
postgres_config = postgres_config["database"]
if "name" not in postgres_config:
sys.stderr.write("Malformed database config: no 'name'")
sys.exit(2)

2
scripts/upgrade_db_to_v0.6.0.py Executable file → Normal file
View File

@@ -1,4 +1,4 @@
#!/usr/bin/env python
from synapse.storage import SCHEMA_VERSION, read_schema
from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore

View File

@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from setuptools import setup, find_packages
@@ -56,5 +55,5 @@ setup(
include_package_data=True,
zip_safe=False,
long_description=long_description,
scripts=["synctl"] + glob.glob("scripts/*"),
scripts=["synctl", "register_new_matrix_user"],
)

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.9.2"
__version__ = "0.8.1-r4"

View File

@@ -20,6 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo
import logging
@@ -64,10 +65,7 @@ class Auth(object):
if event.type == EventTypes.Aliases:
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
logger.debug("Auth events: %s", auth_events)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
@@ -362,7 +360,7 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
self.store.insert_client_ip(
yield self.store.insert_client_ip(
user=user,
access_token=access_token,
device_id=user_info["device_id"],
@@ -426,6 +424,8 @@ class Auth(object):
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes(

View File

@@ -32,9 +32,9 @@ from synapse.server import HomeServer
from twisted.internet import reactor
from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.resource import Resource
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory
from twisted.web.server import Site
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
@@ -54,8 +54,6 @@ from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse import events
from daemonize import Daemonize
import twisted.manhole.telnet
@@ -71,32 +69,16 @@ import subprocess
logger = logging.getLogger("synapse.app.homeserver")
class GzipFile(File):
def getChild(self, path, request):
child = File.getChild(self, path, request)
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer):
def build_http_client(self):
return MatrixFederationHttpClient(self)
def build_resource_for_client(self):
res = ClientV1RestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
return ClientV1RestResource(self)
def build_resource_for_client_v2_alpha(self):
res = ClientV2AlphaRestResource(self)
if self.config.gzip_responses:
res = gz_wrap(res)
return res
return ClientV2AlphaRestResource(self)
def build_resource_for_federation(self):
return JsonResource(self)
@@ -105,16 +87,9 @@ class SynapseHomeServer(HomeServer):
import syweb
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File("static")
def build_resource_for_content_repo(self):
@@ -285,12 +260,9 @@ class SynapseHomeServer(HomeServer):
config,
metrics_resource,
),
interface=config.metrics_bind_host,
)
logger.info(
"Metrics now running on %s port %d",
config.metrics_bind_host, config.metrics_port,
interface="127.0.0.1",
)
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
@@ -423,8 +395,6 @@ def setup(config_options):
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name
else:
@@ -439,6 +409,7 @@ def setup(config_options):
config.server_name,
domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"),
db_name=config.database_path,
db_config=config.database_config,
tls_context_factory=tls_context_factory,
config=config,
@@ -451,7 +422,9 @@ def setup(config_options):
redirect_root_to_web_client=True,
)
logger.info("Preparing database: %r...", config.database_config)
db_name = hs.get_db_name()
logger.info("Preparing database: %s...", db_name)
try:
db_conn = database_engine.module.connect(
@@ -473,7 +446,7 @@ def setup(config_options):
)
sys.exit(1)
logger.info("Database prepared in %r.", config.database_config)
logger.info("Database prepared in %s.", db_name)
if config.manhole:
f = twisted.manhole.telnet.ShellFactory()
@@ -526,31 +499,11 @@ class SynapseSite(Site):
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
def profile(func):
from cProfile import Profile
from threading import current_thread
def profiled(*args, **kargs):
profile = Profile()
profile.enable()
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
hs.hostname, func.__name__, ident
))
return profiled
from twisted.python.threadpool import ThreadPool
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
def in_thread():
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()
if hs.config.daemonize:

View File

@@ -18,33 +18,29 @@ import sys
import os
import subprocess
import signal
import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml"
PIDFILE = "homeserver.pid"
GREEN = "\x1b[1;32m"
NORMAL = "\x1b[m"
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
CONFIG = yaml.load(open(CONFIGFILE))
PIDFILE = CONFIG["pid_file"]
def start():
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
print "Starting ...",
args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE])
args.extend(["--daemonize", "-c", CONFIGFILE, "--pid-file", PIDFILE])
subprocess.check_call(args)
print GREEN + "started" + NORMAL

View File

@@ -148,8 +148,8 @@ class ApplicationService(object):
and self.is_interested_in_user(event.state_key)):
return True
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
for member in member_list:
if self.is_interested_in_user(member.state_key):
return True
return False
@@ -173,7 +173,7 @@ class ApplicationService(object):
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
member_list(list): A list of all joined room members in this room.
Returns:
bool: True if this service would like to know about this event.
"""

View File

@@ -14,10 +14,9 @@
# limitations under the License.
import argparse
import sys
import os
import yaml
import sys
from textwrap import dedent
class ConfigError(Exception):
@@ -25,35 +24,18 @@ class ConfigError(Exception):
class Config(object):
def __init__(self, args):
pass
@staticmethod
def parse_size(value):
if isinstance(value, int) or isinstance(value, long):
return value
def parse_size(string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = value[-1]
suffix = string[-1]
if suffix in sizes:
value = value[:-1]
string = string[:-1]
size = sizes[suffix]
return int(value) * size
@staticmethod
def parse_duration(value):
if isinstance(value, int) or isinstance(value, long):
return value
second = 1000
hour = 60 * 60 * second
day = 24 * hour
week = 7 * day
year = 365 * day
sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
size = 1
suffix = value[-1]
if suffix in sizes:
value = value[:-1]
size = sizes[suffix]
return int(value) * size
return int(string) * size
@staticmethod
def abspath(file_path):
@@ -95,6 +77,17 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
@classmethod
def read_yaml_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
with open(file_path) as file_stream:
try:
return yaml.load(file_stream)
except:
raise ConfigError(
"Error parsing yaml in file %r" % (file_path,)
)
@staticmethod
def default_path(name):
return os.path.abspath(os.path.join(os.path.curdir, name))
@@ -104,130 +97,84 @@ class Config(object):
with open(file_path) as file_stream:
return yaml.load(file_stream)
def invoke_all(self, name, *args, **kargs):
results = []
for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
@classmethod
def add_arguments(cls, parser):
pass
def generate_config(self, config_dir_path, server_name):
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", config_dir_path, server_name
))
config = yaml.load(default_config)
return default_config, config
@classmethod
def generate_config(cls, args, config_dir_path):
pass
@classmethod
def load_config(cls, description, argv, generate_section=None):
obj = cls()
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c", "--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file"
)
config_parser.add_argument(
"--generate-config",
action="store_true",
help="Generate a config file for the server name"
)
config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
help="Generate config file"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
if config_args.generate_config:
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME"
" -c CONFIG-FILE\""
"Must specify where to generate the config file"
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
(config_path,) = config_args.config_path
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
if os.path.exists(config_path):
print "Config file %r already exists" % (config_path,)
yaml_config = cls.read_config_file(config_path)
yaml_name = yaml_config["server_name"]
if server_name != yaml_name:
print (
"Config file %r has a different server_name: "
" %r != %r" % (config_path, server_name, yaml_name)
)
sys.exit(1)
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
config.update(yaml_config)
print "Generating any missing keys for %r" % (server_name,)
obj.invoke_all("generate_files", config)
sys.exit(0)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print (
"A config file has been generated in %s for server name"
" '%s' with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to"
" your needs."
) % (config_path, server_name)
print (
"If this server name is incorrect, you will need to regenerate"
" the SSL certificates"
)
sys.exit(0)
config_dir_path = os.path.dirname(config_args.config_path)
if os.path.exists(config_args.config_path):
defaults = cls.read_config_file(config_args.config_path)
else:
defaults = {}
else:
if config_args.config_path:
defaults = cls.read_config_file(config_args.config_path)
else:
defaults = {}
parser = argparse.ArgumentParser(
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
cls.add_arguments(parser)
parser.set_defaults(**defaults)
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME"
" -c CONFIG-FILE\""
if config_args.generate_config:
config_dir_path = os.path.dirname(config_args.config_path)
config_dir_path = os.path.abspath(config_dir_path)
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
cls.generate_config(args, config_dir_path)
config = {}
for key, value in vars(args).items():
if (key not in set(["config_path", "generate_config"])
and value is not None):
config[key] = value
with open(config_args.config_path, "w") as config_file:
# TODO(mark/paul) We might want to output emacs-style mode
# markers as well as vim-style mode markers into the file,
# to further hint to people this is a YAML file.
config_file.write("# vim:ft=yaml\n")
yaml.dump(config, config_file, default_flow_style=False)
print (
"A config file has been generated in %s for server name"
" '%s' with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to"
" your needs."
) % (
config_args.config_path, config['server_name']
)
print (
"If this server name is incorrect, you will need to regenerate"
" the SSL certificates"
)
sys.exit(0)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {}
for config_path in config_args.config_path:
yaml_config = cls.read_config_file(config_path)
specified_config.update(yaml_config)
server_name = specified_config["server_name"]
_, config = obj.generate_config(config_dir_path, server_name)
config.pop("log_config")
config.update(specified_config)
obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)
return obj
return cls(args)

View File

@@ -17,11 +17,15 @@ from ._base import Config
class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
def __init__(self, args):
super(AppServiceConfig, self).__init__(args)
self.app_service_config_files = args.app_service_config_files
def default_config(cls, config_dir_path, server_name):
return """\
# A list of application service config file to use
app_service_config_files: []
"""
@classmethod
def add_arguments(cls, parser):
super(AppServiceConfig, cls).add_arguments(parser)
group = parser.add_argument_group("appservice")
group.add_argument(
"--app-service-config-files", type=str, nargs='+',
help="A list of application service config files to use."
)

View File

@@ -17,39 +17,42 @@ from ._base import Config
class CaptchaConfig(Config):
def read_config(self, config):
self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = config["enable_registration_captcha"]
def __init__(self, args):
super(CaptchaConfig, self).__init__(args)
self.recaptcha_private_key = args.recaptcha_private_key
self.recaptcha_public_key = args.recaptcha_public_key
self.enable_registration_captcha = args.enable_registration_captcha
# XXX: This is used for more than just captcha
self.captcha_ip_origin_is_x_forwarded = (
config["captcha_ip_origin_is_x_forwarded"]
args.captcha_ip_origin_is_x_forwarded
)
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
self.captcha_bypass_secret = args.captcha_bypass_secret
def default_config(self, config_dir_path, server_name):
return """\
## Captcha ##
# This Home Server's ReCAPTCHA public key.
recaptcha_private_key: "YOUR_PUBLIC_KEY"
# This Home Server's ReCAPTCHA private key.
recaptcha_public_key: "YOUR_PRIVATE_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
enable_registration_captcha: False
# When checking captchas, use the X-Forwarded-For (XFF) header
# as the client IP and not the actual client IP.
captcha_ip_origin_is_x_forwarded: False
# A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
"""
@classmethod
def add_arguments(cls, parser):
super(CaptchaConfig, cls).add_arguments(parser)
group = parser.add_argument_group("recaptcha")
group.add_argument(
"--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY",
help="This Home Server's ReCAPTCHA public key."
)
group.add_argument(
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
help="This Home Server's ReCAPTCHA private key."
)
group.add_argument(
"--enable-registration-captcha", type=bool, default=False,
help="Enables ReCaptcha checks when registering, preventing signup"
+ " unless a captcha is answered. Requires a valid ReCaptcha "
+ "public/private key."
)
group.add_argument(
"--captcha_ip_origin_is_x_forwarded", type=bool, default=False,
help="When checking captchas, use the X-Forwarded-For (XFF) header"
+ " as the client IP and not the actual client IP."
)
group.add_argument(
"--captcha_bypass_secret", type=str,
help="A secret key used to bypass the captcha test entirely."
)

View File

@@ -14,21 +14,28 @@
# limitations under the License.
from ._base import Config
import os
import yaml
class DatabaseConfig(Config):
def __init__(self, args):
super(DatabaseConfig, self).__init__(args)
if args.database_path == ":memory:":
self.database_path = ":memory:"
else:
self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size)
def read_config(self, config):
self.event_cache_size = self.parse_size(
config.get("event_cache_size", "10K")
)
self.database_config = config.get("database")
if self.database_config is None:
if args.database_config:
with open(args.database_config) as f:
self.database_config = yaml.safe_load(f)
else:
self.database_config = {
"name": "sqlite3",
"args": {},
"args": {
"database": self.database_path,
},
}
name = self.database_config.get("name", None)
@@ -43,37 +50,24 @@ class DatabaseConfig(Config):
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
self.set_databasepath(config.get("database_path"))
def default_config(self, config, config_dir_path):
database_path = self.abspath("homeserver.db")
return """\
# Database configuration
database:
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
# Number of events to cache in memory.
event_cache_size: "10K"
""" % locals()
def read_arguments(self, args):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
if database_path != ":memory:":
database_path = self.abspath(database_path)
if self.database_config.get("name", None) == "sqlite3":
if database_path is not None:
self.database_config["args"]["database"] = database_path
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(DatabaseConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("database")
db_group.add_argument(
"-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use."
"-d", "--database-path", default="homeserver.db",
metavar="SQLITE_DATABASE_PATH", help="The database name."
)
db_group.add_argument(
"--event-cache-size", default="100K",
help="Number of events to cache in memory."
)
db_group.add_argument(
"--database-config", default=None,
help="Location of the database configuration file."
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(DatabaseConfig, cls).generate_config(args, config_dir_path)
args.database_path = os.path.abspath(args.database_path)

View File

@@ -36,6 +36,4 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
if __name__ == '__main__':
import sys
sys.stdout.write(
HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
)
HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer")

View File

@@ -20,58 +20,48 @@ from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from syutil.base64util import decode_base64
from synapse.util.stringutils import random_string
class KeyConfig(Config):
def read_config(self, config):
self.signing_key = self.read_signing_key(config["signing_key_path"])
def __init__(self, args):
super(KeyConfig, self).__init__(args)
self.signing_key = self.read_signing_key(args.signing_key_path)
self.old_signing_keys = self.read_old_signing_keys(
config["old_signing_keys"]
)
self.key_refresh_interval = self.parse_duration(
config["key_refresh_interval"]
args.old_signing_key_path
)
self.key_refresh_interval = args.key_refresh_interval
self.perspectives = self.read_perspectives(
config["perspectives"]
args.perspectives_config_path
)
def default_config(self, config_dir_path, server_name):
base_key_name = os.path.join(config_dir_path, server_name)
return """\
## Signing Keys ##
@classmethod
def add_arguments(cls, parser):
super(KeyConfig, cls).add_arguments(parser)
key_group = parser.add_argument_group("keys")
key_group.add_argument("--signing-key-path",
help="The signing key to sign messages with")
key_group.add_argument("--old-signing-key-path",
help="The keys that the server used to sign"
" sign messages with but won't use"
" to sign new messages. E.g. it has"
" lost its private key")
key_group.add_argument("--key-refresh-interval",
default=24 * 60 * 60 * 1000, # 1 Day
help="How long a key response is valid for."
" Used to set the exipiry in /key/v2/."
" Controls how frequently servers will"
" query what keys are still valid")
key_group.add_argument("--perspectives-config-path",
help="The trusted servers to download signing"
" keys from")
# Path to the signing key to sign messages with
signing_key_path: "%(base_key_name)s.signing.key"
# The keys that the server used to sign messages with but won't use
# to sign new messages. E.g. it has lost its private key
old_signing_keys: {}
# "ed25519:auto":
# # Base64 encoded public key
# key: "The public part of your old signing key."
# # Millisecond POSIX timestamp when the key expired.
# expired_ts: 123456789123
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
# Determines how quickly servers will query to check which keys
# are still valid.
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
perspectives:
servers:
"matrix.org":
verify_keys:
"ed25519:auto":
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
""" % locals()
def read_perspectives(self, perspectives_config):
def read_perspectives(self, perspectives_config_path):
config = self.read_yaml_file(
perspectives_config_path, "perspectives_config_path"
)
servers = {}
for server_name, server_config in perspectives_config["servers"].items():
for server_name, server_config in config["servers"].items():
for key_id, key_data in server_config["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
@@ -92,42 +82,66 @@ class KeyConfig(Config):
" Try running again with --generate-config"
)
def read_old_signing_keys(self, old_signing_keys):
keys = {}
for key_id, key_data in old_signing_keys.items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired_ts = key_data["expired_ts"]
keys[key_id] = verify_key
else:
raise ConfigError(
"Unsupported signing algorithm for old key: %r" % (key_id,)
)
return keys
def read_old_signing_keys(self, old_signing_key_path):
old_signing_keys = self.read_file(
old_signing_key_path, "old_signing_key"
)
try:
return syutil.crypto.signing_key.read_old_signing_keys(
old_signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading old signing keys."
)
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
if not os.path.exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
@classmethod
def generate_config(cls, args, config_dir_path):
super(KeyConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
args.pid_file = os.path.abspath(args.pid_file)
if not args.signing_key_path:
args.signing_key_path = base_key_name + ".signing.key"
if not os.path.exists(args.signing_key_path):
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.signing_key.generate_signing_key(key_id),),
(syutil.crypto.signing_key.generate_signing_key("auto"),),
)
else:
signing_keys = self.read_file(signing_key_path, "signing_key")
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key_id = "a_" + random_string(4)
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
key_id,
"auto",
signing_keys.split("\n")[0]
)
with open(signing_key_path, "w") as signing_key_file:
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)
if not args.old_signing_key_path:
args.old_signing_key_path = base_key_name + ".old.signing.keys"
if not os.path.exists(args.old_signing_key_path):
with open(args.old_signing_key_path, "w"):
pass
if not args.perspectives_config_path:
args.perspectives_config_path = base_key_name + ".perspectives"
if not os.path.exists(args.perspectives_config_path):
with open(args.perspectives_config_path, "w") as perspectives_file:
perspectives_file.write(
'servers:\n'
' matrix.org:\n'
' verify_keys:\n'
' "ed25519:auto":\n'
' key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"\n'
)

View File

@@ -19,88 +19,25 @@ from twisted.python.log import PythonLoggingObserver
import logging
import logging.config
import yaml
from string import Template
import os
DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
- %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ${log_file}
maxBytes: 104857600
backupCount: 10
filters: [context]
level: INFO
console:
class: logging.StreamHandler
formatter: precise
loggers:
synapse:
level: INFO
synapse.storage.SQL:
level: INFO
root:
level: INFO
handlers: [file, console]
""")
class LoggingConfig(Config):
def __init__(self, args):
super(LoggingConfig, self).__init__(args)
self.verbosity = int(args.verbose) if args.verbose else None
self.log_config = self.abspath(args.log_config)
self.log_file = self.abspath(args.log_file)
def read_config(self, config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name):
log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
# Logging verbosity level.
verbose: 0
# File to write logging to
log_file: "%(log_file)s"
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
def read_arguments(self, args):
if args.verbose is not None:
self.verbosity = args.verbose
if args.log_config is not None:
self.log_config = args.log_config
if args.log_file is not None:
self.log_file = args.log_file
@classmethod
def add_arguments(cls, parser):
super(LoggingConfig, cls).add_arguments(parser)
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count',
help="The verbosity level."
)
logging_group.add_argument(
'-f', '--log-file', dest="log_file",
'-f', '--log-file', dest="log_file", default="homeserver.log",
help="File to log to."
)
logging_group.add_argument(
@@ -108,14 +45,6 @@ class LoggingConfig(Config):
help="Python logging config file"
)
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
with open(log_config, "wb") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
)
def setup_logging(self):
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"

View File

@@ -17,21 +17,20 @@ from ._base import Config
class MetricsConfig(Config):
def read_config(self, config):
self.enable_metrics = config["enable_metrics"]
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def __init__(self, args):
super(MetricsConfig, self).__init__(args)
self.enable_metrics = args.enable_metrics
self.metrics_port = args.metrics_port
def default_config(self, config_dir_path, server_name):
return """\
## Metrics ###
# Enable collection and rendering of performance metrics
enable_metrics: False
# Separate port to accept metrics requests on
# metrics_port: 8081
# Which host to bind the metric listener to
# metrics_bind_host: 127.0.0.1
"""
@classmethod
def add_arguments(cls, parser):
super(MetricsConfig, cls).add_arguments(parser)
metrics_group = parser.add_argument_group("metrics")
metrics_group.add_argument(
'--enable-metrics', dest="enable_metrics", action="store_true",
help="Enable collection and rendering of performance metrics"
)
metrics_group.add_argument(
'--metrics-port', metavar="PORT", type=int,
help="Separate port to accept metrics requests on (on localhost)"
)

View File

@@ -17,42 +17,56 @@ from ._base import Config
class RatelimitConfig(Config):
def read_config(self, config):
self.rc_messages_per_second = config["rc_messages_per_second"]
self.rc_message_burst_count = config["rc_message_burst_count"]
def __init__(self, args):
super(RatelimitConfig, self).__init__(args)
self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = config["federation_rc_window_size"]
self.federation_rc_sleep_limit = config["federation_rc_sleep_limit"]
self.federation_rc_sleep_delay = config["federation_rc_sleep_delay"]
self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = config["federation_rc_concurrent"]
self.federation_rc_window_size = args.federation_rc_window_size
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
self.federation_rc_reject_limit = args.federation_rc_reject_limit
self.federation_rc_concurrent = args.federation_rc_concurrent
def default_config(self, config_dir_path, server_name):
return """\
## Ratelimiting ##
@classmethod
def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser)
rc_group = parser.add_argument_group("ratelimiting")
rc_group.add_argument(
"--rc-messages-per-second", type=float, default=0.2,
help="number of messages a client can send per second"
)
rc_group.add_argument(
"--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled"
)
# Number of messages a client can send per second
rc_messages_per_second: 0.2
rc_group.add_argument(
"--federation-rc-window-size", type=int, default=10000,
help="The federation window size in milliseconds",
)
# Number of message a client can send before being throttled
rc_message_burst_count: 10.0
rc_group.add_argument(
"--federation-rc-sleep-limit", type=int, default=10,
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
# The federation window size in milliseconds
federation_rc_window_size: 1000
rc_group.add_argument(
"--federation-rc-sleep-delay", type=int, default=500,
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
# The number of federation requests from a single server in a window
# before the server will delay processing the request.
federation_rc_sleep_limit: 10
rc_group.add_argument(
"--federation-rc-reject-limit", type=int, default=50,
help="The maximum number of concurrent federation requests allowed"
" from a single server",
)
# The duration in milliseconds to delay processing events from
# remote servers by if they go over the sleep limit.
federation_rc_sleep_delay: 500
# The maximum number of concurrent federation requests allowed
# from a single server
federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
federation_rc_concurrent: 3
"""
rc_group.add_argument(
"--federation-rc-concurrent", type=int, default=3,
help="The number of federation requests to concurrently process"
" from a single server",
)

View File

@@ -17,44 +17,45 @@ from ._base import Config
from synapse.util.stringutils import random_string_with_symbols
from distutils.util import strtobool
import distutils.util
class RegistrationConfig(Config):
def read_config(self, config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
# `args.enable_registration` may either be a bool or a string depending
# on if the option was given a value (e.g. --enable-registration=true
# would set `args.enable_registration` to "true" not True.)
self.disable_registration = not bool(
strtobool(str(config["enable_registration"]))
distutils.util.strtobool(str(args.enable_registration))
)
if "disable_registration" in config:
self.disable_registration = bool(
strtobool(str(config["disable_registration"]))
)
self.registration_shared_secret = args.registration_shared_secret
self.registration_shared_secret = config.get("registration_shared_secret")
def default_config(self, config_dir, server_name):
registration_shared_secret = random_string_with_symbols(50)
return """\
## Registration ##
# Enable registration for new users.
enable_registration: False
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
""" % locals()
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--enable-registration", action="store_true", default=None,
help="Enable registration for new users."
"--enable-registration",
const=True,
default=False,
nargs='?',
help="Enable registration for new users.",
)
reg_group.add_argument(
"--registration-shared-secret", type=str,
help="If set, allows registration by anyone who also has the shared"
" secret, even if registration is otherwise disabled.",
)
def read_arguments(self, args):
if args.enable_registration is not None:
self.disable_registration = not bool(
strtobool(str(args.enable_registration))
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(RegistrationConfig, cls).generate_config(args, config_dir_path)
if args.enable_registration is None:
args.enable_registration = False
if args.registration_shared_secret is None:
args.registration_shared_secret = random_string_with_symbols(50)

View File

@@ -17,20 +17,32 @@ from ._base import Config
class ContentRepositoryConfig(Config):
def read_config(self, config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"])
def __init__(self, args):
super(ContentRepositoryConfig, self).__init__(args)
self.max_upload_size = self.parse_size(args.max_upload_size)
self.max_image_pixels = self.parse_size(args.max_image_pixels)
self.media_store_path = self.ensure_directory(args.media_store_path)
def default_config(self, config_dir_path, server_name):
media_store = self.default_path("media_store")
return """
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
def parse_size(self, string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = string[-1]
if suffix in sizes:
string = string[:-1]
size = sizes[suffix]
return int(string) * size
# The largest allowed upload size in bytes
max_upload_size: "10M"
# Maximum number of pixels that will be thumbnailed
max_image_pixels: "32M"
""" % locals()
@classmethod
def add_arguments(cls, parser):
super(ContentRepositoryConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("content_repository")
db_group.add_argument(
"--max-upload-size", default="10M"
)
db_group.add_argument(
"--media-store-path", default=cls.default_path("media_store")
)
db_group.add_argument(
"--max-image-pixels", default="32M",
help="Maximum number of pixels that will be thumbnailed"
)

View File

@@ -17,95 +17,64 @@ from ._base import Config
class ServerConfig(Config):
def __init__(self, args):
super(ServerConfig, self).__init__(args)
self.server_name = args.server_name
self.bind_port = args.bind_port
self.bind_host = args.bind_host
self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file)
self.web_client = args.web_client
self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit
def read_config(self, config):
self.server_name = config["server_name"]
self.bind_port = config["bind_port"]
self.bind_host = config["bind_host"]
self.unsecure_port = config["unsecure_port"]
self.manhole = config.get("manhole")
self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.gzip_responses = config["gzip_responses"]
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
host = self.server_name
if not args.content_addr:
host = args.server_name
if ':' not in host:
host = "%s:%d" % (host, self.unsecure_port)
host = "%s:%d" % (host, args.unsecure_port)
else:
host = host.split(':')[0]
host = "%s:%d" % (host, self.unsecure_port)
content_addr = "http://%s" % (host,)
host = "%s:%d" % (host, args.unsecure_port)
args.content_addr = "http://%s" % (host,)
self.content_addr = content_addr
self.content_addr = args.content_addr
def default_config(self, config_dir_path, server_name):
if ":" in server_name:
bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400
else:
bind_port = 8448
unsecure_port = 8008
pid_file = self.abspath("homeserver.pid")
return """\
## Server ##
# The domain name of the server, with optional explicit port.
# This is used by remote servers to connect to this server,
# e.g. matrix.org, localhost:8080, etc.
server_name: "%(server_name)s"
# The port to listen for HTTPS requests on.
# For when matrix traffic is sent directly to synapse.
bind_port: %(bind_port)s
# The port to listen for HTTP requests on.
# For when matrix traffic passes through loadbalancer that unwraps TLS.
unsecure_port: %(unsecure_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_host: ""
# When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
soft_file_limit: 0
# Turn on the twisted telnet manhole service on localhost on the given
# port.
#manhole: 9000
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
gzip_responses: True
""" % locals()
def read_arguments(self, args):
if args.manhole is not None:
self.manhole = args.manhole
if args.daemonize is not None:
self.daemonize = args.daemonize
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(ServerConfig, cls).add_arguments(parser)
server_group = parser.add_argument_group("server")
server_group.add_argument(
"-H", "--server-name", default="localhost",
help="The domain name of the server, with optional explicit port. "
"This is used by remote servers to connect to this server, "
"e.g. matrix.org, localhost:8080, etc."
)
server_group.add_argument("-p", "--bind-port", metavar="PORT",
type=int, help="https port to listen on",
default=8448)
server_group.add_argument("--unsecure-port", metavar="PORT",
type=int, help="http port to listen on",
default=8008)
server_group.add_argument("--bind-host", default="",
help="Local interface to listen on")
server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server")
server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to"
" store the pid in")
server_group.add_argument('--web_client', default=True, type=bool,
help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
" service on the given port.")
server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the "
"content repository")
server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")

View File

@@ -23,44 +23,37 @@ GENERATE_DH_PARAMS = False
class TlsConfig(Config):
def read_config(self, config):
def __init__(self, args):
super(TlsConfig, self).__init__(args)
self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path")
args.tls_certificate_path
)
self.no_tls = config.get("no_tls", False)
self.no_tls = args.no_tls
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
config.get("tls_private_key_path")
args.tls_private_key_path
)
self.tls_dh_params_path = self.check_file(
config.get("tls_dh_params_path"), "tls_dh_params"
args.tls_dh_params_path, "tls_dh_params"
)
def default_config(self, config_dir_path, server_name):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
tls_dh_params_path = base_key_name + ".tls.dh"
return """\
# PEM encoded X509 certificate for TLS
tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS
tls_private_key_path: "%(tls_private_key_path)s"
# PEM dh parameters for ephemeral keys
tls_dh_params_path: "%(tls_dh_params_path)s"
# Don't bind to the https port
no_tls: False
""" % locals()
@classmethod
def add_arguments(cls, parser):
super(TlsConfig, cls).add_arguments(parser)
tls_group = parser.add_argument_group("tls")
tls_group.add_argument("--tls-certificate-path",
help="PEM encoded X509 certificate for TLS")
tls_group.add_argument("--tls-private-key-path",
help="PEM encoded private key for TLS")
tls_group.add_argument("--tls-dh-params-path",
help="PEM dh parameters for ephemeral keys")
tls_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
@@ -70,13 +63,22 @@ class TlsConfig(Config):
private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
def generate_files(self, config):
tls_certificate_path = config["tls_certificate_path"]
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
@classmethod
def generate_config(cls, args, config_dir_path):
super(TlsConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
if not os.path.exists(tls_private_key_path):
with open(tls_private_key_path, "w") as private_key_file:
if args.tls_certificate_path is None:
args.tls_certificate_path = base_key_name + ".tls.crt"
if args.tls_private_key_path is None:
args.tls_private_key_path = base_key_name + ".tls.key"
if args.tls_dh_params_path is None:
args.tls_dh_params_path = base_key_name + ".tls.dh"
if not os.path.exists(args.tls_private_key_path):
with open(args.tls_private_key_path, "w") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -84,17 +86,17 @@ class TlsConfig(Config):
)
private_key_file.write(private_key_pem)
else:
with open(tls_private_key_path) as private_key_file:
with open(args.tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem
)
if not os.path.exists(tls_certificate_path):
with open(tls_certificate_path, "w") as certifcate_file:
if not os.path.exists(args.tls_certificate_path):
with open(args.tls_certificate_path, "w") as certifcate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
subject.CN = args.server_name
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
@@ -108,16 +110,16 @@ class TlsConfig(Config):
certifcate_file.write(cert_pem)
if not os.path.exists(tls_dh_params_path):
if not os.path.exists(args.tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
"-outform", "PEM",
"-out", tls_dh_params_path,
"-out", args.tls_dh_params_path,
"2048"
])
else:
with open(tls_dh_params_path, "w") as dh_params_file:
with open(args.tls_dh_params_path, "w") as dh_params_file:
dh_params_file.write(
"2048-bit DH parameters taken from rfc3526\n"
"-----BEGIN DH PARAMETERS-----\n"

View File

@@ -17,21 +17,28 @@ from ._base import Config
class VoipConfig(Config):
def read_config(self, config):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"]
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def __init__(self, args):
super(VoipConfig, self).__init__(args)
self.turn_uris = args.turn_uris
self.turn_shared_secret = args.turn_shared_secret
self.turn_user_lifetime = args.turn_user_lifetime
def default_config(self, config_dir_path, server_name):
return """\
## Turn ##
# The public URIs of the TURN server to give to clients
turn_uris: []
# The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET"
# How long generated TURN credentials last
turn_user_lifetime: "1h"
"""
@classmethod
def add_arguments(cls, parser):
super(VoipConfig, cls).add_arguments(parser)
group = parser.add_argument_group("voip")
group.add_argument(
"--turn-uris", type=str, default=None, action='append',
help="The public URIs of the TURN server to give to clients"
)
group.add_argument(
"--turn-shared-secret", type=str, default=None,
help=(
"The shared secret used to compute passwords for the TURN"
" server"
)
)
group.add_argument(
"--turn-user-lifetime", type=int, default=(1000 * 60 * 60),
help="How long generated TURN credentials last, in ms"
)

View File

@@ -18,9 +18,7 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import (
preserve_context_over_fn, preserve_context_over_deferred
)
from synapse.util.logcontext import PreserveLoggingContext
import simplejson as json
import logging
@@ -42,14 +40,11 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5):
try:
protocol = yield preserve_context_over_fn(
endpoint.connect, factory
)
server_response, server_certificate = yield preserve_context_over_deferred(
protocol.remote_key
)
defer.returnValue((server_response, server_certificate))
return
with PreserveLoggingContext():
protocol = yield endpoint.connect(factory)
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
return
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):

View File

@@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util.async import ObservableDeferred
from synapse.util.async import create_observer
from OpenSSL import crypto
@@ -111,10 +111,6 @@ class Keyring(object):
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.key_downloads[server_name] = download
@download.addBoth
@@ -122,31 +118,30 @@ class Keyring(object):
del self.key_downloads[server_name]
return ret
r = yield download.observe()
r = yield create_observer(download)
defer.returnValue(r)
@defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids):
keys = None
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
logging.info(
"Unable to getting key %r for %r from %r: %s %s",
key_ids, server_name, perspective_name,
type(e).__name__, str(e.message),
)
perspective_results = []
for perspective_name, perspective_keys in self.perspective_servers.items():
@defer.inlineCallbacks
def get_key():
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except:
logging.info(
"Unable to getting key %r for %r from %r",
key_ids, server_name, perspective_name,
)
perspective_results.append(get_key())
perspective_results = yield defer.gatherResults([
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
])
perspective_results = yield defer.gatherResults(perspective_results)
for results in perspective_results:
if results is not None:
@@ -159,22 +154,17 @@ class Keyring(object):
)
with limiter:
if not keys:
if keys is None:
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logging.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
except:
pass
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
for key_id in key_ids:
if key_id in keys:
@@ -194,7 +184,7 @@ class Keyring(object):
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
query_response = yield self.client.post_json(
responses = yield self.client.post_json(
destination=perspective_name,
path=b"/_matrix/key/v2/query",
data={
@@ -210,8 +200,6 @@ class Keyring(object):
keys = {}
responses = query_response["server_keys"]
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
@@ -335,7 +323,7 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
for key_id in response_json["signatures"].get(server_name, {}):
for key_id in response_json["signatures"][server_name]:
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"

View File

@@ -16,12 +16,6 @@
from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting
# a dict to frozen_dicts is expensive.
USE_FROZEN_DICTS = True
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict)
@@ -128,10 +122,7 @@ class FrozenEvent(EventBase):
unsigned = dict(event_dict.pop("unsigned", {}))
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
frozen_dict = freeze(event_dict)
super(FrozenEvent, self).__init__(
frozen_dict,

View File

@@ -18,12 +18,12 @@ from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
import logging
@@ -78,7 +78,6 @@ class FederationBase(object):
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
timeout=10000,
)
if new_pdu:
@@ -95,7 +94,7 @@ class FederationBase(object):
yield defer.gatherResults(
[do(pdu) for pdu in pdus],
consumeErrors=True
).addErrback(unwrapFirstError)
)
defer.returnValue(signed_pdus)
@@ -118,15 +117,16 @@ class FederationBase(object):
)
except SynapseError:
logger.warn(
"Signature check failed for %s",
pdu.event_id,
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting.",
pdu.event_id,
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)

View File

@@ -22,7 +22,6 @@ from .units import Edu
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
@@ -165,17 +164,16 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
[self._check_sigs_and_hash(pdu) for pdu in pdus],
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, pdu in enumerate(pdus):
pdus[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
def get_pdu(self, destinations, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
servers.
@@ -191,8 +189,6 @@ class FederationClient(FederationBase):
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout (int): How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
Returns:
Deferred: Results in the requested PDU.
@@ -216,7 +212,7 @@ class FederationClient(FederationBase):
with limiter:
transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout,
destination, event_id
)
logger.debug("transaction_data %r", transaction_data)
@@ -226,7 +222,7 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
if pdu_list and pdu_list[0]:
if pdu_list:
pdu = pdu_list[0]
# Check signatures are correct.
@@ -259,7 +255,7 @@ class FederationClient(FederationBase):
)
continue
if self._get_pdu_cache is not None and pdu:
if self._get_pdu_cache is not None:
self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu)
@@ -374,17 +370,13 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
signed_state, signed_auth = yield defer.gatherResults(
[
self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
),
self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
],
consumeErrors=True
).addErrback(unwrapFirstError)
signed_state = yield self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
auth_chain.sort(key=lambda e: e.depth)
@@ -499,7 +491,7 @@ class FederationClient(FederationBase):
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False
destination, events, outlier=True
)
have_gotten_all_from_destination = True
@@ -526,7 +518,7 @@ class FederationClient(FederationBase):
# Are we missing any?
seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events if e)
seen_events.update(e.event_id for e in signed_events)
missing_events = {}
for e in itertools.chain(latest_events, signed_events):
@@ -569,7 +561,7 @@ class FederationClient(FederationBase):
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
if result:
signed_events.append(val)
else:
failed_to_fetch.add(e_id)

View File

@@ -20,6 +20,7 @@ from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import synapse.metrics
@@ -122,28 +123,29 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
results = []
with PreserveLoggingContext():
results = []
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try:
yield d
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
try:
yield d
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure)

View File

@@ -23,6 +23,8 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from syutil.jsonutil import encode_canonical_json
import logging
@@ -69,7 +71,7 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.origin,
code,
response,
encode_canonical_json(response)
)
@defer.inlineCallbacks
@@ -99,5 +101,5 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
response_code,
response_dict,
encode_canonical_json(response_dict)
)

View File

@@ -104,6 +104,7 @@ class TransactionQueue(object):
return not destination.startswith("localhost")
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
@@ -207,13 +208,13 @@ class TransactionQueue(object):
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
logger.info(
"TX [%s] Transaction already in progress",
destination
)
return
logger.debug("TX [%s] _attempt_new_transaction", destination)
logger.info("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@@ -221,11 +222,11 @@ class TransactionQueue(object):
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
logger.info("TX [%s] Nothing to send", destination)
return
# Sort based on the order field
@@ -242,8 +243,6 @@ class TransactionQueue(object):
try:
self.pending_transactions[destination] = 1
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
destination,
self._clock,
@@ -251,9 +250,9 @@ class TransactionQueue(object):
)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
@@ -263,7 +262,7 @@ class TransactionQueue(object):
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id,
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
@@ -277,13 +276,9 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
with limiter:
@@ -319,10 +314,7 @@ class TransactionQueue(object):
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)

View File

@@ -50,15 +50,13 @@ class TransportLayerClient(object):
)
@log_function
def get_event(self, destination, event_id, timeout=None):
def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
Deferred: Results in a dict received from the remote homeserver.
@@ -67,7 +65,7 @@ class TransportLayerClient(object):
destination, event_id)
path = PREFIX + "/event/%s/" % (event_id, )
return self.client.get_json(destination, path=path, timeout=timeout)
return self.client.get_json(destination, path=path)
@log_function
def backfill(self, destination, room_id, event_tuples, limit):

View File

@@ -93,8 +93,6 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin)
defer.returnValue((origin, content))
@log_function
@@ -198,14 +196,6 @@ class FederationSendServlet(BaseFederationServlet):
transaction_id, str(transaction_data)
)
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
# origin = body["origin"]

View File

@@ -20,8 +20,6 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID
from synapse.util.logcontext import PreserveLoggingContext
import logging
@@ -78,9 +76,7 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
builder.prev_state = context.prev_state_events
yield self.auth.add_auth_events(builder, context)
@@ -107,9 +103,7 @@ class BaseHandler(object):
if not suppress_auth:
self.auth.check(event, auth_events=context.current_state)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
yield self.store.persist_event(event, context=context)
federation_handler = self.hs.get_handlers().federation_handler
@@ -143,12 +137,10 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
# Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f):
logger.warn(
@@ -158,6 +150,8 @@ class BaseHandler(object):
notify_d.addErrback(log_failure)
federation_handler.handle_new_event(
fed_d = federation_handler.handle_new_event(
event, destinations=destinations,
)
fed_d.addErrback(log_failure)

View File

@@ -15,7 +15,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.appservice import ApplicationService
from synapse.types import UserID
@@ -147,7 +147,10 @@ class ApplicationServicesHandler(object):
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
member_list = yield self.store.get_room_members(
room_id=event.room_id,
membership=Membership.JOIN
)
services = yield self.store.get_app_services()
interested_list = [
@@ -177,7 +180,7 @@ class ApplicationServicesHandler(object):
return
user_info = yield self.store.get_user_by_id(user_id)
if not user_info:
if len(user_info) > 0:
defer.returnValue(False)
return

View File

@@ -159,7 +159,7 @@ class AuthHandler(BaseHandler):
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info["password_hash"]
stored_hash = user_info[0]["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
@@ -187,8 +187,8 @@ class AuthHandler(BaseHandler):
# each request
try:
client = SimpleHttpClient(self.hs)
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
data = yield client.post_urlencoded_get_json(
"https://www.google.com/recaptcha/api/siteverify",
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
@@ -198,8 +198,7 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = simplejson.loads(data)
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)

View File

@@ -22,7 +22,6 @@ from synapse.api.constants import EventTypes
from synapse.types import RoomAlias
import logging
import string
logger = logging.getLogger(__name__)
@@ -41,10 +40,6 @@ class DirectoryHandler(BaseHandler):
def _create_association(self, room_alias, room_id, servers=None):
# general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this.

View File

@@ -15,6 +15,7 @@
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
@@ -80,9 +81,10 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout
)
with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout
)
time_now = self.clock.time_msec()

View File

@@ -18,11 +18,9 @@
from ._base import BaseHandler
from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
AuthError, FederationError, StoreError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -31,8 +29,6 @@ from synapse.crypto.event_signing import (
)
from synapse.types import UserID
from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer
import itertools
@@ -77,6 +73,8 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, destinations):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
@@ -91,7 +89,9 @@ class FederationHandler(BaseHandler):
processing.
"""
return self.replication_layer.send_pdu(event, destinations)
yield run_on_reactor()
self.replication_layer.send_pdu(event, destinations)
@log_function
@defer.inlineCallbacks
@@ -160,7 +160,7 @@ class FederationHandler(BaseHandler):
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
yield self._handle_new_event(
origin,
event,
state=state,
@@ -201,11 +201,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f):
logger.warn(
@@ -224,254 +222,36 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]):
def backfill(self, dest, room_id, limit):
""" Trigger a backfill request to `dest` for the given `room_id`
"""
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill(
pdus = yield self.replication_layer.backfill(
dest,
room_id,
limit=limit,
limit,
extremities=extremities,
)
event_map = {e.event_id: e for e in events}
events = []
event_ids = set(e.event_id for e in events)
for pdu in pdus:
event = pdu
edges = [
ev.event_id
for ev in events
if set(e_id for e_id, _ in ev.prev_events) - event_ids
]
# FIXME (erikj): Not sure this actually works :/
context = yield self.state_handler.compute_event_context(event)
logger.info(
"backfill: Got %d events with %d edges",
len(events), len(edges),
)
events.append((event, context))
# For each edge get the current state.
auth_events = {}
state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
yield defer.gatherResults(
[
self._handle_new_event(
dest, a,
auth_events={
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
},
)
for a in auth_events.values()
if a.event_id not in seen_events
],
consumeErrors=True,
).addErrback(unwrapFirstError)
yield defer.gatherResults(
[
self._handle_new_event(
dest, event_map[e_id],
state=events_to_state[e_id],
backfilled=True,
auth_events={
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
},
)
for e_id in events_to_state
],
consumeErrors=True
).addErrback(unwrapFirstError)
events.sort(key=lambda e: e.depth)
for event in events:
if event in events_to_state:
continue
yield self._handle_new_event(
dest, event,
backfilled=True,
yield self.store.persist_event(
event,
context=context,
backfilled=True
)
defer.returnValue(events)
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
extremities = yield self.store.get_oldest_events_with_depth_in_room(
room_id
)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
return
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(
extremities.items(),
key=lambda e: -int(e[1])
)
max_depth = sorted_extremeties_tuple[0][1]
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
max_depth, current_depth,
)
return
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
joined_domains = {}
for u, d in joined_users:
try:
dom = UserID.from_string(u).domain
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
except:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains
if domain is not self.server_name
]
@defer.inlineCallbacks
def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
events = yield self.backfill(
dom, room_id,
limit=100,
extremities=[e for e in extremities.keys()]
)
except SynapseError:
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
dom, e,
)
continue
if events:
defer.returnValue(True)
defer.returnValue(False)
success = yield try_backfill(likely_domains)
if success:
defer.returnValue(True)
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
event_ids = list(extremities.keys())
states = yield defer.gatherResults([
self.state_handler.resolve_state_groups([e])
for e in event_ids
])
states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
dom for dom in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
tried_domains.update(likely_domains)
defer.returnValue(False)
@defer.inlineCallbacks
def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -600,14 +380,30 @@ class FederationHandler(BaseHandler):
# FIXME
pass
yield self._handle_auth_events(
origin, [e for e in auth_chain if e.event_id != event.event_id]
)
for e in auth_chain:
e.internal_metadata.outlier = True
@defer.inlineCallbacks
def handle_state(e):
if e.event_id == event.event_id:
return
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
for e in state:
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True
try:
@@ -625,15 +421,13 @@ class FederationHandler(BaseHandler):
e.event_id,
)
yield defer.DeferredList([handle_state(e) for e in state])
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
_, event_stream_id, max_stream_id = yield self._handle_new_event(
yield self._handle_new_event(
origin,
new_event,
state=state,
@@ -641,11 +435,9 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
new_event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
d = self.notifier.on_new_room_event(
new_event, extra_users=[joinee]
)
def log_failure(f):
logger.warn(
@@ -710,9 +502,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
context = yield self._handle_new_event(origin, event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -726,10 +516,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f):
logger.warn(
@@ -802,18 +591,16 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event)
event_stream_id, max_stream_id = yield self.store.persist_event(
yield self.store.persist_event(
event,
context=context,
backfilled=False,
)
target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
d = self.notifier.on_new_room_event(
event, extra_users=[target_user],
)
def log_failure(f):
logger.warn(
@@ -945,10 +732,8 @@ class FederationHandler(BaseHandler):
event.event_id, event.signatures,
)
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier,
event, old_state=state
)
if not auth_events:
@@ -959,17 +744,14 @@ class FederationHandler(BaseHandler):
event.event_id, auth_events,
)
is_new_state = not outlier
is_new_state = not event.internal_metadata.is_outlier()
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1 and event.depth < 5:
c = yield self.store.get_event(
event.prev_events[0][0],
allow_none=True,
)
if c and c.type == EventTypes.Create:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try:
@@ -995,7 +777,7 @@ class FederationHandler(BaseHandler):
)
raise
event_stream_id, max_stream_id = yield self.store.persist_event(
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
@@ -1003,7 +785,7 @@ class FederationHandler(BaseHandler):
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
defer.returnValue(context)
@defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
@@ -1143,7 +925,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d]
],
consumeErrors=True
).addErrback(unwrapFirstError)
)
if different_events:
local_view = dict(auth_events)
@@ -1388,52 +1170,3 @@ class FederationHandler(BaseHandler):
},
"missing": [e.event_id for e in missing_locals],
})
@defer.inlineCallbacks
def _handle_auth_events(self, origin, auth_events):
auth_ids_to_deferred = {}
def process_auth_ev(ev):
auth_ids = [e_id for e_id, _ in ev.auth_events]
prev_ds = [
auth_ids_to_deferred[i]
for i in auth_ids
if i in auth_ids_to_deferred
]
d = defer.Deferred()
auth_ids_to_deferred[ev.event_id] = d
@defer.inlineCallbacks
def f(*_):
ev.internal_metadata.outlier = True
try:
auth = {
(e.type, e.state_key): e for e in auth_events
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, ev, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
ev.event_id,
)
d.callback(None)
if prev_ds:
dx = defer.DeferredList(prev_ds)
dx.addBoth(f)
else:
f()
for e in auth_events:
process_auth_ev(e)
yield defer.DeferredList(auth_ids_to_deferred.values())

View File

@@ -20,9 +20,8 @@ from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID, RoomStreamToken
from synapse.types import UserID
from ._base import BaseHandler
@@ -90,19 +89,9 @@ class MessageHandler(BaseHandler):
if not pagin_config.from_token:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
yield self.hs.get_event_sources().get_current_token()
)
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
)
user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows(
@@ -261,31 +250,47 @@ class MessageHandler(BaseHandler):
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
start_time = self.clock.time_msec()
def delta():
return self.clock.time_msec() - start_time
logger.info("initial_sync: start")
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id,
membership_list=[Membership.INVITE, Membership.JOIN]
)
logger.info("initial_sync: got_rooms %d", delta())
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
logger.info("initial_sync: now_token %d", delta())
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
logger.info("initial_sync: presence_done %d", delta())
public_room_ids = yield self.store.get_public_room_ids()
logger.info("initial_sync: public_rooms %d", delta())
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
logger.info("initial_sync: start: %s %d", event.room_id, delta())
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -314,7 +319,7 @@ class MessageHandler(BaseHandler):
event.room_id
),
]
).addErrback(unwrapFirstError)
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
@@ -336,10 +341,14 @@ class MessageHandler(BaseHandler):
except:
logger.exception("Failed to get snapshot")
logger.info("initial_sync: end: %s %d", event.room_id, delta())
yield defer.gatherResults(
[handle_room(e) for e in room_list],
consumeErrors=True
).addErrback(unwrapFirstError)
)
logger.info("initial_sync: done", delta())
ret = {
"rooms": rooms_ret,

View File

@@ -18,8 +18,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import synapse.metrics
@@ -146,10 +146,6 @@ class PresenceHandler(BaseHandler):
self._user_cachemap = {}
self._user_cachemap_latest_serial = 0
# map room_ids to the latest presence serial for a member of that
# room
self._room_serials = {}
metrics.register_callback(
"userCachemap:size",
lambda: len(self._user_cachemap),
@@ -282,14 +278,15 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
self.stop_polling_presence(target_user)
with PreserveLoggingContext():
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
self.changed_presencelike_data(target_user, state)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None):
if now is None:
@@ -301,34 +298,13 @@ class PresenceHandler(BaseHandler):
self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
Args:
user(UserID): The user.
Returns:
A Deferred of a list of room id strings.
"""
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_joined_rooms_for_user(user)
def get_joined_users_for_room_id(self, room_id):
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_room_members(room_id)
@defer.inlineCallbacks
def changed_presencelike_data(self, user, state):
"""Updates the presence state of a local user.
statuscache = self._get_or_make_usercache(user)
Args:
user(UserID): The user being updated.
state(dict): The new presence state for the user.
Returns:
A Deferred
"""
self._user_cachemap_latest_serial += 1
statuscache = yield self.update_presence_cache(user, state)
yield self.push_presence(user, statuscache=statuscache)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
return self.push_presence(user, statuscache=statuscache)
@log_function
def started_user_eventstream(self, user):
@@ -342,21 +318,14 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def user_joined_room(self, user, room_id):
"""Called via the distributor whenever a user joins a room.
Notifies the new member of the presence of the current members.
Notifies the current members of the room of the new member's presence.
Args:
user(UserID): The user who joined the room.
room_id(str): The room id the user joined.
"""
if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user)
# No actual update but we need to bump the serial anyway for the
# event source
self._user_cachemap_latest_serial += 1
statuscache = yield self.update_presence_cache(
user, room_ids=[room_id]
)
statuscache.update({}, serial=self._user_cachemap_latest_serial)
self.push_update_to_local_and_remote(
observed_user=user,
room_ids=[room_id],
@@ -364,22 +333,18 @@ class PresenceHandler(BaseHandler):
)
# We also want to tell them about current presence of people.
curr_users = yield self.get_joined_users_for_room_id(room_id)
rm_handler = self.homeserver.get_handlers().room_member_handler
curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
statuscache = yield self.update_presence_cache(
local_user, room_ids=[room_id], add_to_cache=False
)
self.push_update_to_local_and_remote(
observed_user=local_user,
users_to_push=[user],
statuscache=statuscache,
statuscache=self._get_or_offline_usercache(local_user),
)
@defer.inlineCallbacks
def send_invite(self, observer_user, observed_user):
"""Request the presence of a local or remote user for a local user"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
@@ -414,15 +379,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def invite_presence(self, observed_user, observer_user):
"""Handles a m.presence_invite EDU. A remote or local user has
requested presence updates for a local user. If the invite is accepted
then allow the local or remote user to see the presence of the local
user.
Args:
observed_user(UserID): The local user whose presence is requested.
observer_user(UserID): The remote or local user requesting presence.
"""
accept = yield self._should_accept_invite(observed_user, observer_user)
if accept:
@@ -449,34 +405,16 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def accept_presence(self, observed_user, observer_user):
"""Handles a m.presence_accept EDU. Mark a presence invite from a
local or remote user as accepted in a local user's presence list.
Starts polling for presence updates from the local or remote user.
Args:
observed_user(UserID): The user to update in the presence list.
observer_user(UserID): The owner of the presence list to update.
"""
yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string()
)
self.start_polling_presence(
observer_user, target_user=observed_user
)
with PreserveLoggingContext():
self.start_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user):
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
local user's presence list.
Args:
observed_user(UserID): The local or remote user to remove from the
list.
observer_user(UserID): The local owner of the presence list.
Returns:
A Deferred.
"""
yield self.store.del_presence_list(
observer_user.localpart, observed_user.to_string()
)
@@ -485,16 +423,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def drop(self, observed_user, observer_user):
"""Remove a local or remote user from a local user's presence list and
unsubscribe the local user from updates that user.
Args:
observed_user(UserId): The local or remote user to remove from the
list.
observer_user(UserId): The local owner of the presence list.
Returns:
A Deferred.
"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
@@ -502,66 +430,34 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
self.stop_polling_presence(
observer_user, target_user=observed_user
)
with PreserveLoggingContext():
self.stop_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
"""Get the presence list for a local user. The retured list includes
the current presence state for each user listed.
Args:
observer_user(UserID): The local user whose presence list to fetch.
accepted(bool or None): If not none then only include users who
have or have not accepted the presence invite request.
Returns:
A Deferred list of presence state events.
"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list(
presence = yield self.store.get_presence_list(
observer_user.localpart, accepted=accepted
)
results = []
for row in presence_list:
observed_user = UserID.from_string(row["observed_user_id"])
result = {
"observed_user": observed_user, "accepted": row["accepted"]
}
result.update(
self._get_or_offline_usercache(observed_user).get_state()
)
if "last_active" in result:
result["last_active_ago"] = int(
self.clock.time_msec() - result.pop("last_active")
for p in presence:
observed_user = UserID.from_string(p.pop("observed_user_id"))
p["observed_user"] = observed_user
p.update(self._get_or_offline_usercache(observed_user).get_state())
if "last_active" in p:
p["last_active_ago"] = int(
self.clock.time_msec() - p.pop("last_active")
)
results.append(result)
defer.returnValue(results)
defer.returnValue(presence)
@defer.inlineCallbacks
@log_function
def start_polling_presence(self, user, target_user=None, state=None):
"""Subscribe a local user to presence updates from a local or remote
user. If no target_user is supplied then subscribe to all users stored
in the presence list for the local user.
Additonally this pushes the current presence state of this user to all
target_users. That state can be provided directly or will be read from
the stored state for the local user.
Also this attempts to notify the local user of the current state of
any local target users.
Args:
user(UserID): The local user that whishes for presence updates.
target_user(UserID): The local or remote user whose updates are
wanted.
state(dict): Optional presence state for the local user.
"""
logger.debug("Start polling for presence from %s", user)
if target_user:
@@ -577,7 +473,8 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if state is None:
state = yield self.store.get_presence_state(user.localpart)
@@ -601,7 +498,9 @@ class PresenceHandler(BaseHandler):
# We want to tell the person that just came online
# presence state of people they are interested in?
self.push_update_to_clients(
observed_user=target_user,
users_to_push=[user],
statuscache=self._get_or_offline_usercache(target_user),
)
deferreds = []
@@ -618,12 +517,6 @@ class PresenceHandler(BaseHandler):
yield defer.DeferredList(deferreds, consumeErrors=True)
def _start_polling_local(self, user, target_user):
"""Subscribe a local user to presence updates for a local user
Args:
user(UserId): The local user that wishes for updates.
target_user(UserId): The local users whose updates are wanted.
"""
target_localpart = target_user.localpart
if target_localpart not in self._local_pushmap:
@@ -632,17 +525,6 @@ class PresenceHandler(BaseHandler):
self._local_pushmap[target_localpart].add(user)
def _start_polling_remote(self, user, domain, remoteusers):
"""Subscribe a local user to presence updates for remote users on a
given remote domain.
Args:
user(UserID): The local user that wishes for updates.
domain(str): The remote server the local user wants updates from.
remoteusers(UserID): The remote users that local user wants to be
told about.
Returns:
A Deferred.
"""
to_poll = set()
for u in remoteusers:
@@ -663,17 +545,6 @@ class PresenceHandler(BaseHandler):
@log_function
def stop_polling_presence(self, user, target_user=None):
"""Unsubscribe a local user from presence updates from a local or
remote user. If no target user is supplied then unsubscribe the user
from all presence updates that the user had subscribed to.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID or None): The user whose updates are no longer
wanted.
Returns:
A Deferred.
"""
logger.debug("Stop polling for presence from %s", user)
if not target_user or self.hs.is_mine(target_user):
@@ -702,13 +573,6 @@ class PresenceHandler(BaseHandler):
return defer.DeferredList(deferreds, consumeErrors=True)
def _stop_polling_local(self, user, target_user):
"""Unsubscribe a local user from presence updates from a local user on
this server.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID): The user whose updates are no longer wanted.
"""
for localpart in self._local_pushmap.keys():
if target_user and localpart != target_user.localpart:
continue
@@ -721,17 +585,6 @@ class PresenceHandler(BaseHandler):
@log_function
def _stop_polling_remote(self, user, domain, remoteusers):
"""Unsubscribe a local user from presence updates from remote users on
a given domain.
Args:
user(UserID): The local user that no longer wishes for updates.
domain(str): The remote server to unsubscribe from.
remoteusers([UserID]): The users on that remote server that the
local user no longer wishes to be updated about.
Returns:
A Deferred.
"""
to_unpoll = set()
for u in remoteusers:
@@ -753,19 +606,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def push_presence(self, user, statuscache):
"""
Notify local and remote users of a change in presence of a local user.
Pushes the update to local clients and remote domains that are directly
subscribed to the presence of the local user.
Also pushes that update to any local user or remote domain that shares
a room with the local user.
Args:
user(UserID): The local user whose presence was updated.
statuscache(UserPresenceCache): Cache of the user's presence state
Returns:
A Deferred.
"""
assert(self.hs.is_mine(user))
logger.debug("Pushing presence update from %s", user)
@@ -777,7 +617,8 @@ class PresenceHandler(BaseHandler):
# and also user is informed of server-forced pushes
localusers.add(user)
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if not localusers and not room_ids:
defer.returnValue(None)
@@ -792,23 +633,44 @@ class PresenceHandler(BaseHandler):
yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
"""Handle an incoming m.presence EDU.
For each presence update in the "push" list update our local cache and
notify the appropriate local clients. Only clients that share a room
or are directly subscribed to the presence for a user should be
notified of the update.
For each subscription request in the "poll" list start pushing presence
updates to the remote server.
For unsubscribe request in the "unpoll" list stop pushing presence
updates to the remote server.
def _push_presence_remote(self, user, destination, state=None):
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
Args:
orgin(str): The source of this m.presence EDU.
content(dict): The content of this m.presence EDU.
Returns:
A Deferred.
"""
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {
"user_id": user.to_string(),
}
user_state.update(**state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={
"push": [
user_state,
],
}
)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
deferreds = []
for push in content.get("push", []):
@@ -822,7 +684,8 @@ class PresenceHandler(BaseHandler):
" | %d interested local observers %r", len(observers), observers
)
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
@@ -841,15 +704,20 @@ class PresenceHandler(BaseHandler):
self.clock.time_msec() - state.pop("last_active_ago")
)
statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1
yield self.update_presence_cache(user, state, room_ids=room_ids)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs")
continue
self.push_update_to_clients(
users_to_push=observers, room_ids=room_ids
observed_user=user,
users_to_push=observers,
room_ids=room_ids,
statuscache=statuscache,
)
user_id = user.to_string()
@@ -898,58 +766,13 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]:
del self._remote_sendmap[user]
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def update_presence_cache(self, user, state={}, room_ids=None,
add_to_cache=True):
"""Update the presence cache for a user with a new state and bump the
serial to the latest value.
Args:
user(UserID): The user being updated
state(dict): The presence state being updated
room_ids(None or list of str): A list of room_ids to update. If
room_ids is None then fetch the list of room_ids the user is
joined to.
add_to_cache: Whether to add an entry to the presence cache if the
user isn't already in the cache.
Returns:
A Deferred UserPresenceCache for the user being updated.
"""
if room_ids is None:
room_ids = yield self.get_joined_rooms_for_user(user)
for room_id in room_ids:
self._room_serials[room_id] = self._user_cachemap_latest_serial
if add_to_cache:
statuscache = self._get_or_make_usercache(user)
else:
statuscache = self._get_or_offline_usercache(user)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
defer.returnValue(statuscache)
with PreserveLoggingContext():
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache,
users_to_push=[], room_ids=[],
remote_domains=[]):
"""Notify local clients and remote servers of a change in the presence
of a user.
Args:
observed_user(UserID): The user to push the presence state for.
statuscache(UserPresenceCache): The cache for the presence state to
push.
users_to_push([UserID]): A list of local and remote users to
notify.
room_ids([str]): Notify the local and remote occupants of these
rooms.
remote_domains([str]): A list of remote servers to notify in
addition to those implied by the users_to_push and the
room_ids.
Returns:
A Deferred.
"""
localusers, remoteusers = partitionbool(
users_to_push,
@@ -959,7 +782,10 @@ class PresenceHandler(BaseHandler):
localusers = set(localusers)
self.push_update_to_clients(
users_to_push=localusers, room_ids=room_ids
observed_user=observed_user,
users_to_push=localusers,
room_ids=room_ids,
statuscache=statuscache,
)
remote_domains = set(remote_domains)
@@ -984,65 +810,11 @@ class PresenceHandler(BaseHandler):
defer.returnValue((localusers, remote_domains))
def push_update_to_clients(self, users_to_push=[], room_ids=[]):
"""Notify clients of a new presence event.
Args:
users_to_push([UserID]): List of users to notify.
room_ids([str]): List of room_ids to notify.
"""
with PreserveLoggingContext():
self.notifier.on_new_user_event(
"presence_key",
self._user_cachemap_latest_serial,
users_to_push,
room_ids,
)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
"""Push a user's presence to a remote server. If a presence state event
that event is sent. Otherwise a new state event is constructed from the
stored presence state.
The last_active is replaced with last_active_ago in case the wallclock
time on the remote server is different to the time on this server.
Sends an EDU to the remote server with the current presence state.
Args:
user(UserID): The user to push the presence state for.
destination(str): The remote server to send state to.
state(dict): The state to push, or None to use the current stored
state.
Returns:
A Deferred.
"""
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {"user_id": user.to_string(), }
user_state.update(state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={"push": [user_state, ], }
def push_update_to_clients(self, observed_user, users_to_push=[],
room_ids=[], statuscache=None):
self.notifier.on_new_user_event(
users_to_push,
room_ids,
)
@@ -1051,11 +823,39 @@ class PresenceEventSource(object):
self.hs = hs
self.clock = hs.get_clock()
@defer.inlineCallbacks
def is_visible(self, observer_user, observed_user):
if observer_user == observed_user:
defer.returnValue(True)
presence = self.hs.get_handlers().presence_handler
if (yield presence.store.user_rooms_intersect(
[u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap
defer.returnValue(
observed_user.localpart in pushmap and
observer_user in pushmap[observed_user.localpart]
)
else:
recvmap = presence._remote_recvmap
defer.returnValue(
observed_user in recvmap and
observer_user in recvmap[observed_user]
)
@defer.inlineCallbacks
@log_function
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
observer_user = user
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
@@ -1064,27 +864,17 @@ class PresenceEventSource(object):
clock = self.clock
latest_serial = 0
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
for observed_user in user_ids_to_check & set(cachemap):
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
cached = cachemap[observed_user]
if cached.serial <= from_key or cached.serial > max_serial:
continue
if not (yield self.is_visible(observer_user, observed_user)):
continue
latest_serial = max(cached.serial, latest_serial)
updates.append(cached.make_event(user=observed_user, clock=clock))
@@ -1121,6 +911,8 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering?
observer_user = user
from_key = int(pagination_config.from_key)
if pagination_config.to_key:
@@ -1131,26 +923,14 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] >= from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
for observed_user in user_ids_to_check & set(cachemap):
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
if not (to_key < cachemap[observed_user].serial <= from_key):
continue
updates.append((observed_user, cachemap[observed_user]))
if (yield self.is_visible(observer_user, observed_user)):
updates.append((observed_user, cachemap[observed_user]))
# TODO(paul): limit

View File

@@ -17,8 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from synapse.util import unwrapFirstError
from ._base import BaseHandler
@@ -88,9 +88,6 @@ class ProfileHandler(BaseHandler):
if target_user != auth_user:
raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '':
new_displayname = None
yield self.store.set_profile_displayname(
target_user.localpart, new_displayname
)
@@ -157,13 +154,14 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user):
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
).addErrback(unwrapFirstError)
with PreserveLoggingContext():
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
)
state["displayname"] = displayname
state["avatar_url"] = avatar_url

View File

@@ -21,12 +21,11 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError
from synapse.util import stringutils
from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
import logging
import string
logger = logging.getLogger(__name__)
@@ -51,10 +50,6 @@ class RoomCreationHandler(BaseHandler):
self.ratelimit(user_id)
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias.create(
config["room_alias_name"],
self.hs.hostname,
@@ -534,17 +529,11 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True)
results = yield defer.gatherResults(
[
self.store.get_users_in_room(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, room in enumerate(chunk):
room["num_joined_members"] = len(results[i])
for room in chunk:
joined_users = yield self.store.get_users_in_room(
room_id=room["room_id"],
)
room["num_joined_members"] = len(joined_users)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
@@ -580,8 +569,8 @@ class RoomEventSource(object):
defer.returnValue((events, end_key))
def get_current_key(self, direction='f'):
return self.store.get_room_events_max_id(direction)
def get_current_key(self):
return self.store.get_room_events_max_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):

View File

@@ -92,7 +92,7 @@ class SyncHandler(BaseHandler):
result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result)
else:
def current_sync_callback(before_token, after_token):
def current_sync_callback():
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler

View File

@@ -18,7 +18,6 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import logging
@@ -217,10 +216,7 @@ class TypingNotificationHandler(BaseHandler):
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
with PreserveLoggingContext():
self.notifier.on_new_user_event(
"typing_key", self._latest_room_serial, rooms=[room_id]
)
self.notifier.on_new_user_event(rooms=[room_id])
class TypingNotificationEventSource(object):

View File

@@ -14,14 +14,12 @@
# limitations under the License.
from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn
from syutil.jsonutil import encode_canonical_json
import synapse.metrics
from twisted.internet import defer, reactor
from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool,
Agent, readBody, FileBodyProducer, PartialDownloadError
)
from twisted.web.http_headers import Headers
@@ -56,19 +54,14 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = Agent(reactor, pool=pool)
self.agent = Agent(reactor)
self.version_string = hs.version_string
def request(self, method, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.inc(method)
d = preserve_context_over_fn(
self.agent.request,
method, *args, **kwargs
)
d = self.agent.request(method, *args, **kwargs)
def _cb(response):
incoming_responses_counter.inc(method, response.code)

View File

@@ -16,13 +16,13 @@
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logcontext import PreserveLoggingContext
import synapse.metrics
from syutil.jsonutil import encode_canonical_json
@@ -103,17 +103,14 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
self.agent = MatrixFederationHttpAgent(reactor)
self.clock = hs.get_clock()
self.version_string = hs.version_string
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True,
timeout=None):
query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url
"""
headers_dict[b"User-Agent"] = [self.version_string]
@@ -147,22 +144,22 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, url_bytes, headers_dict)
try:
request_deferred = preserve_context_over_fn(
self.agent.request,
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
with PreserveLoggingContext():
request_deferred = self.agent.request(
destination,
endpoint,
method,
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
response = yield self.clock.time_bound_deferred(
request_deferred,
time_out=timeout/1000. if timeout else 60,
)
response = yield self.clock.time_bound_deferred(
request_deferred,
time_out=60,
)
logger.debug("Got response to %s", method)
break
@@ -184,7 +181,7 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e),
)
if retries_left and not timeout:
if retries_left:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
@@ -337,8 +334,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
timeout=None):
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" GETs some json from the given host homeserver and path
Args:
@@ -347,9 +343,6 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to
None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
@@ -377,8 +370,7 @@ class MatrixFederationHttpClient(object):
path.encode("ascii"),
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
retry_on_dns_fail=retry_on_dns_fail
)
if 200 <= response.code < 300:

View File

@@ -17,12 +17,11 @@
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.logcontext import LoggingContext
import synapse.metrics
import synapse.events
from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json, encode_json
encode_canonical_json, encode_pretty_printed_json
)
from twisted.internet import defer
@@ -86,9 +85,7 @@ def request_handler(request_handler):
"Received request: %s %s",
request.method, request.path
)
d = request_handler(self, request)
with PreserveLoggingContext():
yield d
yield request_handler(self, request)
code = request.code
except CodeMessageException as e:
code = e.code
@@ -169,10 +166,9 @@ class JsonResource(HttpServer, resource.Resource):
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
def __init__(self, hs, canonical_json=True):
def __init__(self, hs):
resource.Resource.__init__(self)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.version_string = hs.version_string
@@ -258,7 +254,6 @@ class JsonResource(HttpServer, resource.Resource):
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
canonical_json=self.canonical_json,
)
@@ -280,16 +275,11 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string="", canonical_json=True):
version_string=""):
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
if canonical_json:
json_bytes = encode_canonical_json(json_object)
else:
json_bytes = encode_json(
json_object, using_frozen_dicts=synapse.events.USE_FROZEN_DICTS
)
json_bytes = encode_canonical_json(json_object)
return respond_with_json_bytes(
request, code, json_bytes,

View File

@@ -16,6 +16,7 @@
from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.types import StreamToken
import synapse.metrics
@@ -42,78 +43,63 @@ def count(func, l):
class _NotificationListener(object):
""" This represents a single client connection to the events stream.
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
def __init__(self, deferred):
self.deferred = deferred
def notified(self):
return self.deferred.called
def notify(self, token):
""" Inform whoever is listening about the new events.
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
class _NotifierUserStream(object):
"""This represents a user connected to the event stream.
It tracks the most recent stream token for that user.
At a given point a user may have a number of streams listening for
events.
This listener will also keep track of which rooms it is listening in
so that it can remove itself from the indexes in the Notifier class.
"""
def __init__(self, user, rooms, current_token, time_now_ms,
def __init__(self, user, rooms, from_token, limit, timeout, deferred,
appservice=None):
self.user = str(user)
self.user = user
self.appservice = appservice
self.listeners = set()
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
self.from_token = from_token
self.limit = limit
self.timeout = timeout
self.deferred = deferred
self.rooms = rooms
self.timer = None
def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an
event source.
Args:
stream_key(str): The stream the event came from.
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
"""
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
if self.listeners:
self.last_notified_ms = time_now_ms
listeners = self.listeners
self.listeners = set()
for listener in listeners:
listener.notify(self.current_token)
def notified(self):
return self.deferred.called
def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier
def notify(self, notifier, events, start_token, end_token):
""" Inform whoever is listening about the new events. This will
also remove this listener from all the indexes in the Notifier
it knows about.
"""
result = (events, (start_token, end_token))
try:
self.deferred.callback(result)
notified_events_counter.inc_by(len(events))
except defer.AlreadyCalledError:
pass
# Should the following be done be using intrusively linked lists?
# -- erikj
for room in self.rooms:
lst = notifier.room_to_user_streams.get(room, set())
lst = notifier.room_to_listeners.get(room, set())
lst.discard(self)
notifier.user_to_user_stream.pop(self.user)
notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice:
notifier.appservice_to_user_streams.get(
notifier.appservice_to_listeners.get(
self.appservice, set()
).discard(self)
# Cancel the timeout for this notifer if one exists.
if self.timer is not None:
try:
notifier.clock.cancel_call_later(self.timer)
except:
logger.warn("Failed to cancel notifier timer")
class Notifier(object):
""" This class is responsible for notifying any listeners when there are
@@ -122,18 +108,14 @@ class Notifier(object):
Primarily used from the /events stream.
"""
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs):
self.hs = hs
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.room_to_listeners = {}
self.user_to_listeners = {}
self.appservice_to_listeners = {}
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
self.clock = hs.get_clock()
@@ -141,80 +123,47 @@ class Notifier(object):
"user_joined_room", self._user_joined_room
)
self.clock.looping_call(
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
)
# This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
def count_listeners():
all_user_streams = set()
all_listeners = set()
for x in self.room_to_user_streams.values():
all_user_streams |= x
for x in self.user_to_user_stream.values():
all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
for x in self.room_to_listeners.values():
all_listeners |= x
for x in self.user_to_listeners.values():
all_listeners |= x
for x in self.appservice_to_listeners.values():
all_listeners |= x
return sum(len(stream.listeners) for stream in all_user_streams)
return len(all_listeners)
metrics.register_callback("listeners", count_listeners)
metrics.register_callback(
"rooms",
lambda: count(bool, self.room_to_user_streams.values()),
lambda: count(bool, self.room_to_listeners.values()),
)
metrics.register_callback(
"users",
lambda: len(self.user_to_user_stream),
lambda: count(bool, self.user_to_listeners.values()),
)
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
lambda: count(bool, self.appservice_to_listeners.values()),
)
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
def on_new_room_event(self, event, extra_users=[]):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
This triggers the notifier to wake up any listeners that are
listening to the room, and any listeners for the users in the
`extra_users` param.
The events can be peristed out of order. The notifier will wait
until all previous events have been persisted before notifying
the client streams.
"""
yield run_on_reactor()
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
self._notify_pending_new_room_events(max_room_stream_id)
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
max_room_stream_id(int): The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
self.pending_new_room_events = []
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
else:
self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services(
event
@@ -222,129 +171,194 @@ class Notifier(object):
room_id = event.room_id
room_user_streams = self.room_to_user_streams.get(room_id, set())
room_source = self.event_sources.sources["room"]
user_streams = room_user_streams.copy()
room_listeners = self.room_to_listeners.get(room_id, set())
_discard_if_notified(room_listeners)
listeners = room_listeners.copy()
for user in extra_users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
user_listeners = self.user_to_listeners.get(user, set())
for appservice in self.appservice_to_user_streams:
_discard_if_notified(user_listeners)
listeners |= user_listeners
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# App services will already be in the room_to_listeners set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
# make the room_to_listeners check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
app_listeners = self.appservice_to_listeners.get(
appservice, set()
)
user_streams |= app_user_streams
logger.debug("on_new_room_event listeners %s", user_streams)
_discard_if_notified(app_listeners)
time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
try:
user_stream.notify(
"room_key", "s%d" % (room_stream_id,), time_now_ms
listeners |= app_listeners
logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the
# db once?
@defer.inlineCallbacks
def notify(listener):
events, end_key = yield room_source.get_new_events_for_user(
listener.user,
listener.from_token.room_key,
listener.limit,
)
if events:
end_token = listener.from_token.copy_and_replace(
"room_key", end_key
)
except:
logger.exception("Failed to notify listener")
listener.notify(
self, events, listener.from_token, end_token
)
def eb(failure):
logger.exception("Failed to notify listener", failure)
with PreserveLoggingContext():
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks
@log_function
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
def on_new_user_event(self, users=[], rooms=[]):
""" Used to inform listeners that something has happend
presence/user event wise.
Will wake up all listeners for the given users and rooms.
"""
yield run_on_reactor()
user_streams = set()
# TODO(paul): This is horrible, having to manually list every event
# source here individually
presence_source = self.event_sources.sources["presence"]
typing_source = self.event_sources.sources["typing"]
listeners = set()
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
user_listeners = self.user_to_listeners.get(user, set())
_discard_if_notified(user_listeners)
listeners |= user_listeners
for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set())
room_listeners = self.room_to_listeners.get(room, set())
time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
except:
logger.exception("Failed to notify listener")
_discard_if_notified(room_listeners)
listeners |= room_listeners
@defer.inlineCallbacks
def notify(listener):
presence_events, presence_end_key = (
yield presence_source.get_new_events_for_user(
listener.user,
listener.from_token.presence_key,
listener.limit,
)
)
typing_events, typing_end_key = (
yield typing_source.get_new_events_for_user(
listener.user,
listener.from_token.typing_key,
listener.limit,
)
)
if presence_events or typing_events:
end_token = listener.from_token.copy_and_replace(
"presence_key", presence_end_key
).copy_and_replace(
"typing_key", typing_end_key
)
listener.notify(
self,
presence_events + typing_events,
listener.from_token,
end_token
)
def eb(failure):
logger.error(
"Failed to notify listener",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject())
)
with PreserveLoggingContext():
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0")):
def wait_for_events(self, user, rooms, filter, timeout, callback):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
deferred = defer.Deferred()
time_now_ms = self.clock.time_msec()
user = str(user)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(user)
rooms = [room.room_id for room in rooms]
user_stream = _NotifierUserStream(
user=user,
rooms=rooms,
appservice=appservice,
current_token=current_token,
time_now_ms=time_now_ms,
)
self._register_with_keys(user_stream)
else:
current_token = user_stream.current_token
from_token = StreamToken("s0", "0", "0")
listener = [_NotificationListener(deferred)]
listener = [_NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)]
if timeout and not current_token.is_after(from_token):
user_stream.listeners.add(listener[0])
if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
else:
result = None
if timeout:
self._register_with_keys(listener[0])
result = yield callback()
timer = [None]
if result:
user_stream.listeners.discard(listener[0])
defer.returnValue(result)
return
if timeout:
timed_out = [False]
def _timeout_listener():
timed_out[0] = True
timer[0] = None
user_stream.listeners.discard(listener[0])
listener[0].notify(current_token)
listener[0].notify(self, [], from_token, from_token)
# We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
new_token = yield deferred
yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(deferred)
user_stream.listeners.add(listener[0])
result = yield callback(current_token, new_token)
current_token = new_token
listener[0] = _NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)
self._register_with_keys(listener[0])
result = yield callback()
if timer[0] is not None:
try:
@@ -354,79 +368,125 @@ class Notifier(object):
defer.returnValue(result)
@defer.inlineCallbacks
def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
"""
from_token = pagination_config.from_token
deferred = defer.Deferred()
self._get_events(
deferred, user, rooms, pagination_config.from_token,
pagination_config.limit, timeout
).addErrback(deferred.errback)
return deferred
@defer.inlineCallbacks
def _get_events(self, deferred, user, rooms, from_token, limit, timeout):
if not from_token:
from_token = yield self.event_sources.get_current_token()
limit = pagination_config.limit
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
events = []
end_token = from_token
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
if before_id == after_id:
continue
stuff, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit,
)
events.extend(stuff)
end_token = end_token.copy_and_replace(keyname, new_key)
if events:
defer.returnValue((events, (from_token, end_token)))
else:
defer.returnValue(None)
result = yield self.wait_for_events(
user, rooms, timeout, check_for_updates, from_token=from_token
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
user.to_string()
)
if result is None:
result = ([], (from_token, from_token))
listener = _NotificationListener(
user,
rooms,
from_token,
limit,
timeout,
deferred,
appservice=appservice
)
defer.returnValue(result)
def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token.
# Remove the timer from the listener so we don't try to cancel it.
listener.timer = None
listener.notify(
self,
[],
listener.from_token,
listener.from_token,
)
if timeout:
self._register_with_keys(listener)
yield self._check_for_updates(listener)
if not timeout:
_timeout_listener()
else:
# Only add the timer if the listener hasn't been notified
if not listener.notified():
listener.timer = self.clock.call_later(
timeout/1000.0, _timeout_listener
)
return
@log_function
def remove_expired_streams(self):
time_now_ms = self.clock.time_msec()
expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values():
if stream.listeners:
continue
if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream)
def _register_with_keys(self, listener):
for room in listener.rooms:
s = self.room_to_listeners.setdefault(room, set())
s.add(listener)
for expired_stream in expired_streams:
expired_stream.remove(self)
self.user_to_listeners.setdefault(listener.user, set()).add(listener)
if listener.appservice:
self.appservice_to_listeners.setdefault(
listener.appservice, set()
).add(listener)
@defer.inlineCallbacks
@log_function
def _register_with_keys(self, user_stream):
self.user_to_user_stream[user_stream.user] = user_stream
def _check_for_updates(self, listener):
# TODO (erikj): We need to think about limits across multiple sources
events = []
for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
from_token = listener.from_token
limit = listener.limit
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
# TODO (erikj): DeferredList?
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
stuff, new_key = yield source.get_new_events_for_user(
listener.user,
getattr(from_token, keyname),
limit,
)
events.extend(stuff)
from_token = from_token.copy_and_replace(keyname, new_key)
end_token = from_token
if events:
listener.notify(self, events, listener.from_token, end_token)
defer.returnValue(listener)
def _user_joined_room(self, user, room_id):
user = str(user)
new_user_stream = self.user_to_user_stream.get(user)
if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id)
new_listeners = self.user_to_listeners.get(user, set())
listeners = self.room_to_listeners.setdefault(room_id, set())
listeners |= new_listeners
for l in new_listeners:
l.rooms.add(room_id)
def _discard_if_notified(listener_set):
"""Remove any 'stale' listeners from the given set.
"""
to_discard = set()
for l in listener_set:
if l.notified():
to_discard.add(l)
listener_set -= to_discard

View File

@@ -24,7 +24,6 @@ import baserules
import logging
import simplejson as json
import re
import random
logger = logging.getLogger(__name__)
@@ -75,33 +74,35 @@ class Pusher(object):
rawrules = yield self.store.get_push_rules_for_user(self.user_name)
rules = []
for rawrule in rawrules:
rule = dict(rawrule)
rule['conditions'] = json.loads(rawrule['conditions'])
rule['actions'] = json.loads(rawrule['actions'])
rules.append(rule)
for r in rawrules:
r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions'])
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rules, user)
room_id = ev['room_id']
rules = baserules.list_with_base_rules(rawrules, user)
# get *our* member event for display name matching
my_display_name = None
our_member_event = yield self.store.get_current_state(
room_id=room_id,
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=self.user_name,
state_key=None
)
if our_member_event:
my_display_name = our_member_event[0].content.get("displayname")
my_display_name = None
room_member_count = 0
for mev in member_events_for_room:
if mev.content['membership'] != 'join':
continue
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
# This loop does two things:
# 1) Find our current display name
if mev.state_key == self.user_name and 'displayname' in mev.content:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
for r in rules:
if r['rule_id'] in enabled_map:
@@ -257,154 +258,132 @@ class Pusher(object):
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
)
@defer.inlineCallbacks
def get_and_dispatch(self):
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=timeout, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
return
if not self.alive:
return
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False
)
processed = True
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
continue
if not self.alive:
continue
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_name
)
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_name
)
if not self.alive:
return
if not self.alive:
continue
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_name,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False

View File

@@ -18,23 +18,22 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.7": ["syutil>=0.0.7"],
"syutil>=0.0.6": ["syutil>=0.0.6"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
"pynacl>=0.0.3": ["nacl>=0.0.3"],
"pynacl": ["nacl"],
"daemonize": ["daemonize"],
"py-bcrypt": ["bcrypt"],
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
"ujson": ["ujson"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
"matrix_angular_sdk>=0.6.6": ["syweb>=0.6.6"],
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
}
}
@@ -51,15 +50,20 @@ def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
DEPENDENCY_LINKS = [
github_link(
project="pyca/pynacl",
version="d4d3175589b892f6ea7c22f466e0e223853516fa",
egg="pynacl-0.3.0",
),
github_link(
project="matrix-org/syutil",
version="v0.0.7",
egg="syutil-0.0.7",
version="v0.0.6",
egg="syutil-0.0.6",
),
github_link(
project="matrix-org/matrix-angular-sdk",
version="v0.6.6",
egg="matrix_angular_sdk-0.6.6",
version="v0.6.5",
egg="matrix_angular_sdk-0.6.5",
),
]

View File

@@ -25,7 +25,7 @@ class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod

View File

@@ -118,14 +118,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
user.to_string()
)
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
for r in rawrules:
r["conditions"] = json.loads(r["conditions"])
r["actions"] = json.loads(r["actions"])
ruleslist = baserules.list_with_base_rules(ruleslist, user)
ruleslist = baserules.list_with_base_rules(rawrules, user)
rules = {'global': {}, 'device': {}}

View File

@@ -28,7 +28,7 @@ class ClientV2AlphaRestResource(JsonResource):
"""A resource for version 2 alpha of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod

View File

@@ -65,12 +65,12 @@ class PasswordRestServlet(RestServlet):
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid_user = yield self.hs.get_datastore().get_user_by_threepid(
threepid['medium'], threepid['address']
)
if not threepid_user_id:
if not threepid_user:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
user_id = threepid_user
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)

View File

@@ -82,10 +82,8 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
result = None
if service:
is_application_server = True
params = body
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
@@ -94,7 +92,6 @@ class RegisterRestServlet(RestServlet):
body['username'], body['mac']
)
is_using_shared_secret = True
params = body
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
@@ -121,7 +118,7 @@ class RegisterRestServlet(RestServlet):
password=new_password
)
if result and LoginType.EMAIL_IDENTITY in result:
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']:

View File

@@ -15,18 +15,17 @@
from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
from twisted.internet import defer, threads
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
from synapse.util.async import create_observer
import os
@@ -53,7 +52,7 @@ class BaseMediaResource(Resource):
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.client = hs.get_http_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -84,17 +83,13 @@ class BaseMediaResource(Resource):
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download.observe()
return create_observer(download)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
@@ -274,65 +269,57 @@ class BaseMediaResource(Resource):
if not requirements:
return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
yield threads.deferToThread(generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue({
"width": m_width,

View File

@@ -59,6 +59,7 @@ class BaseHomeServer(object):
'config',
'clock',
'http_client',
'db_name',
'db_pool',
'persistence_service',
'replication_layer',

View File

@@ -106,7 +106,7 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None, outlier=False):
def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
@@ -119,23 +119,9 @@ class StateHandler(object):
Returns:
an EventContext
"""
yield run_on_reactor()
context = EventContext()
if outlier:
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
context.current_state = {
(s.type, s.state_key): s for s in old_state
}
else:
context.current_state = {}
context.prev_state_events = []
context.state_group = None
defer.returnValue(context)
yield run_on_reactor()
if old_state:
context.current_state = {
@@ -169,6 +155,10 @@ class StateHandler(object):
context.current_state = curr_state
context.state_group = group if not event.is_state() else None
prev_state = yield self.store.add_event_hashes(
prev_state
)
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:

View File

@@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 20
SCHEMA_VERSION = 17
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -348,7 +348,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine)
module.run_upgrade(cur)
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path)

View File

@@ -15,8 +15,10 @@
import logging
from synapse.api.errors import StoreError
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache
import synapse.metrics
@@ -25,13 +27,11 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
import simplejson as json
import sys
import time
import threading
DEBUG_CACHES = False
logger = logging.getLogger(__name__)
@@ -46,6 +46,7 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
caches_by_name = {}
cache_counter = metrics.register_cache(
@@ -67,19 +68,8 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
caches_by_name[name] = self.cache
def get(self, *keyargs):
if len(keyargs) != self.keylen:
@@ -92,13 +82,6 @@ class Cache(object):
cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
@@ -113,21 +96,13 @@ class Cache(object):
self.cache[keyargs] = value
def invalidate(self, *keyargs):
self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(keyargs, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
@@ -141,84 +116,43 @@ class CacheDescriptor(object):
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
self.orig = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
def wrap(orig):
cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
name=orig.__name__,
max_entries=max_entries,
keylen=num_args,
lru=lru,
)
@functools.wraps(self.orig)
@functools.wraps(orig)
@defer.inlineCallbacks
def wrapped(*keyargs):
def wrapped(self, *keyargs):
try:
cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
defer.returnValue(cache.get(*keyargs))
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = yield orig(self, *keyargs)
ret = yield self.orig(obj, *keyargs)
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
cache.prefill(*keyargs + (ret,))
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
return wrap
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
__slots__ = ["txn", "name", "database_engine", "after_callbacks"]
__slots__ = ["txn", "name", "database_engine"]
def __init__(self, txn, name, database_engine, after_callbacks):
def __init__(self, txn, name, database_engine):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
object.__setattr__(self, "after_callbacks", after_callbacks)
def call_after(self, callback, *args):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
"""
self.after_callbacks.append((callback, args))
def __getattr__(self, name):
return getattr(self.txn, name)
@@ -226,23 +160,22 @@ class LoggingTransaction(object):
def __setattr__(self, name, value):
setattr(self.txn, name, value)
def execute(self, sql, *args):
self._do_execute(self.txn.execute, sql, *args)
def executemany(self, sql, *args):
self._do_execute(self.txn.executemany, sql, *args)
def _do_execute(self, func, sql, *args):
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
sql = self.database_engine.convert_param_style(sql)
if args:
if args and args[0]:
args = list(args)
args[0] = [
self.database_engine.encode_parameter(a) for a in args[0]
]
try:
sql_logger.debug(
"[SQL values] {%s} %r",
self.name, args[0]
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
self.name,
*args[0]
)
except:
# Don't let logging failures stop SQL from working
@@ -251,8 +184,8 @@ class LoggingTransaction(object):
start = time.time() * 1000
try:
return func(
sql, *args
return self.txn.execute(
sql, *args, **kwargs
)
except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
@@ -321,12 +254,6 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
self._pending_ds = []
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator()
@@ -334,8 +261,6 @@ class SQLBaseStore(object):
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -366,75 +291,6 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
start = time.time() * 1000
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
name = "%s-%x" % (desc, txn_id, )
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks
)
r = func(txn, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
@@ -442,54 +298,82 @@ class SQLBaseStore(object):
start_time = time.time() * 1000
after_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return self._new_transaction(
conn, desc, after_callbacks, func, *args, **kwargs
)
start = time.time() * 1000
txn_id = self._TXN_ID
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
# We don't really need these to be unique, so lets stop it from
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
for after_callback, after_args in after_callbacks:
after_callback(*after_args)
defer.returnValue(result)
name = "%s-%x" % (desc, txn_id, )
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, name, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warn(
"[TXN OPERROR] {%s} %s %d/%d",
name, e, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
if i < N:
i += 1
try:
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warn(
"[TXN EROLL] {%s} %s",
name, e1,
)
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = time.time() * 1000
duration = end - start
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
current_context.copy_to(context)
return func(conn, *args, **kwargs)
result = yield preserve_context_over_fn(
self._db_pool.runWithConnection,
inner_func, *args, **kwargs
)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
defer.returnValue(result)
def cursor_to_dict(self, cursor):
@@ -554,49 +438,18 @@ class SQLBaseStore(object):
@log_function
def _simple_insert_txn(self, txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys),
", ".join("?" for _ in keys)
", ".join(k for k in values),
", ".join("?" for k in values)
)
txn.execute(sql, vals)
def _simple_insert_many_txn(self, txn, table, values):
if not values:
return
# This is a *slight* abomination to get a list of tuples of key names
# and a list of tuples of value names.
#
# i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
# => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
#
# The sort is to ensure that we don't rely on dictionary iteration
# order.
keys, vals = zip(*[
zip(
*(sorted(i.items(), key=lambda kv: kv[0]))
)
for i in values
if i
])
for k in keys:
if k != keys[0]:
raise RuntimeError(
"All items must have the same keys"
)
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys[0]),
", ".join("?" for _ in keys[0])
logger.debug(
"[SQL] %s Args=%s",
sql, values.values(),
)
txn.executemany(sql, vals)
txn.execute(sql, values.values())
def _simple_upsert(self, table, keyvalues, values,
insertion_values={}, desc="_simple_upsert", lock=True):
@@ -929,6 +782,158 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func)
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False):
return self.runInteraction(
"_get_events", self._get_events_txn, event_ids,
check_redacted=check_redacted, get_prev_content=get_prev_content,
)
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False):
if not event_ids:
return []
events = [
self._get_event_txn(
txn, event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content
)
for event_id in event_ids
]
return [e for e in events if e]
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
start_time = time.time() * 1000
def update_counter(desc, last_time):
curr_time = self._get_event_counters.update(desc, last_time)
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
try:
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
if allow_rejected or not ret.rejected_reason:
return ret
else:
return None
except KeyError:
pass
finally:
start_time = update_counter("event_cache", start_time)
sql = (
"SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
"FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
"LEFT JOIN rejections as rej on rej.event_id = e.event_id "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
txn.execute(sql, (event_id,))
res = txn.fetchone()
if not res:
return None
internal_metadata, js, redacted, rejected_reason = res
start_time = update_counter("select_event", start_time)
result = self._get_event_from_row_txn(
txn, internal_metadata, js, redacted,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=rejected_reason,
)
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
if allow_rejected or not rejected_reason:
return result
else:
return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
start_time = time.time() * 1000
def update_counter(desc, last_time):
curr_time = self._get_event_counters.update(desc, last_time)
sql_getevents_timer.inc_by(curr_time - last_time, desc)
return curr_time
d = json.loads(js)
start_time = update_counter("decode_json", start_time)
internal_metadata = json.loads(internal_metadata)
start_time = update_counter("decode_internal", start_time)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
start_time = update_counter("build_frozen_event", start_time)
if check_redacted and redacted:
ev = prune_event(ev)
ev.unsigned["redacted_by"] = redacted
# Get the redaction event.
because = self._get_event_txn(
txn,
redacted,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
start_time = update_counter("redact_event", start_time)
if get_prev_content and "replaces_state" in ev.unsigned:
prev = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
start_time = update_counter("get_prev_content", start_time)
return ev
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id

View File

@@ -19,8 +19,6 @@ from ._base import IncorrectDatabaseSetup
class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
@@ -38,6 +36,9 @@ class PostgresEngine(object):
def convert_param_style(self, sql):
return sql.replace("?", "%s")
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ

View File

@@ -17,8 +17,6 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module):
self.module = database_module
@@ -28,6 +26,9 @@ class Sqlite3Engine(object):
def convert_param_style(self, sql):
return sql
def encode_parameter(self, param):
return param
def on_new_connection(self, db_conn):
self.prepare_database(db_conn)

View File

@@ -13,13 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from syutil.base64util import encode_base64
import logging
from Queue import PriorityQueue, Empty
logger = logging.getLogger(__name__)
@@ -36,7 +33,16 @@ class EventFederationStore(SQLBaseStore):
"""
def get_auth_chain(self, event_ids):
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
return self.runInteraction(
"get_auth_chain",
self._get_auth_chain_txn,
event_ids
)
def _get_auth_chain_txn(self, txn, event_ids):
results = self._get_auth_chain_ids_txn(txn, event_ids)
return self._get_events_txn(txn, results)
def get_auth_chain_ids(self, event_ids):
return self.runInteraction(
@@ -73,28 +79,6 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
def get_oldest_events_with_depth_in_room(self, room_id):
return self.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
)
def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
sql = (
"SELECT b.event_id, MAX(e.depth) FROM events as e"
" INNER JOIN event_edges as g"
" ON g.event_id = e.event_id AND g.room_id = e.room_id"
" INNER JOIN event_backward_extremities as b"
" ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
" WHERE b.room_id = ? AND g.is_state is ?"
" GROUP BY b.event_id"
)
txn.execute(sql, (room_id, False,))
return dict(txn.fetchall())
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
txn,
@@ -112,7 +96,6 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
@cached()
def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol(
table="event_forward_extremities",
@@ -120,7 +103,7 @@ class EventFederationStore(SQLBaseStore):
"room_id": room_id,
},
retcol="event_id",
desc="get_latest_event_ids_in_room",
desc="get_latest_events_in_room",
)
def _get_latest_events_in_room(self, txn, room_id):
@@ -263,13 +246,11 @@ class EventFederationStore(SQLBaseStore):
do_insert = depth < min_depth if min_depth else True
if do_insert:
self._simple_upsert_txn(
self._simple_insert_txn(
txn,
table="room_depth",
keyvalues={
"room_id": room_id,
},
values={
"room_id": room_id,
"min_depth": depth,
},
)
@@ -280,19 +261,18 @@ class EventFederationStore(SQLBaseStore):
For the given event, update the event edges table and forward and
backward extremities tables.
"""
self._simple_insert_many_txn(
txn,
table="event_edges",
values=[
{
for e_id, _ in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event_id,
"prev_event_id": e_id,
"room_id": room_id,
"is_state": False,
}
for e_id, _ in prev_events
],
)
},
)
# Update the extremities table if this is not an outlier.
if not outlier:
@@ -324,32 +304,30 @@ class EventFederationStore(SQLBaseStore):
txn.execute(query, (event_id, room_id))
# Insert all the prev_events as a backwards thing, they'll get
# deleted in a second if they're incorrect anyway.
for e_id, _ in prev_events:
# TODO (erikj): This could be done as a bulk insert
self._simple_insert_txn(
txn,
table="event_backward_extremities",
values={
"event_id": e_id,
"room_id": room_id,
},
)
# Also delete from the backwards extremities table all ones that
# reference events that we have already seen
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
" )"
" AND NOT EXISTS ("
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
" AND outlier = ?"
" )"
)
txn.executemany(query, [
(e_id, room_id, e_id, room_id, e_id, room_id, False)
for e_id, _ in prev_events
])
query = (
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.execute(query, (event_id, room_id))
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id
"DELETE FROM event_backward_extremities WHERE EXISTS ("
"SELECT 1 FROM events "
"WHERE "
"event_backward_extremities.event_id = events.event_id "
"AND not events.outlier "
")"
)
txn.execute(query)
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
@@ -364,10 +342,6 @@ class EventFederationStore(SQLBaseStore):
return self.runInteraction(
"get_backfill_events",
self._get_backfill_events, room_id, event_list, limit
).addCallback(
self._get_events
).addCallback(
lambda l: sorted(l, key=lambda e: -e.depth)
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
@@ -376,75 +350,54 @@ class EventFederationStore(SQLBaseStore):
room_id, repr(event_list), limit
)
event_results = set()
event_results = event_list
# We want to make sure that we do a breadth-first, "depth" ordered
# search.
front = event_list
query = (
"SELECT depth, prev_event_id FROM event_edges"
" INNER JOIN events"
" ON prev_event_id = events.event_id"
" AND event_edges.room_id = events.room_id"
" WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
" AND event_edges.is_state = ?"
" LIMIT ?"
"SELECT prev_event_id FROM event_edges "
"WHERE room_id = ? AND event_id = ? "
"LIMIT ?"
)
queue = PriorityQueue()
# We iterate through all event_ids in `front` to select their previous
# events. These are dumped in `new_front`.
# We continue until we reach the limit *or* new_front is empty (i.e.,
# we've run out of things to select
while front and len(event_results) < limit:
for event_id in event_list:
depth = self._simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={
"event_id": event_id,
},
retcol="depth"
)
new_front = []
for event_id in front:
logger.debug(
"_backfill_interaction: id=%s",
event_id
)
queue.put((-depth, event_id))
txn.execute(
query,
(room_id, event_id, limit - len(event_results))
)
while not queue.empty() and len(event_results) < limit:
try:
_, event_id = queue.get_nowait()
except Empty:
break
for row in txn.fetchall():
logger.debug(
"_backfill_interaction: got id=%s",
*row
)
new_front.append(row[0])
if event_id in event_results:
continue
front = new_front
event_results += new_front
event_results.add(event_id)
return self._get_events_txn(txn, event_results)
txn.execute(
query,
(room_id, event_id, False, limit - len(event_results))
)
for row in txn.fetchall():
if row[1] not in event_results:
queue.put((-row[0], row[1]))
return event_results
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events,
limit, min_depth):
ids = yield self.runInteraction(
return self.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id, earliest_events, latest_events, limit, min_depth
)
events = yield self._get_events(ids)
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
defer.returnValue(events[:limit])
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
limit, min_depth):
@@ -476,7 +429,14 @@ class EventFederationStore(SQLBaseStore):
front = new_front
event_results |= new_front
return event_results
events = self._get_events_txn(txn, event_results)
events = sorted(
[ev for ev in events if ev.depth >= min_depth],
key=lambda e: e.depth,
)
return events[:limit]
def clean_room_for_join(self, room_id):
return self.runInteraction(
@@ -489,4 +449,3 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)

View File

@@ -15,36 +15,20 @@
from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer, reactor
from twisted.internet import defer
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_json
from contextlib import contextmanager
from syutil.jsonutil import encode_canonical_json
import logging
import ujson as json
logger = logging.getLogger(__name__)
# These values are used in the `enqueus_event` and `_do_fetch` methods to
# control how we batch/bulk fetch events from the database.
# The values are plucked out of thing air to make initial sync run faster
# on jki.re
# TODO: Make these configurable.
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
@@ -57,32 +41,20 @@ class EventsStore(SQLBaseStore):
self.min_token -= 1
stream_ordering = self.min_token
if stream_ordering is None:
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
else:
@contextmanager
def stream_ordering_manager():
yield stream_ordering
stream_ordering_manager = stream_ordering_manager()
try:
with stream_ordering_manager as stream_ordering:
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
except _RollbackButIsFineException:
pass
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
@@ -102,17 +74,18 @@ class EventsStore(SQLBaseStore):
Returns:
Deferred : A FrozenEvent.
"""
events = yield self._get_events(
[event_id],
event = yield self.runInteraction(
"get_event", self._get_event_txn,
event_id,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
if not events and not allow_none:
if not event and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,))
defer.returnValue(events[0] if events else None)
defer.returnValue(event)
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
@@ -120,17 +93,20 @@ class EventsStore(SQLBaseStore):
current_state=None):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
self._invalidate_get_event_cache(event.event_id)
if stream_ordering is None:
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
return self._persist_event_txn(
txn, event, context, backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state,
current_state=current_state,
)
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn(
txn,
table="current_state_events",
@@ -146,29 +122,52 @@ class EventsStore(SQLBaseStore):
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
}
},
)
if event.is_state() and is_new_state:
if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
)
for prev_state_id, _ in event.prev_state:
self._simple_delete_txn(
txn,
table="state_forward_extremities",
keyvalues={
"event_id": prev_state_id,
}
)
outlier = event.internal_metadata.is_outlier()
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
have_persisted = self._simple_select_one_txn(
have_persisted = self._simple_select_one_onecol_txn(
txn,
table="events",
table="event_json",
keyvalues={"event_id": event.event_id},
retcols=["event_id", "outlier"],
retcol="event_id",
allow_none=True,
)
metadata_json = encode_json(
event.internal_metadata.get_dict(),
using_frozen_dicts=USE_FROZEN_DICTS
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
# If we have already persisted this event, we don't need to do any
@@ -178,9 +177,7 @@ class EventsStore(SQLBaseStore):
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
if not outlier and have_persisted["outlier"]:
self._store_state_groups_txn(txn, event, context)
if not outlier:
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
@@ -200,9 +197,6 @@ class EventsStore(SQLBaseStore):
)
return
if not outlier:
self._store_state_groups_txn(txn, event, context)
self._handle_prev_events(
txn,
outlier=outlier,
@@ -236,14 +230,12 @@ class EventsStore(SQLBaseStore):
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": encode_json(
event_dict, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"),
"json": encode_canonical_json(event_dict).decode("UTF-8"),
},
)
content = encode_json(
event.content, using_frozen_dicts=USE_FROZEN_DICTS
content = encode_canonical_json(
event.content
).decode("UTF-8")
vals = {
@@ -269,8 +261,8 @@ class EventsStore(SQLBaseStore):
]
}
vals["unrecognized_keys"] = encode_json(
unrec, using_frozen_dicts=USE_FROZEN_DICTS
vals["unrecognized_keys"] = encode_canonical_json(
unrec
).decode("UTF-8")
sql = (
@@ -289,9 +281,7 @@ class EventsStore(SQLBaseStore):
)
if context.rejected:
self._store_rejections_txn(
txn, event.event_id, context.rejected
)
self._store_rejections_txn(txn, event.event_id, context.rejected)
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
@@ -303,22 +293,19 @@ class EventsStore(SQLBaseStore):
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg,
hash_bytes
txn, event.event_id, prev_event_id, alg, hash_bytes
)
self._simple_insert_many_txn(
txn,
table="event_auth",
values=[
{
for auth_id, _ in event.auth_events:
self._simple_insert_txn(
txn,
table="event_auth",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"auth_id": auth_id,
}
for auth_id, _ in event.auth_events
],
)
},
)
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
self._store_event_reference_hash_txn(
@@ -343,33 +330,19 @@ class EventsStore(SQLBaseStore):
vals,
)
self._simple_insert_many_txn(
txn,
table="event_edges",
values=[
{
for e_id, h in event.prev_state:
self._simple_insert_txn(
txn,
table="event_edges",
values={
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": True,
}
for e_id, h in event.prev_state
],
)
if is_new_state and not context.rejected:
txn.call_after(
self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key
},
)
if (event.type == EventTypes.Name
or event.type == EventTypes.Aliases):
txn.call_after(
self.get_room_name_and_aliases.invalidate,
event.room_id
)
if is_new_state and not context.rejected:
self._simple_upsert_txn(
txn,
"current_state_events",
@@ -383,11 +356,9 @@ class EventsStore(SQLBaseStore):
}
)
return
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
self._invalidate_get_event_cache(event.redacts)
txn.execute(
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
(event.event_id, event.redacts)
@@ -424,409 +395,3 @@ class EventsStore(SQLBaseStore):
return self.runInteraction(
"have_events", f,
)
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not event_ids:
defer.returnValue([])
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
missing_events = yield self._enqueue_events(
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not event_ids:
return []
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
return [
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
]
missing_events = self._fetch_events_txn(
txn,
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
return [
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
]
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
events = self._get_events_txn(
txn, [event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
return events[0] if events else None
def _get_events_from_cache(self, events, check_redacted, get_prev_content,
allow_rejected):
event_map = {}
for event_id in events:
try:
ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
)
if allow_rejected or not ret.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
except KeyError:
pass
return event_map
def _do_fetch(self, conn):
"""Takes a database connection and waits for requests for events from
the _event_fetch_list queue.
"""
event_list = []
i = 0
while True:
try:
with self._event_fetch_lock:
event_list = self._event_fetch_list
self._event_fetch_list = []
if not event_list:
single_threaded = self.database_engine.single_threaded
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
self._event_fetch_ongoing -= 1
return
else:
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
i += 1
continue
i = 0
event_id_lists = zip(*event_list)[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
rows = self._new_transaction(
conn, "do_fetch", [], self._fetch_event_rows, event_ids
)
row_dict = {
r["event_id"]: r
for r in rows
}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
for ids, d in lst:
if not d.called:
try:
d.callback([
res[i]
for i in ids
if i in res
])
except:
logger.exception("Failed to callback")
reactor.callFromThread(fire, event_list, row_dict)
except Exception as e:
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
for _, d in evs:
if not d.called:
d.errback(e)
if event_list:
reactor.callFromThread(fire, event_list)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
"""
if not events:
defer.returnValue({})
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_lock.notify()
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
self._event_fetch_ongoing += 1
should_start = True
else:
should_start = False
if should_start:
self.runWithConnection(
self._do_fetch
)
rows = yield preserve_context_over_deferred(events_d)
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield defer.gatherResults(
[
self._get_event_from_row(
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row["rejects"],
)
for row in rows
],
consumeErrors=True
)
defer.returnValue({
e.event_id: e
for e in res if e
})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N]
if not evs:
break
sql = (
"SELECT "
" e.event_id as event_id, "
" e.internal_metadata,"
" e.json,"
" r.redacts as redacts,"
" rej.event_id as rejects "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),)
txn.execute(sql, evs)
rows.extend(self.cursor_to_dict(txn))
return rows
def _fetch_events_txn(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
return {}
rows = self._fetch_event_rows(
txn, events,
)
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = [
self._get_event_from_row_txn(
txn,
row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row["rejects"],
)
for row in rows
]
return {
r.event_id: r
for r in res
}
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = yield self._simple_select_one_onecol(
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
desc="_get_event_from_row",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = yield self._simple_select_one_onecol(
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
desc="_get_event_from_row",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = yield self.get_event(
ev.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
)
defer.returnValue(ev)
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
if rejected_reason:
rejected_reason = self._simple_select_one_onecol_txn(
txn,
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = self._simple_select_one_onecol_txn(
txn,
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = self._get_event_txn(
txn,
redaction_id,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = self._get_event_txn(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev
)
return ev
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None

View File

@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from twisted.internet import defer
from ._base import SQLBaseStore
class PresenceStore(SQLBaseStore):
@@ -89,48 +87,31 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending",
)
@defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid):
result = yield self._simple_update_one(
return self._simple_update_one(
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
updatevalues={"accepted": True},
desc="set_presence_list_accepted",
)
self.get_presence_list_accepted.invalidate(observer_localpart)
defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
if accepted:
return self.get_presence_list_accepted(observer_localpart)
else:
keyvalues = {"user_id": observer_localpart}
if accepted is not None:
keyvalues["accepted"] = accepted
keyvalues = {"user_id": observer_localpart}
if accepted is not None:
keyvalues["accepted"] = accepted
return self._simple_select_list(
table="presence_list",
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
desc="get_presence_list",
)
@cached()
def get_presence_list_accepted(self, observer_localpart):
return self._simple_select_list(
table="presence_list",
keyvalues={"user_id": observer_localpart, "accepted": True},
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
desc="get_presence_list_accepted",
desc="get_presence_list",
)
@defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one(
return self._simple_delete_one(
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate(observer_localpart)

View File

@@ -13,60 +13,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
import collections
from ._base import SQLBaseStore, Table
from twisted.internet import defer
import logging
import copy
import simplejson as json
logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@cached()
@defer.inlineCallbacks
def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list(
table=PushRuleTable.table_name,
keyvalues={
"user_name": user_name,
},
retcols=PushRuleTable.fields,
desc="get_push_rules_enabled_for_user",
sql = (
"SELECT "+",".join(PushRuleTable.fields)+" "
"FROM "+PushRuleTable.table_name+" "
"WHERE user_name = ? "
"ORDER BY priority_class DESC, priority DESC"
)
rows = yield self._execute("get_push_rules_for_user", None, sql, user_name)
rows.sort(
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
dicts = []
for r in rows:
d = {}
for i, f in enumerate(PushRuleTable.fields):
d[f] = r[i]
dicts.append(d)
defer.returnValue(rows)
defer.returnValue(dicts)
@cached()
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name,
keyvalues={
'user_name': user_name
},
retcols=PushRuleEnableTable.fields,
PushRuleEnableTable.table_name,
{'user_name': user_name},
PushRuleEnableTable.fields,
desc="get_push_rules_enabled_for_user",
)
defer.returnValue({
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
)
@defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs):
vals = kwargs
vals = copy.copy(kwargs)
if 'conditions' in vals:
vals['conditions'] = json.dumps(vals['conditions'])
if 'actions' in vals:
vals['actions'] = json.dumps(vals['actions'])
# we could check the rest of the keys are valid column names
# but sqlite will do that anyway so I think it's just pointless.
vals.pop("id", None)
if 'id' in vals:
del vals['id']
if before or after:
ret = yield self.runInteraction(
@@ -86,39 +87,39 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
after = kwargs.pop("after", None)
relative_to_rule = kwargs.pop("before", after)
after = None
relative_to_rule = None
if 'after' in kwargs and kwargs['after']:
after = kwargs['after']
relative_to_rule = after
if 'before' in kwargs and kwargs['before']:
relative_to_rule = kwargs['before']
res = self._simple_select_one_txn(
txn,
table=PushRuleTable.table_name,
keyvalues={
"user_name": user_name,
"rule_id": relative_to_rule,
},
retcols=["priority_class", "priority"],
allow_none=True,
# get the priority of the rule we're inserting after/before
sql = (
"SELECT priority_class, priority FROM ? "
"WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
)
txn.execute(sql, (user_name, relative_to_rule))
res = txn.fetchall()
if not res:
raise RuleNotFoundException(
"before/after rule not found: %s" % (relative_to_rule,)
)
priority_class = res["priority_class"]
base_rule_priority = res["priority"]
priority_class, base_rule_priority = res[0]
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
new_rule = kwargs
new_rule.pop("before", None)
new_rule.pop("after", None)
new_rule = copy.copy(kwargs)
if 'before' in new_rule:
del new_rule['before']
if 'after' in new_rule:
del new_rule['after']
new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
# check if the priority before/after is free
new_rule_priority = base_rule_priority
@@ -152,19 +153,12 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
)
# now insert the new rule
sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
)
self._simple_insert_txn(
txn,
table=PushRuleTable.table_name,
values=new_rule,
)
txn.execute(sql, new_rule.values())
def _add_push_rule_highest_priority_txn(self, txn, user_name,
priority_class, **kwargs):
@@ -182,24 +176,18 @@ class PushRuleStore(SQLBaseStore):
new_prio = highest_prio + 1
# and insert the new rule
new_rule = kwargs
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
new_rule = copy.copy(kwargs)
if 'id' in new_rule:
del new_rule['id']
new_rule['user_name'] = user_name
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
)
sql = "INSERT INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
self._simple_insert_txn(
txn,
table=PushRuleTable.table_name,
values=new_rule,
)
txn.execute(sql, new_rule.values())
@defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id):
@@ -218,32 +206,13 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule",
)
self.get_push_rules_for_user.invalidate(user_name)
self.get_push_rules_enabled_for_user.invalidate(user_name)
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
ret = yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
user_name, rule_id, enabled
)
defer.returnValue(ret)
def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
self._simple_upsert_txn(
txn,
yield self._simple_upsert(
PushRuleEnableTable.table_name,
{'user_name': user_name, 'rule_id': rule_id},
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
txn.call_after(
self.get_push_rules_for_user.invalidate, user_name
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
{'enabled': enabled},
desc="set_push_rule_enabled",
)
@@ -255,7 +224,7 @@ class InconsistentRuleException(Exception):
pass
class PushRuleTable(object):
class PushRuleTable(Table):
table_name = "push_rules"
fields = [
@@ -268,8 +237,10 @@ class PushRuleTable(object):
"actions",
]
EntryType = collections.namedtuple("PushRuleEntry", fields)
class PushRuleEnableTable(object):
class PushRuleEnableTable(Table):
table_name = "push_rules_enable"
fields = [

View File

@@ -112,15 +112,14 @@ class RegistrationStore(SQLBaseStore):
@defer.inlineCallbacks
def user_delete_access_tokens_apart_from(self, user_id, token_id):
yield self.runInteraction(
"user_delete_access_tokens_apart_from",
self._user_delete_access_tokens_apart_from, user_id, token_id
)
rows = yield self.get_user_by_id(user_id)
if len(rows) == 0:
raise Exception("No such user!")
def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id):
txn.execute(
yield self._execute(
"delete_access_tokens_apart_from", None,
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
(user_id, token_id)
rows[0]['id'], token_id
)
@defer.inlineCallbacks
@@ -202,15 +201,15 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address):
def get_user_by_threepid(self, medium, address):
ret = yield self._simple_select_one(
"user_threepids",
{
"medium": medium,
"address": address
},
['user_id'], True, 'get_user_id_by_threepid'
['user'], True, 'get_user_by_threepid'
)
if ret:
defer.returnValue(ret['user_id'])
defer.returnValue(ret['user'])
defer.returnValue(None)

View File

@@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
import collections
import logging
@@ -186,7 +186,6 @@ class RoomStore(SQLBaseStore):
}
)
@cached()
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id):
def f(txn):

View File

@@ -64,9 +64,8 @@ class RoomMemberStore(SQLBaseStore):
}
)
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
self.get_rooms_for_user.invalidate(target_user_id)
self.get_joined_hosts_for_room.invalidate(event.room_id)
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@@ -77,18 +76,17 @@ class RoomMemberStore(SQLBaseStore):
Returns:
Deferred: Results in a MembershipEvent or None.
"""
return self.runInteraction(
"get_room_member",
self._get_members_events_txn,
room_id,
user_id=user_id,
).addCallback(
self._get_events
).addCallback(
lambda events: events[0] if events else None
)
def f(txn):
events = self._get_members_events_txn(
txn,
room_id,
user_id=user_id,
)
return events[0] if events else None
return self.runInteraction("get_room_member", f)
@cached()
def get_users_in_room(self, room_id):
def f(txn):
@@ -112,12 +110,15 @@ class RoomMemberStore(SQLBaseStore):
Returns:
list of namedtuples representing the members in this room.
"""
return self.runInteraction(
"get_room_members",
self._get_members_events_txn,
room_id,
membership=membership,
).addCallback(self._get_events)
def f(txn):
return self._get_members_events_txn(
txn,
room_id,
membership=membership,
)
return self.runInteraction("get_room_members", f)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -189,14 +190,14 @@ class RoomMemberStore(SQLBaseStore):
return self.runInteraction(
"get_members_query", self._get_members_events_txn,
where_clause, where_values
).addCallbacks(self._get_events)
)
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn(
txn,
room_id, membership, user_id,
)
return [r["event_id"] for r in rows]
return self._get_events_txn(txn, [r["event_id"] for r in rows])
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
where_clause = "c.room_id = ?"

View File

@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, *args, **kwargs):
def run_upgrade(cur):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:

View File

@@ -1,18 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DROP INDEX IF EXISTS sent_transaction_dest;
DROP INDEX IF EXISTS sent_transaction_sent;
DROP INDEX IF EXISTS user_ips_user;

View File

@@ -1,32 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS new_server_keys_json (
server_name TEXT NOT NULL, -- Server name.
key_id TEXT NOT NULL, -- Requested key id.
from_server TEXT NOT NULL, -- Which server the keys were fetched from.
ts_added_ms BIGINT NOT NULL, -- When the keys were fetched
ts_valid_until_ms BIGINT NOT NULL, -- When this version of the keys exipires.
key_json bytea NOT NULL, -- JSON certificate for the remote server.
CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server)
);
INSERT INTO new_server_keys_json
SELECT server_name, key_id, from_server,ts_added_ms, ts_valid_until_ms, key_json FROM server_keys_json ;
DROP TABLE server_keys_json;
ALTER TABLE new_server_keys_json RENAME TO server_keys_json;

View File

@@ -1,19 +0,0 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX events_order_topo_stream_room ON events(
topological_ordering, stream_ordering, room_id
);

View File

@@ -1,76 +0,0 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main purpose of this upgrade is to change the unique key on the
pushers table again (it was missed when the v16 full schema was
made) but this also changes the pushkey and data columns to text.
When selecting a bytea column into a text column, postgres inserts
the hex encoded data, and there's no portable way of getting the
UTF-8 bytes, so we have to do it in Python.
"""
import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
access_token BIGINT DEFAULT NULL,
profile_tag VARCHAR(32) NOT NULL,
kind VARCHAR(8) NOT NULL,
app_id VARCHAR(64) NOT NULL,
app_display_name VARCHAR(64) NOT NULL,
device_display_name VARCHAR(128) NOT NULL,
pushkey TEXT NOT NULL,
ts BIGINT NOT NULL,
lang VARCHAR(8),
data TEXT,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
""")
cur.execute("""SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
""")
count = 0
for row in cur.fetchall():
row = list(row)
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(database_engine.convert_param_style("""
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
row
)
count += 1
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cached
from ._base import SQLBaseStore
from twisted.internet import defer
@@ -43,7 +43,6 @@ class StateStore(SQLBaseStore):
* `state_groups_state`: Maps state group to state events.
"""
@defer.inlineCallbacks
def get_state_groups(self, event_ids):
""" Get the state groups for the given list of event_ids
@@ -72,33 +71,17 @@ class StateStore(SQLBaseStore):
retcol="event_id",
)
res[group] = state_ids
state = self._get_events_txn(txn, state_ids)
res[group] = state
return res
states = yield self.runInteraction(
return self.runInteraction(
"get_state_groups",
f,
)
state_list = yield defer.gatherResults(
[
self._fetch_events_for_group(group, vals)
for group, vals in states.items()
],
consumeErrors=True,
)
defer.returnValue(dict(state_list))
@cached(num_args=1)
def _fetch_events_for_group(self, state_group, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (state_group, evs)
)
def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None:
return
@@ -121,20 +104,18 @@ class StateStore(SQLBaseStore):
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
for state in state_events.values():
self._simple_insert_txn(
txn,
table="state_groups_state",
values={
"state_group": state_group,
"room_id": state.room_id,
"type": state.type,
"state_key": state.state_key,
"event_id": state.event_id,
}
for state in state_events.values()
],
)
},
)
self._simple_insert_txn(
txn,
@@ -147,12 +128,6 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
result = yield self.get_current_state_for_key(
room_id, event_type, state_key
)
defer.returnValue(result)
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
@@ -169,29 +144,11 @@ class StateStore(SQLBaseStore):
args = (room_id, )
txn.execute(sql, args)
results = txn.fetchall()
results = self.cursor_to_dict(txn)
return [r[0] for r in results]
return self._parse_events_txn(txn, results)
event_ids = yield self.runInteraction("get_current_state", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?"
)
args = (room_id, event_type, state_key)
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
event_ids = yield self.runInteraction("get_current_state_for_key", f)
events = yield self._get_events(event_ids, get_prev_content=False)
events = yield self.runInteraction("get_current_state", f)
defer.returnValue(events)

View File

@@ -37,9 +37,11 @@ from twisted.internet import defer
from ._base import SQLBaseStore
from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken
from synapse.api.errors import SynapseError
from synapse.util.logutils import log_function
from collections import namedtuple
import logging
@@ -53,26 +55,76 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
def lower_bound(token):
if token.topological is None:
return "(%d < %s)" % (token.stream, "stream_ordering")
else:
return "(%d < %s OR (%d = %s AND %d < %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, "stream_ordering",
)
class _StreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1.
s0 s1
| |
[0] V [1] V [2]
def upper_bound(token):
if token.topological is None:
return "(%d >= %s)" % (token.stream, "stream_ordering")
else:
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, "stream_ordering",
)
Tokens can either be a point in the live event stream or a cursor going
through historic events.
When traversing the live event stream events are ordered by when they
arrived at the homeserver.
When traversing historic events the events are ordered by their depth in
the event graph "topological_ordering" and then by when they arrived at the
homeserver "stream_ordering".
Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
__slots__ = []
@classmethod
def parse(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
parts = string[1:].split('-', 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@classmethod
def parse_stream_token(cls, string):
try:
if string[0] == 's':
return cls(topological=None, stream=int(string[1:]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
def __str__(self):
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
def lower_bound(self):
if self.topological is None:
return "(%d < %s)" % (self.stream, "stream_ordering")
else:
return "(%d < %s OR (%d = %s AND %d < %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
)
def upper_bound(self):
if self.topological is None:
return "(%d >= %s)" % (self.stream, "stream_ordering")
else:
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
self.topological, "topological_ordering",
self.topological, "topological_ordering",
self.stream, "stream_ordering",
)
class StreamStore(SQLBaseStore):
@@ -87,8 +139,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = RoomStreamToken.parse_stream_token(to_key)
from_id = _StreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key)
if from_key == to_key:
defer.returnValue(([], to_key))
@@ -182,8 +234,8 @@ class StreamStore(SQLBaseStore):
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = RoomStreamToken.parse_stream_token(to_key)
from_id = _StreamToken.parse_stream_token(from_key)
to_id = _StreamToken.parse_stream_token(to_key)
if from_key == to_key:
return defer.succeed(([], to_key))
@@ -224,7 +276,7 @@ class StreamStore(SQLBaseStore):
return self.runInteraction("get_room_events_stream", f)
@defer.inlineCallbacks
@log_function
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1,
with_feedback=False):
@@ -236,17 +288,17 @@ class StreamStore(SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
bounds = upper_bound(RoomStreamToken.parse(from_key))
bounds = _StreamToken.parse(from_key).upper_bound()
if to_key:
bounds = "%s AND %s" % (
bounds, lower_bound(RoomStreamToken.parse(to_key))
bounds, _StreamToken.parse(to_key).lower_bound()
)
else:
order = "ASC"
bounds = lower_bound(RoomStreamToken.parse(from_key))
bounds = _StreamToken.parse(from_key).lower_bound()
if to_key:
bounds = "%s AND %s" % (
bounds, upper_bound(RoomStreamToken.parse(to_key))
bounds, _StreamToken.parse(to_key).upper_bound()
)
if int(limit) > 0:
@@ -281,30 +333,28 @@ class StreamStore(SQLBaseStore):
# when we are going backwards so we subtract one from the
# stream part.
toke -= 1
next_token = str(RoomStreamToken(topo, toke))
next_token = str(_StreamToken(topo, toke))
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key
return rows, next_token,
events = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
rows, token = yield self.runInteraction("paginate_room_events", f)
self._set_before_and_after(events, rows)
events = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
return events, next_token,
self._set_before_and_after(events, rows)
return self.runInteraction("paginate_room_events", f)
defer.returnValue((events, token))
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token,
with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback
end_token = RoomStreamToken.parse_stream_token(end_token)
end_token = _StreamToken.parse_stream_token(end_token)
if from_token is None:
sql = (
@@ -315,7 +365,7 @@ class StreamStore(SQLBaseStore):
" LIMIT ?"
)
else:
from_token = RoomStreamToken.parse_stream_token(from_token)
from_token = _StreamToken.parse_stream_token(from_token)
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
@@ -345,49 +395,30 @@ class StreamStore(SQLBaseStore):
# stream part.
topo = rows[0]["topological_ordering"]
toke = rows[0]["stream_ordering"] - 1
start_token = str(RoomStreamToken(topo, toke))
start_token = str(_StreamToken(topo, toke))
token = (start_token, str(end_token))
else:
token = (str(end_token), str(end_token))
return rows, token
events = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
rows, token = yield self.runInteraction(
self._set_before_and_after(events, rows)
return events, token
return self.runInteraction(
"get_recent_events_for_room", get_recent_events_for_room_txn
)
logger.debug("stream before")
events = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
logger.debug("stream after")
self._set_before_and_after(events, rows)
defer.returnValue((events, token))
@defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'):
def get_room_events_max_id(self):
token = yield self._stream_id_gen.get_max_token(self)
if direction != 'b':
defer.returnValue("s%d" % (token,))
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn
)
defer.returnValue("t%d-%d" % (topo, token))
def _get_max_topological_txn(self, txn):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
" WHERE outlier = ?",
(False,)
)
rows = txn.fetchall()
return rows[0][0] if rows else 0
defer.returnValue("s%d" % (token,))
@defer.inlineCallbacks
def _get_min_token(self):
@@ -408,5 +439,5 @@ class StreamStore(SQLBaseStore):
stream = row["stream_ordering"]
topo = event.depth
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
internal.before = str(_StreamToken(topo, stream - 1))
internal.after = str(_StreamToken(topo, stream))

View File

@@ -17,7 +17,6 @@ from ._base import SQLBaseStore, cached
from collections import namedtuple
from syutil.jsonutil import encode_canonical_json
import logging
logger = logging.getLogger(__name__)
@@ -83,7 +82,7 @@ class TransactionStore(SQLBaseStore):
"transaction_id": transaction_id,
"origin": origin,
"response_code": code,
"response_json": buffer(encode_canonical_json(response_dict)),
"response_json": response_dict,
},
or_ignore=True,
desc="set_received_txn_response",
@@ -162,8 +161,7 @@ class TransactionStore(SQLBaseStore):
return self.runInteraction(
"delivered_txn",
self._delivered_txn,
transaction_id, destination, code,
buffer(encode_canonical_json(response_dict)),
transaction_id, destination, code, response_dict
)
def _delivered_txn(self, txn, transaction_id, destination,

View File

@@ -78,18 +78,14 @@ class StreamIdGenerator(object):
self._current_max = None
self._unfinished_ids = deque()
@defer.inlineCallbacks
def get_next(self, store):
def get_next_txn(self, txn):
"""
Usage:
with yield stream_id_gen.get_next as stream_id:
with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
self._get_or_compute_current_max(txn)
with self._lock:
self._current_max += 1
@@ -105,7 +101,7 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
return manager()
@defer.inlineCallbacks
def get_max_token(self, store):

Some files were not shown because too many files have changed in this diff Show More