mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-07 01:20:16 +00:00
Compare commits
1 Commits
v0.9.2
...
erikj/init
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
171829bb94 |
@@ -35,6 +35,3 @@ Turned to Dust <dwinslow86 at gmail.com>
|
||||
|
||||
Brabo <brabo at riseup.net>
|
||||
* Installation instruction fixes
|
||||
|
||||
Ivan Shapovalov <intelfx100 at gmail.com>
|
||||
* contrib/systemd: a sample systemd unit file and a logger configuration
|
||||
|
||||
109
CHANGES.rst
109
CHANGES.rst
@@ -1,110 +1,9 @@
|
||||
Changes in synapse v0.9.2 (2015-06-12)
|
||||
======================================
|
||||
Changes in synapse vX
|
||||
=====================
|
||||
|
||||
General:
|
||||
* Changed config option from ``disable_registration`` to
|
||||
``enable_registration``. Old option will be ignored.
|
||||
|
||||
* Use ultrajson for json (de)serialisation when a canonical encoding is not
|
||||
required. Ultrajson is significantly faster than simplejson in certain
|
||||
circumstances.
|
||||
* Use connection pools for outgoing HTTP connections.
|
||||
* Process thumbnails on separate threads.
|
||||
|
||||
Configuration:
|
||||
|
||||
* Add option, ``gzip_responses``, to disable HTTP response compression.
|
||||
|
||||
Federation:
|
||||
|
||||
* Improve resilience of backfill by ensuring we fetch any missing auth events.
|
||||
* Improve performance of backfill and joining remote rooms by removing
|
||||
unnecessary computations. This included handling events we'd previously
|
||||
handled as well as attempting to compute the current state for outliers.
|
||||
|
||||
|
||||
Changes in synapse v0.9.1 (2015-05-26)
|
||||
======================================
|
||||
|
||||
General:
|
||||
|
||||
* Add support for backfilling when a client paginates. This allows servers to
|
||||
request history for a room from remote servers when a client tries to
|
||||
paginate history the server does not have - SYN-36
|
||||
* Fix bug where you couldn't disable non-default pushrules - SYN-378
|
||||
* Fix ``register_new_user`` script - SYN-359
|
||||
* Improve performance of fetching events from the database, this improves both
|
||||
initialSync and sending of events.
|
||||
* Improve performance of event streams, allowing synapse to handle more
|
||||
simultaneous connected clients.
|
||||
|
||||
Federation:
|
||||
|
||||
* Fix bug with existing backfill implementation where it returned the wrong
|
||||
selection of events in some circumstances.
|
||||
* Improve performance of joining remote rooms.
|
||||
|
||||
Configuration:
|
||||
|
||||
* Add support for changing the bind host of the metrics listener via the
|
||||
``metrics_bind_host`` option.
|
||||
|
||||
|
||||
Changes in synapse v0.9.0-r5 (2015-05-21)
|
||||
=========================================
|
||||
|
||||
* Add more database caches to reduce amount of work done for each pusher. This
|
||||
radically reduces CPU usage when multiple pushers are set up in the same room.
|
||||
|
||||
Changes in synapse v0.9.0 (2015-05-07)
|
||||
======================================
|
||||
|
||||
General:
|
||||
|
||||
* Add support for using a PostgreSQL database instead of SQLite. See
|
||||
`docs/postgres.rst`_ for details.
|
||||
* Add password change and reset APIs. See `Registration`_ in the spec.
|
||||
* Fix memory leak due to not releasing stale notifiers - SYN-339.
|
||||
* Fix race in caches that occasionally caused some presence updates to be
|
||||
dropped - SYN-369.
|
||||
* Check server name has not changed on restart.
|
||||
* Add a sample systemd unit file and a logger configuration in
|
||||
contrib/systemd. Contributed Ivan Shapovalov.
|
||||
|
||||
Federation:
|
||||
|
||||
* Add key distribution mechanisms for fetching public keys of unavailable
|
||||
remote home servers. See `Retrieving Server Keys`_ in the spec.
|
||||
|
||||
Configuration:
|
||||
|
||||
* Add support for multiple config files.
|
||||
* Add support for dictionaries in config files.
|
||||
* Remove support for specifying config options on the command line, except
|
||||
for:
|
||||
|
||||
* ``--daemonize`` - Daemonize the home server.
|
||||
* ``--manhole`` - Turn on the twisted telnet manhole service on the given
|
||||
port.
|
||||
* ``--database-path`` - The path to a sqlite database to use.
|
||||
* ``--verbose`` - The verbosity level.
|
||||
* ``--log-file`` - File to log to.
|
||||
* ``--log-config`` - Python logging config file.
|
||||
* ``--enable-registration`` - Enable registration for new users.
|
||||
|
||||
Application services:
|
||||
|
||||
* Reliably retry sending of events from Synapse to application services, as per
|
||||
`Application Services`_ spec.
|
||||
* Application services can no longer register via the ``/register`` API,
|
||||
instead their configuration should be saved to a file and listed in the
|
||||
synapse ``app_service_config_files`` config option. The AS configuration file
|
||||
has the same format as the old ``/register`` request.
|
||||
See `docs/application_services.rst`_ for more information.
|
||||
|
||||
.. _`docs/postgres.rst`: docs/postgres.rst
|
||||
.. _`docs/application_services.rst`: docs/application_services.rst
|
||||
.. _`Registration`: https://github.com/matrix-org/matrix-doc/blob/master/specification/10_client_server_api.rst#registration
|
||||
.. _`Retrieving Server Keys`: https://github.com/matrix-org/matrix-doc/blob/6f2698/specification/30_server_server_api.rst#retrieving-server-keys
|
||||
.. _`Application Services`: https://github.com/matrix-org/matrix-doc/blob/0c6bd9/specification/25_application_service_api.rst#home-server---application-service-api
|
||||
|
||||
Changes in synapse v0.8.1 (2015-03-18)
|
||||
======================================
|
||||
|
||||
@@ -117,7 +117,7 @@ Installing prerequisites on Mac OS X::
|
||||
|
||||
To install the synapse homeserver run::
|
||||
|
||||
$ virtualenv -p python2.7 ~/.synapse
|
||||
$ virtualenv ~/.synapse
|
||||
$ source ~/.synapse/bin/activate
|
||||
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||
|
||||
@@ -318,7 +318,7 @@ ArchLinux
|
||||
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
||||
you will need to explicitly call Python2.7 - either running as::
|
||||
|
||||
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
||||
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid
|
||||
|
||||
...or by editing synctl with the correct python executable.
|
||||
|
||||
@@ -409,6 +409,7 @@ SRV record, as that is the name other machines will expect it to have::
|
||||
|
||||
$ python -m synapse.app.homeserver \
|
||||
--server-name YOURDOMAIN \
|
||||
--bind-port 8448 \
|
||||
--config-path homeserver.yaml \
|
||||
--generate-config
|
||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
Upgrading to v0.9.0
|
||||
Upgrading to v0.x.x
|
||||
===================
|
||||
|
||||
Application services have had a breaking API change in this version.
|
||||
|
||||
@@ -21,5 +21,3 @@ handlers:
|
||||
root:
|
||||
level: INFO
|
||||
handlers: [journal]
|
||||
|
||||
disable_existing_loggers: False
|
||||
|
||||
@@ -16,31 +16,30 @@ if [ $# -eq 1 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
export PYTHONPATH=$(readlink -f $(pwd))
|
||||
|
||||
|
||||
echo $PYTHONPATH
|
||||
|
||||
for port in 8080 8081 8082; do
|
||||
echo "Starting server on port $port... "
|
||||
|
||||
https_port=$((port + 400))
|
||||
mkdir -p demo/$port
|
||||
pushd demo/$port
|
||||
|
||||
#rm $DIR/etc/$port.config
|
||||
python -m synapse.app.homeserver \
|
||||
--generate-config \
|
||||
--enable_registration \
|
||||
--config-path "demo/etc/$port.config" \
|
||||
-p "$https_port" \
|
||||
--unsecure-port "$port" \
|
||||
-H "localhost:$https_port" \
|
||||
--config-path "$DIR/etc/$port.config" \
|
||||
-f "$DIR/$port.log" \
|
||||
-d "$DIR/$port.db" \
|
||||
-D --pid-file "$DIR/$port.pid" \
|
||||
--manhole $((port + 1000)) \
|
||||
--tls-dh-params-path "demo/demo.tls.dh" \
|
||||
--media-store-path "demo/media_store.$port" \
|
||||
$PARAMS $SYNAPSE_PARAMS \
|
||||
--enable-registration
|
||||
|
||||
python -m synapse.app.homeserver \
|
||||
--config-path "$DIR/etc/$port.config" \
|
||||
-D \
|
||||
--config-path "demo/etc/$port.config" \
|
||||
-vv \
|
||||
|
||||
popd
|
||||
done
|
||||
|
||||
cd "$CWD"
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
Registering an Application Service
|
||||
==================================
|
||||
|
||||
The registration of new application services depends on the homeserver used.
|
||||
In synapse, you need to create a new configuration file for your AS and add it
|
||||
to the list specified under the ``app_service_config_files`` config
|
||||
option in your synapse config.
|
||||
|
||||
For example:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
app_service_config_files:
|
||||
- /home/matrix/.synapse/<your-AS>.yaml
|
||||
|
||||
|
||||
The format of the AS configuration file is as follows:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
url: <base url of AS>
|
||||
as_token: <token AS will add to requests to HS>
|
||||
hs_token: <token HS will add to requests to AS>
|
||||
sender_localpart: <localpart of AS user>
|
||||
namespaces:
|
||||
users: # List of users we're interested in
|
||||
- exclusive: <bool>
|
||||
regex: <regex>
|
||||
- ...
|
||||
aliases: [] # List of aliases we're interested in
|
||||
rooms: [] # List of room ids we're interested in
|
||||
|
||||
See the spec_ for further details on how application services work.
|
||||
|
||||
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api
|
||||
|
||||
@@ -34,15 +34,19 @@ Synapse config
|
||||
When you are ready to start using PostgreSQL, add the following line to your
|
||||
config file::
|
||||
|
||||
database:
|
||||
name: psycopg2
|
||||
args:
|
||||
user: <user>
|
||||
password: <pass>
|
||||
database: <db>
|
||||
host: <host>
|
||||
cp_min: 5
|
||||
cp_max: 10
|
||||
database_config: <db_config_file>
|
||||
|
||||
Where ``<db_config_file>`` is the file name that points to a yaml file of the
|
||||
following form::
|
||||
|
||||
name: psycopg2
|
||||
args:
|
||||
user: <user>
|
||||
password: <pass>
|
||||
database: <db>
|
||||
host: <host>
|
||||
cp_min: 5
|
||||
cp_max: 10
|
||||
|
||||
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
|
||||
function, except keys beginning with ``cp_``, which are consumed by the twisted
|
||||
@@ -82,13 +86,13 @@ complete, restart synapse. For instance::
|
||||
cp homeserver.db homeserver.db.snapshot
|
||||
./synctl start
|
||||
|
||||
Assuming your new config file (as described in the section *Synapse config*)
|
||||
is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at
|
||||
Assuming your database config file (as described in the section *Synapse
|
||||
config*) is named ``database_config.yaml`` and the SQLite snapshot is at
|
||||
``homeserver.db.snapshot`` then simply run::
|
||||
|
||||
python scripts/port_from_sqlite_to_postgres.py \
|
||||
--sqlite-database homeserver.db.snapshot \
|
||||
--postgres-config homeserver-postgres.yaml
|
||||
--postgres-config database_config.yaml
|
||||
|
||||
The flag ``--curses`` displays a coloured curses progress UI.
|
||||
|
||||
|
||||
@@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret):
|
||||
).hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"username": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
}
|
||||
|
||||
server_location = server_location.rstrip("/")
|
||||
@@ -44,7 +43,7 @@ def request_registration(user, password, server_location, shared_secret):
|
||||
print "Sending registration request..."
|
||||
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||
"%s/_matrix/client/v2_alpha/register" % (server_location,),
|
||||
data=json.dumps(data),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
@@ -1,116 +0,0 @@
|
||||
import psycopg2
|
||||
import yaml
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
from syutil.base64util import encode_base64
|
||||
from syutil.crypto.signing_key import read_signing_keys
|
||||
from syutil.crypto.jsonsign import sign_json
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
|
||||
def select_v1_keys(connection):
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
|
||||
rows = cursor.fetchall()
|
||||
cursor.close()
|
||||
results = {}
|
||||
for server_name, key_id, verify_key in rows:
|
||||
results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
|
||||
return results
|
||||
|
||||
|
||||
def select_v1_certs(connection):
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
|
||||
rows = cursor.fetchall()
|
||||
cursor.close()
|
||||
results = {}
|
||||
for server_name, tls_certificate in rows:
|
||||
results[server_name] = tls_certificate
|
||||
return results
|
||||
|
||||
|
||||
def select_v2_json(connection):
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
|
||||
rows = cursor.fetchall()
|
||||
cursor.close()
|
||||
results = {}
|
||||
for server_name, key_id, key_json in rows:
|
||||
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
|
||||
return results
|
||||
|
||||
|
||||
def convert_v1_to_v2(server_name, valid_until, keys, certificate):
|
||||
return {
|
||||
"old_verify_keys": {},
|
||||
"server_name": server_name,
|
||||
"verify_keys": {
|
||||
key_id: {"key": key}
|
||||
for key_id, key in keys.items()
|
||||
},
|
||||
"valid_until_ts": valid_until,
|
||||
"tls_fingerprints": [fingerprint(certificate)],
|
||||
}
|
||||
|
||||
|
||||
def fingerprint(certificate):
|
||||
finger = hashlib.sha256(certificate)
|
||||
return {"sha256": encode_base64(finger.digest())}
|
||||
|
||||
|
||||
def rows_v2(server, json):
|
||||
valid_until = json["valid_until_ts"]
|
||||
key_json = encode_canonical_json(json)
|
||||
for key_id in json["verify_keys"]:
|
||||
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
|
||||
|
||||
|
||||
def main():
|
||||
config = yaml.load(open(sys.argv[1]))
|
||||
valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
|
||||
|
||||
server_name = config["server_name"]
|
||||
signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
|
||||
|
||||
database = config["database"]
|
||||
assert database["name"] == "psycopg2", "Can only convert for postgresql"
|
||||
args = database["args"]
|
||||
args.pop("cp_max")
|
||||
args.pop("cp_min")
|
||||
connection = psycopg2.connect(**args)
|
||||
keys = select_v1_keys(connection)
|
||||
certificates = select_v1_certs(connection)
|
||||
json = select_v2_json(connection)
|
||||
|
||||
result = {}
|
||||
for server in keys:
|
||||
if not server in json:
|
||||
v2_json = convert_v1_to_v2(
|
||||
server, valid_until, keys[server], certificates[server]
|
||||
)
|
||||
v2_json = sign_json(v2_json, server_name, signing_key)
|
||||
result[server] = v2_json
|
||||
|
||||
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
|
||||
|
||||
rows = list(
|
||||
row for server, json in result.items()
|
||||
for row in rows_v2(server, json)
|
||||
)
|
||||
|
||||
cursor = connection.cursor()
|
||||
cursor.executemany(
|
||||
"INSERT INTO server_keys_json ("
|
||||
" server_name, key_id, from_server,"
|
||||
" ts_added_ms, ts_valid_until_ms, key_json"
|
||||
") VALUES (%s, %s, %s, %s, %s, %s)",
|
||||
rows
|
||||
)
|
||||
connection.commit()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
10
scripts/port_from_sqlite_to_postgres.py
Executable file → Normal file
10
scripts/port_from_sqlite_to_postgres.py
Executable file → Normal file
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
@@ -106,7 +105,7 @@ class Store(object):
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
return func(
|
||||
LoggingTransaction(txn, desc, self.database_engine, []),
|
||||
LoggingTransaction(txn, desc, self.database_engine),
|
||||
*args, **kwargs
|
||||
)
|
||||
except self.database_engine.module.DatabaseError as e:
|
||||
@@ -378,7 +377,9 @@ class Porter(object):
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
rows[i] = tuple(
|
||||
conv(j, col)
|
||||
self.postgres_store.database_engine.encode_parameter(
|
||||
conv(j, col)
|
||||
)
|
||||
for j, col in enumerate(row)
|
||||
if j > 0
|
||||
)
|
||||
@@ -723,9 +724,6 @@ if __name__ == "__main__":
|
||||
|
||||
postgres_config = yaml.safe_load(args.postgres_config)
|
||||
|
||||
if "database" in postgres_config:
|
||||
postgres_config = postgres_config["database"]
|
||||
|
||||
if "name" not in postgres_config:
|
||||
sys.stderr.write("Malformed database config: no 'name'")
|
||||
sys.exit(2)
|
||||
|
||||
2
scripts/upgrade_db_to_v0.6.0.py
Executable file → Normal file
2
scripts/upgrade_db_to_v0.6.0.py
Executable file → Normal file
@@ -1,4 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from synapse.storage import SCHEMA_VERSION, read_schema
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.signatures import SignatureStore
|
||||
|
||||
3
setup.py
3
setup.py
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
@@ -56,5 +55,5 @@ setup(
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
long_description=long_description,
|
||||
scripts=["synctl"] + glob.glob("scripts/*"),
|
||||
scripts=["synctl", "register_new_matrix_user"],
|
||||
)
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.9.2"
|
||||
__version__ = "0.8.1-r4"
|
||||
|
||||
@@ -20,6 +20,7 @@ from twisted.internet import defer
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.types import UserID, ClientInfo
|
||||
|
||||
import logging
|
||||
@@ -64,10 +65,7 @@ class Auth(object):
|
||||
if event.type == EventTypes.Aliases:
|
||||
return True
|
||||
|
||||
logger.debug(
|
||||
"Auth events: %s",
|
||||
[a.event_id for a in auth_events.values()]
|
||||
)
|
||||
logger.debug("Auth events: %s", auth_events)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
allowed = self.is_membership_change_allowed(
|
||||
@@ -362,7 +360,7 @@ class Auth(object):
|
||||
default=[""]
|
||||
)[0]
|
||||
if user and access_token and ip_addr:
|
||||
self.store.insert_client_ip(
|
||||
yield self.store.insert_client_ip(
|
||||
user=user,
|
||||
access_token=access_token,
|
||||
device_id=user_info["device_id"],
|
||||
@@ -426,6 +424,8 @@ class Auth(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_auth_events(self, builder, context):
|
||||
yield run_on_reactor()
|
||||
|
||||
auth_ids = self.compute_auth_events(builder, context.current_state)
|
||||
|
||||
auth_events_entries = yield self.store.add_event_hashes(
|
||||
|
||||
@@ -32,9 +32,9 @@ from synapse.server import HomeServer
|
||||
from twisted.internet import reactor
|
||||
from twisted.application import service
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.web.resource import Resource, EncodingResourceWrapper
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.static import File
|
||||
from twisted.web.server import Site, GzipEncoderFactory
|
||||
from twisted.web.server import Site
|
||||
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
|
||||
from synapse.http.server import JsonResource, RootRedirect
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
@@ -54,8 +54,6 @@ from synapse.rest.client.v1 import ClientV1RestResource
|
||||
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
|
||||
from synapse import events
|
||||
|
||||
from daemonize import Daemonize
|
||||
import twisted.manhole.telnet
|
||||
|
||||
@@ -71,32 +69,16 @@ import subprocess
|
||||
logger = logging.getLogger("synapse.app.homeserver")
|
||||
|
||||
|
||||
class GzipFile(File):
|
||||
def getChild(self, path, request):
|
||||
child = File.getChild(self, path, request)
|
||||
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
|
||||
|
||||
|
||||
def gz_wrap(r):
|
||||
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
|
||||
|
||||
|
||||
class SynapseHomeServer(HomeServer):
|
||||
|
||||
def build_http_client(self):
|
||||
return MatrixFederationHttpClient(self)
|
||||
|
||||
def build_resource_for_client(self):
|
||||
res = ClientV1RestResource(self)
|
||||
if self.config.gzip_responses:
|
||||
res = gz_wrap(res)
|
||||
return res
|
||||
return ClientV1RestResource(self)
|
||||
|
||||
def build_resource_for_client_v2_alpha(self):
|
||||
res = ClientV2AlphaRestResource(self)
|
||||
if self.config.gzip_responses:
|
||||
res = gz_wrap(res)
|
||||
return res
|
||||
return ClientV2AlphaRestResource(self)
|
||||
|
||||
def build_resource_for_federation(self):
|
||||
return JsonResource(self)
|
||||
@@ -105,16 +87,9 @@ class SynapseHomeServer(HomeServer):
|
||||
import syweb
|
||||
syweb_path = os.path.dirname(syweb.__file__)
|
||||
webclient_path = os.path.join(syweb_path, "webclient")
|
||||
# GZip is disabled here due to
|
||||
# https://twistedmatrix.com/trac/ticket/7678
|
||||
# (It can stay enabled for the API resources: they call
|
||||
# write() with the whole body and then finish() straight
|
||||
# after and so do not trigger the bug.
|
||||
# return GzipFile(webclient_path) # TODO configurable?
|
||||
return File(webclient_path) # TODO configurable?
|
||||
|
||||
def build_resource_for_static_content(self):
|
||||
# This is old and should go away: not going to bother adding gzip
|
||||
return File("static")
|
||||
|
||||
def build_resource_for_content_repo(self):
|
||||
@@ -285,12 +260,9 @@ class SynapseHomeServer(HomeServer):
|
||||
config,
|
||||
metrics_resource,
|
||||
),
|
||||
interface=config.metrics_bind_host,
|
||||
)
|
||||
logger.info(
|
||||
"Metrics now running on %s port %d",
|
||||
config.metrics_bind_host, config.metrics_port,
|
||||
interface="127.0.0.1",
|
||||
)
|
||||
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
|
||||
|
||||
def run_startup_checks(self, db_conn, database_engine):
|
||||
all_users_native = are_all_users_on_domain(
|
||||
@@ -423,8 +395,6 @@ def setup(config_options):
|
||||
logger.info("Server hostname: %s", config.server_name)
|
||||
logger.info("Server version: %s", version_string)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
if re.search(":[0-9]+$", config.server_name):
|
||||
domain_with_port = config.server_name
|
||||
else:
|
||||
@@ -439,6 +409,7 @@ def setup(config_options):
|
||||
config.server_name,
|
||||
domain_with_port=domain_with_port,
|
||||
upload_dir=os.path.abspath("uploads"),
|
||||
db_name=config.database_path,
|
||||
db_config=config.database_config,
|
||||
tls_context_factory=tls_context_factory,
|
||||
config=config,
|
||||
@@ -451,7 +422,9 @@ def setup(config_options):
|
||||
redirect_root_to_web_client=True,
|
||||
)
|
||||
|
||||
logger.info("Preparing database: %r...", config.database_config)
|
||||
db_name = hs.get_db_name()
|
||||
|
||||
logger.info("Preparing database: %s...", db_name)
|
||||
|
||||
try:
|
||||
db_conn = database_engine.module.connect(
|
||||
@@ -473,7 +446,7 @@ def setup(config_options):
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("Database prepared in %r.", config.database_config)
|
||||
logger.info("Database prepared in %s.", db_name)
|
||||
|
||||
if config.manhole:
|
||||
f = twisted.manhole.telnet.ShellFactory()
|
||||
@@ -526,31 +499,11 @@ class SynapseSite(Site):
|
||||
|
||||
|
||||
def run(hs):
|
||||
PROFILE_SYNAPSE = False
|
||||
if PROFILE_SYNAPSE:
|
||||
def profile(func):
|
||||
from cProfile import Profile
|
||||
from threading import current_thread
|
||||
|
||||
def profiled(*args, **kargs):
|
||||
profile = Profile()
|
||||
profile.enable()
|
||||
func(*args, **kargs)
|
||||
profile.disable()
|
||||
ident = current_thread().ident
|
||||
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
|
||||
hs.hostname, func.__name__, ident
|
||||
))
|
||||
|
||||
return profiled
|
||||
|
||||
from twisted.python.threadpool import ThreadPool
|
||||
ThreadPool._worker = profile(ThreadPool._worker)
|
||||
reactor.run = profile(reactor.run)
|
||||
|
||||
def in_thread():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(hs.config.soft_file_limit)
|
||||
|
||||
reactor.run()
|
||||
|
||||
if hs.config.daemonize:
|
||||
|
||||
@@ -18,33 +18,29 @@ import sys
|
||||
import os
|
||||
import subprocess
|
||||
import signal
|
||||
import yaml
|
||||
|
||||
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
|
||||
|
||||
CONFIGFILE = "homeserver.yaml"
|
||||
PIDFILE = "homeserver.pid"
|
||||
|
||||
GREEN = "\x1b[1;32m"
|
||||
NORMAL = "\x1b[m"
|
||||
|
||||
if not os.path.exists(CONFIGFILE):
|
||||
sys.stderr.write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), CONFIGFILE
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
CONFIG = yaml.load(open(CONFIGFILE))
|
||||
PIDFILE = CONFIG["pid_file"]
|
||||
|
||||
|
||||
def start():
|
||||
if not os.path.exists(CONFIGFILE):
|
||||
sys.stderr.write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), CONFIGFILE
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
print "Starting ...",
|
||||
args = SYNAPSE
|
||||
args.extend(["--daemonize", "-c", CONFIGFILE])
|
||||
args.extend(["--daemonize", "-c", CONFIGFILE, "--pid-file", PIDFILE])
|
||||
subprocess.check_call(args)
|
||||
print GREEN + "started" + NORMAL
|
||||
|
||||
|
||||
@@ -148,8 +148,8 @@ class ApplicationService(object):
|
||||
and self.is_interested_in_user(event.state_key)):
|
||||
return True
|
||||
# check joined member events
|
||||
for user_id in member_list:
|
||||
if self.is_interested_in_user(user_id):
|
||||
for member in member_list:
|
||||
if self.is_interested_in_user(member.state_key):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -173,7 +173,7 @@ class ApplicationService(object):
|
||||
restrict_to(str): The namespace to restrict regex tests to.
|
||||
aliases_for_event(list): A list of all the known room aliases for
|
||||
this event.
|
||||
member_list(list): A list of all joined user_ids in this room.
|
||||
member_list(list): A list of all joined room members in this room.
|
||||
Returns:
|
||||
bool: True if this service would like to know about this event.
|
||||
"""
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import yaml
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
@@ -25,35 +24,18 @@ class ConfigError(Exception):
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self, args):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def parse_size(value):
|
||||
if isinstance(value, int) or isinstance(value, long):
|
||||
return value
|
||||
def parse_size(string):
|
||||
sizes = {"K": 1024, "M": 1024 * 1024}
|
||||
size = 1
|
||||
suffix = value[-1]
|
||||
suffix = string[-1]
|
||||
if suffix in sizes:
|
||||
value = value[:-1]
|
||||
string = string[:-1]
|
||||
size = sizes[suffix]
|
||||
return int(value) * size
|
||||
|
||||
@staticmethod
|
||||
def parse_duration(value):
|
||||
if isinstance(value, int) or isinstance(value, long):
|
||||
return value
|
||||
second = 1000
|
||||
hour = 60 * 60 * second
|
||||
day = 24 * hour
|
||||
week = 7 * day
|
||||
year = 365 * day
|
||||
sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
|
||||
size = 1
|
||||
suffix = value[-1]
|
||||
if suffix in sizes:
|
||||
value = value[:-1]
|
||||
size = sizes[suffix]
|
||||
return int(value) * size
|
||||
return int(string) * size
|
||||
|
||||
@staticmethod
|
||||
def abspath(file_path):
|
||||
@@ -95,6 +77,17 @@ class Config(object):
|
||||
with open(file_path) as file_stream:
|
||||
return file_stream.read()
|
||||
|
||||
@classmethod
|
||||
def read_yaml_file(cls, file_path, config_name):
|
||||
cls.check_file(file_path, config_name)
|
||||
with open(file_path) as file_stream:
|
||||
try:
|
||||
return yaml.load(file_stream)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Error parsing yaml in file %r" % (file_path,)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def default_path(name):
|
||||
return os.path.abspath(os.path.join(os.path.curdir, name))
|
||||
@@ -104,130 +97,84 @@ class Config(object):
|
||||
with open(file_path) as file_stream:
|
||||
return yaml.load(file_stream)
|
||||
|
||||
def invoke_all(self, name, *args, **kargs):
|
||||
results = []
|
||||
for cls in type(self).mro():
|
||||
if name in cls.__dict__:
|
||||
results.append(getattr(cls, name)(self, *args, **kargs))
|
||||
return results
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
pass
|
||||
|
||||
def generate_config(self, config_dir_path, server_name):
|
||||
default_config = "# vim:ft=yaml\n"
|
||||
|
||||
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
|
||||
"default_config", config_dir_path, server_name
|
||||
))
|
||||
|
||||
config = yaml.load(default_config)
|
||||
|
||||
return default_config, config
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, description, argv, generate_section=None):
|
||||
obj = cls()
|
||||
|
||||
config_parser = argparse.ArgumentParser(add_help=False)
|
||||
config_parser.add_argument(
|
||||
"-c", "--config-path",
|
||||
action="append",
|
||||
metavar="CONFIG_FILE",
|
||||
help="Specify config file"
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"--generate-config",
|
||||
action="store_true",
|
||||
help="Generate a config file for the server name"
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"-H", "--server-name",
|
||||
help="The server name to generate a config file for"
|
||||
help="Generate config file"
|
||||
)
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
if config_args.generate_config:
|
||||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -h SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
"Must specify where to generate the config file"
|
||||
)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
server_name = config_args.server_name
|
||||
if not server_name:
|
||||
print "Must specify a server_name to a generate config for."
|
||||
sys.exit(1)
|
||||
(config_path,) = config_args.config_path
|
||||
if not os.path.exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
if os.path.exists(config_path):
|
||||
print "Config file %r already exists" % (config_path,)
|
||||
yaml_config = cls.read_config_file(config_path)
|
||||
yaml_name = yaml_config["server_name"]
|
||||
if server_name != yaml_name:
|
||||
print (
|
||||
"Config file %r has a different server_name: "
|
||||
" %r != %r" % (config_path, server_name, yaml_name)
|
||||
)
|
||||
sys.exit(1)
|
||||
config_bytes, config = obj.generate_config(
|
||||
config_dir_path, server_name
|
||||
)
|
||||
config.update(yaml_config)
|
||||
print "Generating any missing keys for %r" % (server_name,)
|
||||
obj.invoke_all("generate_files", config)
|
||||
sys.exit(0)
|
||||
with open(config_path, "wb") as config_file:
|
||||
config_bytes, config = obj.generate_config(
|
||||
config_dir_path, server_name
|
||||
)
|
||||
obj.invoke_all("generate_files", config)
|
||||
config_file.write(config_bytes)
|
||||
print (
|
||||
"A config file has been generated in %s for server name"
|
||||
" '%s' with corresponding SSL keys and self-signed"
|
||||
" certificates. Please review this file and customise it to"
|
||||
" your needs."
|
||||
) % (config_path, server_name)
|
||||
print (
|
||||
"If this server name is incorrect, you will need to regenerate"
|
||||
" the SSL certificates"
|
||||
)
|
||||
sys.exit(0)
|
||||
config_dir_path = os.path.dirname(config_args.config_path)
|
||||
if os.path.exists(config_args.config_path):
|
||||
defaults = cls.read_config_file(config_args.config_path)
|
||||
else:
|
||||
defaults = {}
|
||||
else:
|
||||
if config_args.config_path:
|
||||
defaults = cls.read_config_file(config_args.config_path)
|
||||
else:
|
||||
defaults = {}
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
parents=[config_parser],
|
||||
description=description,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
cls.add_arguments(parser)
|
||||
parser.set_defaults(**defaults)
|
||||
|
||||
obj.invoke_all("add_arguments", parser)
|
||||
args = parser.parse_args(remaining_args)
|
||||
|
||||
if not config_args.config_path:
|
||||
config_parser.error(
|
||||
"Must supply a config file.\nA config file can be automatically"
|
||||
" generated using \"--generate-config -h SERVER_NAME"
|
||||
" -c CONFIG-FILE\""
|
||||
if config_args.generate_config:
|
||||
config_dir_path = os.path.dirname(config_args.config_path)
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
if not os.path.exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
cls.generate_config(args, config_dir_path)
|
||||
config = {}
|
||||
for key, value in vars(args).items():
|
||||
if (key not in set(["config_path", "generate_config"])
|
||||
and value is not None):
|
||||
config[key] = value
|
||||
with open(config_args.config_path, "w") as config_file:
|
||||
# TODO(mark/paul) We might want to output emacs-style mode
|
||||
# markers as well as vim-style mode markers into the file,
|
||||
# to further hint to people this is a YAML file.
|
||||
config_file.write("# vim:ft=yaml\n")
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
print (
|
||||
"A config file has been generated in %s for server name"
|
||||
" '%s' with corresponding SSL keys and self-signed"
|
||||
" certificates. Please review this file and customise it to"
|
||||
" your needs."
|
||||
) % (
|
||||
config_args.config_path, config['server_name']
|
||||
)
|
||||
print (
|
||||
"If this server name is incorrect, you will need to regenerate"
|
||||
" the SSL certificates"
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
|
||||
specified_config = {}
|
||||
for config_path in config_args.config_path:
|
||||
yaml_config = cls.read_config_file(config_path)
|
||||
specified_config.update(yaml_config)
|
||||
|
||||
server_name = specified_config["server_name"]
|
||||
_, config = obj.generate_config(config_dir_path, server_name)
|
||||
config.pop("log_config")
|
||||
config.update(specified_config)
|
||||
|
||||
obj.invoke_all("read_config", config)
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
return cls(args)
|
||||
|
||||
@@ -17,11 +17,15 @@ from ._base import Config
|
||||
|
||||
class AppServiceConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.app_service_config_files = config.get("app_service_config_files", [])
|
||||
def __init__(self, args):
|
||||
super(AppServiceConfig, self).__init__(args)
|
||||
self.app_service_config_files = args.app_service_config_files
|
||||
|
||||
def default_config(cls, config_dir_path, server_name):
|
||||
return """\
|
||||
# A list of application service config file to use
|
||||
app_service_config_files: []
|
||||
"""
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(AppServiceConfig, cls).add_arguments(parser)
|
||||
group = parser.add_argument_group("appservice")
|
||||
group.add_argument(
|
||||
"--app-service-config-files", type=str, nargs='+',
|
||||
help="A list of application service config files to use."
|
||||
)
|
||||
|
||||
@@ -17,39 +17,42 @@ from ._base import Config
|
||||
|
||||
class CaptchaConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.recaptcha_private_key = config["recaptcha_private_key"]
|
||||
self.recaptcha_public_key = config["recaptcha_public_key"]
|
||||
self.enable_registration_captcha = config["enable_registration_captcha"]
|
||||
def __init__(self, args):
|
||||
super(CaptchaConfig, self).__init__(args)
|
||||
self.recaptcha_private_key = args.recaptcha_private_key
|
||||
self.recaptcha_public_key = args.recaptcha_public_key
|
||||
self.enable_registration_captcha = args.enable_registration_captcha
|
||||
|
||||
# XXX: This is used for more than just captcha
|
||||
self.captcha_ip_origin_is_x_forwarded = (
|
||||
config["captcha_ip_origin_is_x_forwarded"]
|
||||
args.captcha_ip_origin_is_x_forwarded
|
||||
)
|
||||
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
|
||||
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
|
||||
self.captcha_bypass_secret = args.captcha_bypass_secret
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
return """\
|
||||
## Captcha ##
|
||||
|
||||
# This Home Server's ReCAPTCHA public key.
|
||||
recaptcha_private_key: "YOUR_PUBLIC_KEY"
|
||||
|
||||
# This Home Server's ReCAPTCHA private key.
|
||||
recaptcha_public_key: "YOUR_PRIVATE_KEY"
|
||||
|
||||
# Enables ReCaptcha checks when registering, preventing signup
|
||||
# unless a captcha is answered. Requires a valid ReCaptcha
|
||||
# public/private key.
|
||||
enable_registration_captcha: False
|
||||
|
||||
# When checking captchas, use the X-Forwarded-For (XFF) header
|
||||
# as the client IP and not the actual client IP.
|
||||
captcha_ip_origin_is_x_forwarded: False
|
||||
|
||||
# A secret key used to bypass the captcha test entirely.
|
||||
#captcha_bypass_secret: "YOUR_SECRET_HERE"
|
||||
|
||||
# The API endpoint to use for verifying m.login.recaptcha responses.
|
||||
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
|
||||
"""
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(CaptchaConfig, cls).add_arguments(parser)
|
||||
group = parser.add_argument_group("recaptcha")
|
||||
group.add_argument(
|
||||
"--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY",
|
||||
help="This Home Server's ReCAPTCHA public key."
|
||||
)
|
||||
group.add_argument(
|
||||
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
|
||||
help="This Home Server's ReCAPTCHA private key."
|
||||
)
|
||||
group.add_argument(
|
||||
"--enable-registration-captcha", type=bool, default=False,
|
||||
help="Enables ReCaptcha checks when registering, preventing signup"
|
||||
+ " unless a captcha is answered. Requires a valid ReCaptcha "
|
||||
+ "public/private key."
|
||||
)
|
||||
group.add_argument(
|
||||
"--captcha_ip_origin_is_x_forwarded", type=bool, default=False,
|
||||
help="When checking captchas, use the X-Forwarded-For (XFF) header"
|
||||
+ " as the client IP and not the actual client IP."
|
||||
)
|
||||
group.add_argument(
|
||||
"--captcha_bypass_secret", type=str,
|
||||
help="A secret key used to bypass the captcha test entirely."
|
||||
)
|
||||
|
||||
@@ -14,21 +14,28 @@
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
class DatabaseConfig(Config):
|
||||
def __init__(self, args):
|
||||
super(DatabaseConfig, self).__init__(args)
|
||||
if args.database_path == ":memory:":
|
||||
self.database_path = ":memory:"
|
||||
else:
|
||||
self.database_path = self.abspath(args.database_path)
|
||||
self.event_cache_size = self.parse_size(args.event_cache_size)
|
||||
|
||||
def read_config(self, config):
|
||||
self.event_cache_size = self.parse_size(
|
||||
config.get("event_cache_size", "10K")
|
||||
)
|
||||
|
||||
self.database_config = config.get("database")
|
||||
|
||||
if self.database_config is None:
|
||||
if args.database_config:
|
||||
with open(args.database_config) as f:
|
||||
self.database_config = yaml.safe_load(f)
|
||||
else:
|
||||
self.database_config = {
|
||||
"name": "sqlite3",
|
||||
"args": {},
|
||||
"args": {
|
||||
"database": self.database_path,
|
||||
},
|
||||
}
|
||||
|
||||
name = self.database_config.get("name", None)
|
||||
@@ -43,37 +50,24 @@ class DatabaseConfig(Config):
|
||||
else:
|
||||
raise RuntimeError("Unsupported database type '%s'" % (name,))
|
||||
|
||||
self.set_databasepath(config.get("database_path"))
|
||||
|
||||
def default_config(self, config, config_dir_path):
|
||||
database_path = self.abspath("homeserver.db")
|
||||
return """\
|
||||
# Database configuration
|
||||
database:
|
||||
# The database engine name
|
||||
name: "sqlite3"
|
||||
# Arguments to pass to the engine
|
||||
args:
|
||||
# Path to the database
|
||||
database: "%(database_path)s"
|
||||
|
||||
# Number of events to cache in memory.
|
||||
event_cache_size: "10K"
|
||||
""" % locals()
|
||||
|
||||
def read_arguments(self, args):
|
||||
self.set_databasepath(args.database_path)
|
||||
|
||||
def set_databasepath(self, database_path):
|
||||
if database_path != ":memory:":
|
||||
database_path = self.abspath(database_path)
|
||||
if self.database_config.get("name", None) == "sqlite3":
|
||||
if database_path is not None:
|
||||
self.database_config["args"]["database"] = database_path
|
||||
|
||||
def add_arguments(self, parser):
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(DatabaseConfig, cls).add_arguments(parser)
|
||||
db_group = parser.add_argument_group("database")
|
||||
db_group.add_argument(
|
||||
"-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
|
||||
help="The path to a sqlite database to use."
|
||||
"-d", "--database-path", default="homeserver.db",
|
||||
metavar="SQLITE_DATABASE_PATH", help="The database name."
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--event-cache-size", default="100K",
|
||||
help="Number of events to cache in memory."
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--database-config", default=None,
|
||||
help="Location of the database configuration file."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
super(DatabaseConfig, cls).generate_config(args, config_dir_path)
|
||||
args.database_path = os.path.abspath(args.database_path)
|
||||
|
||||
@@ -36,6 +36,4 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.stdout.write(
|
||||
HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
|
||||
)
|
||||
HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer")
|
||||
|
||||
@@ -20,58 +20,48 @@ from syutil.crypto.signing_key import (
|
||||
is_signing_algorithm_supported, decode_verify_key_bytes
|
||||
)
|
||||
from syutil.base64util import decode_base64
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
||||
class KeyConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.signing_key = self.read_signing_key(config["signing_key_path"])
|
||||
def __init__(self, args):
|
||||
super(KeyConfig, self).__init__(args)
|
||||
self.signing_key = self.read_signing_key(args.signing_key_path)
|
||||
self.old_signing_keys = self.read_old_signing_keys(
|
||||
config["old_signing_keys"]
|
||||
)
|
||||
self.key_refresh_interval = self.parse_duration(
|
||||
config["key_refresh_interval"]
|
||||
args.old_signing_key_path
|
||||
)
|
||||
self.key_refresh_interval = args.key_refresh_interval
|
||||
self.perspectives = self.read_perspectives(
|
||||
config["perspectives"]
|
||||
args.perspectives_config_path
|
||||
)
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
return """\
|
||||
## Signing Keys ##
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(KeyConfig, cls).add_arguments(parser)
|
||||
key_group = parser.add_argument_group("keys")
|
||||
key_group.add_argument("--signing-key-path",
|
||||
help="The signing key to sign messages with")
|
||||
key_group.add_argument("--old-signing-key-path",
|
||||
help="The keys that the server used to sign"
|
||||
" sign messages with but won't use"
|
||||
" to sign new messages. E.g. it has"
|
||||
" lost its private key")
|
||||
key_group.add_argument("--key-refresh-interval",
|
||||
default=24 * 60 * 60 * 1000, # 1 Day
|
||||
help="How long a key response is valid for."
|
||||
" Used to set the exipiry in /key/v2/."
|
||||
" Controls how frequently servers will"
|
||||
" query what keys are still valid")
|
||||
key_group.add_argument("--perspectives-config-path",
|
||||
help="The trusted servers to download signing"
|
||||
" keys from")
|
||||
|
||||
# Path to the signing key to sign messages with
|
||||
signing_key_path: "%(base_key_name)s.signing.key"
|
||||
|
||||
# The keys that the server used to sign messages with but won't use
|
||||
# to sign new messages. E.g. it has lost its private key
|
||||
old_signing_keys: {}
|
||||
# "ed25519:auto":
|
||||
# # Base64 encoded public key
|
||||
# key: "The public part of your old signing key."
|
||||
# # Millisecond POSIX timestamp when the key expired.
|
||||
# expired_ts: 123456789123
|
||||
|
||||
# How long key response published by this server is valid for.
|
||||
# Used to set the valid_until_ts in /key/v2 APIs.
|
||||
# Determines how quickly servers will query to check which keys
|
||||
# are still valid.
|
||||
key_refresh_interval: "1d" # 1 Day.
|
||||
|
||||
# The trusted servers to download signing keys from.
|
||||
perspectives:
|
||||
servers:
|
||||
"matrix.org":
|
||||
verify_keys:
|
||||
"ed25519:auto":
|
||||
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
|
||||
""" % locals()
|
||||
|
||||
def read_perspectives(self, perspectives_config):
|
||||
def read_perspectives(self, perspectives_config_path):
|
||||
config = self.read_yaml_file(
|
||||
perspectives_config_path, "perspectives_config_path"
|
||||
)
|
||||
servers = {}
|
||||
for server_name, server_config in perspectives_config["servers"].items():
|
||||
for server_name, server_config in config["servers"].items():
|
||||
for key_id, key_data in server_config["verify_keys"].items():
|
||||
if is_signing_algorithm_supported(key_id):
|
||||
key_base64 = key_data["key"]
|
||||
@@ -92,42 +82,66 @@ class KeyConfig(Config):
|
||||
" Try running again with --generate-config"
|
||||
)
|
||||
|
||||
def read_old_signing_keys(self, old_signing_keys):
|
||||
keys = {}
|
||||
for key_id, key_data in old_signing_keys.items():
|
||||
if is_signing_algorithm_supported(key_id):
|
||||
key_base64 = key_data["key"]
|
||||
key_bytes = decode_base64(key_base64)
|
||||
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
||||
verify_key.expired_ts = key_data["expired_ts"]
|
||||
keys[key_id] = verify_key
|
||||
else:
|
||||
raise ConfigError(
|
||||
"Unsupported signing algorithm for old key: %r" % (key_id,)
|
||||
)
|
||||
return keys
|
||||
def read_old_signing_keys(self, old_signing_key_path):
|
||||
old_signing_keys = self.read_file(
|
||||
old_signing_key_path, "old_signing_key"
|
||||
)
|
||||
try:
|
||||
return syutil.crypto.signing_key.read_old_signing_keys(
|
||||
old_signing_keys.splitlines(True)
|
||||
)
|
||||
except Exception:
|
||||
raise ConfigError(
|
||||
"Error reading old signing keys."
|
||||
)
|
||||
|
||||
def generate_files(self, config):
|
||||
signing_key_path = config["signing_key_path"]
|
||||
if not os.path.exists(signing_key_path):
|
||||
with open(signing_key_path, "w") as signing_key_file:
|
||||
key_id = "a_" + random_string(4)
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
super(KeyConfig, cls).generate_config(args, config_dir_path)
|
||||
base_key_name = os.path.join(config_dir_path, args.server_name)
|
||||
|
||||
args.pid_file = os.path.abspath(args.pid_file)
|
||||
|
||||
if not args.signing_key_path:
|
||||
args.signing_key_path = base_key_name + ".signing.key"
|
||||
|
||||
if not os.path.exists(args.signing_key_path):
|
||||
with open(args.signing_key_path, "w") as signing_key_file:
|
||||
syutil.crypto.signing_key.write_signing_keys(
|
||||
signing_key_file,
|
||||
(syutil.crypto.signing_key.generate_signing_key(key_id),),
|
||||
(syutil.crypto.signing_key.generate_signing_key("auto"),),
|
||||
)
|
||||
else:
|
||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
|
||||
if len(signing_keys.split("\n")[0].split()) == 1:
|
||||
# handle keys in the old format.
|
||||
key_id = "a_" + random_string(4)
|
||||
key = syutil.crypto.signing_key.decode_signing_key_base64(
|
||||
syutil.crypto.signing_key.NACL_ED25519,
|
||||
key_id,
|
||||
"auto",
|
||||
signing_keys.split("\n")[0]
|
||||
)
|
||||
with open(signing_key_path, "w") as signing_key_file:
|
||||
with open(args.signing_key_path, "w") as signing_key_file:
|
||||
syutil.crypto.signing_key.write_signing_keys(
|
||||
signing_key_file,
|
||||
(key,),
|
||||
)
|
||||
|
||||
if not args.old_signing_key_path:
|
||||
args.old_signing_key_path = base_key_name + ".old.signing.keys"
|
||||
|
||||
if not os.path.exists(args.old_signing_key_path):
|
||||
with open(args.old_signing_key_path, "w"):
|
||||
pass
|
||||
|
||||
if not args.perspectives_config_path:
|
||||
args.perspectives_config_path = base_key_name + ".perspectives"
|
||||
|
||||
if not os.path.exists(args.perspectives_config_path):
|
||||
with open(args.perspectives_config_path, "w") as perspectives_file:
|
||||
perspectives_file.write(
|
||||
'servers:\n'
|
||||
' matrix.org:\n'
|
||||
' verify_keys:\n'
|
||||
' "ed25519:auto":\n'
|
||||
' key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"\n'
|
||||
)
|
||||
|
||||
@@ -19,88 +19,25 @@ from twisted.python.log import PythonLoggingObserver
|
||||
import logging
|
||||
import logging.config
|
||||
import yaml
|
||||
from string import Template
|
||||
import os
|
||||
|
||||
|
||||
DEFAULT_LOG_CONFIG = Template("""
|
||||
version: 1
|
||||
|
||||
formatters:
|
||||
precise:
|
||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
|
||||
- %(message)s'
|
||||
|
||||
filters:
|
||||
context:
|
||||
(): synapse.util.logcontext.LoggingContextFilter
|
||||
request: ""
|
||||
|
||||
handlers:
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: ${log_file}
|
||||
maxBytes: 104857600
|
||||
backupCount: 10
|
||||
filters: [context]
|
||||
level: INFO
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
|
||||
loggers:
|
||||
synapse:
|
||||
level: INFO
|
||||
|
||||
synapse.storage.SQL:
|
||||
level: INFO
|
||||
|
||||
root:
|
||||
level: INFO
|
||||
handlers: [file, console]
|
||||
""")
|
||||
|
||||
|
||||
class LoggingConfig(Config):
|
||||
def __init__(self, args):
|
||||
super(LoggingConfig, self).__init__(args)
|
||||
self.verbosity = int(args.verbose) if args.verbose else None
|
||||
self.log_config = self.abspath(args.log_config)
|
||||
self.log_file = self.abspath(args.log_file)
|
||||
|
||||
def read_config(self, config):
|
||||
self.verbosity = config.get("verbose", 0)
|
||||
self.log_config = self.abspath(config.get("log_config"))
|
||||
self.log_file = self.abspath(config.get("log_file"))
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
log_file = self.abspath("homeserver.log")
|
||||
log_config = self.abspath(
|
||||
os.path.join(config_dir_path, server_name + ".log.config")
|
||||
)
|
||||
return """
|
||||
# Logging verbosity level.
|
||||
verbose: 0
|
||||
|
||||
# File to write logging to
|
||||
log_file: "%(log_file)s"
|
||||
|
||||
# A yaml python logging config file
|
||||
log_config: "%(log_config)s"
|
||||
""" % locals()
|
||||
|
||||
def read_arguments(self, args):
|
||||
if args.verbose is not None:
|
||||
self.verbosity = args.verbose
|
||||
if args.log_config is not None:
|
||||
self.log_config = args.log_config
|
||||
if args.log_file is not None:
|
||||
self.log_file = args.log_file
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(LoggingConfig, cls).add_arguments(parser)
|
||||
logging_group = parser.add_argument_group("logging")
|
||||
logging_group.add_argument(
|
||||
'-v', '--verbose', dest="verbose", action='count',
|
||||
help="The verbosity level."
|
||||
)
|
||||
logging_group.add_argument(
|
||||
'-f', '--log-file', dest="log_file",
|
||||
'-f', '--log-file', dest="log_file", default="homeserver.log",
|
||||
help="File to log to."
|
||||
)
|
||||
logging_group.add_argument(
|
||||
@@ -108,14 +45,6 @@ class LoggingConfig(Config):
|
||||
help="Python logging config file"
|
||||
)
|
||||
|
||||
def generate_files(self, config):
|
||||
log_config = config.get("log_config")
|
||||
if log_config and not os.path.exists(log_config):
|
||||
with open(log_config, "wb") as log_config_file:
|
||||
log_config_file.write(
|
||||
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
|
||||
)
|
||||
|
||||
def setup_logging(self):
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
|
||||
@@ -17,21 +17,20 @@ from ._base import Config
|
||||
|
||||
|
||||
class MetricsConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.enable_metrics = config["enable_metrics"]
|
||||
self.metrics_port = config.get("metrics_port")
|
||||
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
|
||||
def __init__(self, args):
|
||||
super(MetricsConfig, self).__init__(args)
|
||||
self.enable_metrics = args.enable_metrics
|
||||
self.metrics_port = args.metrics_port
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
return """\
|
||||
## Metrics ###
|
||||
|
||||
# Enable collection and rendering of performance metrics
|
||||
enable_metrics: False
|
||||
|
||||
# Separate port to accept metrics requests on
|
||||
# metrics_port: 8081
|
||||
|
||||
# Which host to bind the metric listener to
|
||||
# metrics_bind_host: 127.0.0.1
|
||||
"""
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(MetricsConfig, cls).add_arguments(parser)
|
||||
metrics_group = parser.add_argument_group("metrics")
|
||||
metrics_group.add_argument(
|
||||
'--enable-metrics', dest="enable_metrics", action="store_true",
|
||||
help="Enable collection and rendering of performance metrics"
|
||||
)
|
||||
metrics_group.add_argument(
|
||||
'--metrics-port', metavar="PORT", type=int,
|
||||
help="Separate port to accept metrics requests on (on localhost)"
|
||||
)
|
||||
|
||||
@@ -17,42 +17,56 @@ from ._base import Config
|
||||
|
||||
class RatelimitConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.rc_messages_per_second = config["rc_messages_per_second"]
|
||||
self.rc_message_burst_count = config["rc_message_burst_count"]
|
||||
def __init__(self, args):
|
||||
super(RatelimitConfig, self).__init__(args)
|
||||
self.rc_messages_per_second = args.rc_messages_per_second
|
||||
self.rc_message_burst_count = args.rc_message_burst_count
|
||||
|
||||
self.federation_rc_window_size = config["federation_rc_window_size"]
|
||||
self.federation_rc_sleep_limit = config["federation_rc_sleep_limit"]
|
||||
self.federation_rc_sleep_delay = config["federation_rc_sleep_delay"]
|
||||
self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
|
||||
self.federation_rc_concurrent = config["federation_rc_concurrent"]
|
||||
self.federation_rc_window_size = args.federation_rc_window_size
|
||||
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
|
||||
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
|
||||
self.federation_rc_reject_limit = args.federation_rc_reject_limit
|
||||
self.federation_rc_concurrent = args.federation_rc_concurrent
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
return """\
|
||||
## Ratelimiting ##
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(RatelimitConfig, cls).add_arguments(parser)
|
||||
rc_group = parser.add_argument_group("ratelimiting")
|
||||
rc_group.add_argument(
|
||||
"--rc-messages-per-second", type=float, default=0.2,
|
||||
help="number of messages a client can send per second"
|
||||
)
|
||||
rc_group.add_argument(
|
||||
"--rc-message-burst-count", type=float, default=10,
|
||||
help="number of message a client can send before being throttled"
|
||||
)
|
||||
|
||||
# Number of messages a client can send per second
|
||||
rc_messages_per_second: 0.2
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-window-size", type=int, default=10000,
|
||||
help="The federation window size in milliseconds",
|
||||
)
|
||||
|
||||
# Number of message a client can send before being throttled
|
||||
rc_message_burst_count: 10.0
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-sleep-limit", type=int, default=10,
|
||||
help="The number of federation requests from a single server"
|
||||
" in a window before the server will delay processing the"
|
||||
" request.",
|
||||
)
|
||||
|
||||
# The federation window size in milliseconds
|
||||
federation_rc_window_size: 1000
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-sleep-delay", type=int, default=500,
|
||||
help="The duration in milliseconds to delay processing events from"
|
||||
" remote servers by if they go over the sleep limit.",
|
||||
)
|
||||
|
||||
# The number of federation requests from a single server in a window
|
||||
# before the server will delay processing the request.
|
||||
federation_rc_sleep_limit: 10
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-reject-limit", type=int, default=50,
|
||||
help="The maximum number of concurrent federation requests allowed"
|
||||
" from a single server",
|
||||
)
|
||||
|
||||
# The duration in milliseconds to delay processing events from
|
||||
# remote servers by if they go over the sleep limit.
|
||||
federation_rc_sleep_delay: 500
|
||||
|
||||
# The maximum number of concurrent federation requests allowed
|
||||
# from a single server
|
||||
federation_rc_reject_limit: 50
|
||||
|
||||
# The number of federation requests to concurrently process from a
|
||||
# single server
|
||||
federation_rc_concurrent: 3
|
||||
"""
|
||||
rc_group.add_argument(
|
||||
"--federation-rc-concurrent", type=int, default=3,
|
||||
help="The number of federation requests to concurrently process"
|
||||
" from a single server",
|
||||
)
|
||||
|
||||
@@ -17,44 +17,45 @@ from ._base import Config
|
||||
|
||||
from synapse.util.stringutils import random_string_with_symbols
|
||||
|
||||
from distutils.util import strtobool
|
||||
import distutils.util
|
||||
|
||||
|
||||
class RegistrationConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
def __init__(self, args):
|
||||
super(RegistrationConfig, self).__init__(args)
|
||||
|
||||
# `args.enable_registration` may either be a bool or a string depending
|
||||
# on if the option was given a value (e.g. --enable-registration=true
|
||||
# would set `args.enable_registration` to "true" not True.)
|
||||
self.disable_registration = not bool(
|
||||
strtobool(str(config["enable_registration"]))
|
||||
distutils.util.strtobool(str(args.enable_registration))
|
||||
)
|
||||
if "disable_registration" in config:
|
||||
self.disable_registration = bool(
|
||||
strtobool(str(config["disable_registration"]))
|
||||
)
|
||||
self.registration_shared_secret = args.registration_shared_secret
|
||||
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
|
||||
def default_config(self, config_dir, server_name):
|
||||
registration_shared_secret = random_string_with_symbols(50)
|
||||
return """\
|
||||
## Registration ##
|
||||
|
||||
# Enable registration for new users.
|
||||
enable_registration: False
|
||||
|
||||
# If set, allows registration by anyone who also has the shared
|
||||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
""" % locals()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(RegistrationConfig, cls).add_arguments(parser)
|
||||
reg_group = parser.add_argument_group("registration")
|
||||
|
||||
reg_group.add_argument(
|
||||
"--enable-registration", action="store_true", default=None,
|
||||
help="Enable registration for new users."
|
||||
"--enable-registration",
|
||||
const=True,
|
||||
default=False,
|
||||
nargs='?',
|
||||
help="Enable registration for new users.",
|
||||
)
|
||||
reg_group.add_argument(
|
||||
"--registration-shared-secret", type=str,
|
||||
help="If set, allows registration by anyone who also has the shared"
|
||||
" secret, even if registration is otherwise disabled.",
|
||||
)
|
||||
|
||||
def read_arguments(self, args):
|
||||
if args.enable_registration is not None:
|
||||
self.disable_registration = not bool(
|
||||
strtobool(str(args.enable_registration))
|
||||
)
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
super(RegistrationConfig, cls).generate_config(args, config_dir_path)
|
||||
if args.enable_registration is None:
|
||||
args.enable_registration = False
|
||||
|
||||
if args.registration_shared_secret is None:
|
||||
args.registration_shared_secret = random_string_with_symbols(50)
|
||||
|
||||
@@ -17,20 +17,32 @@ from ._base import Config
|
||||
|
||||
|
||||
class ContentRepositoryConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.max_upload_size = self.parse_size(config["max_upload_size"])
|
||||
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||
def __init__(self, args):
|
||||
super(ContentRepositoryConfig, self).__init__(args)
|
||||
self.max_upload_size = self.parse_size(args.max_upload_size)
|
||||
self.max_image_pixels = self.parse_size(args.max_image_pixels)
|
||||
self.media_store_path = self.ensure_directory(args.media_store_path)
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
media_store = self.default_path("media_store")
|
||||
return """
|
||||
# Directory where uploaded images and attachments are stored.
|
||||
media_store_path: "%(media_store)s"
|
||||
def parse_size(self, string):
|
||||
sizes = {"K": 1024, "M": 1024 * 1024}
|
||||
size = 1
|
||||
suffix = string[-1]
|
||||
if suffix in sizes:
|
||||
string = string[:-1]
|
||||
size = sizes[suffix]
|
||||
return int(string) * size
|
||||
|
||||
# The largest allowed upload size in bytes
|
||||
max_upload_size: "10M"
|
||||
|
||||
# Maximum number of pixels that will be thumbnailed
|
||||
max_image_pixels: "32M"
|
||||
""" % locals()
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(ContentRepositoryConfig, cls).add_arguments(parser)
|
||||
db_group = parser.add_argument_group("content_repository")
|
||||
db_group.add_argument(
|
||||
"--max-upload-size", default="10M"
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--media-store-path", default=cls.default_path("media_store")
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--max-image-pixels", default="32M",
|
||||
help="Maximum number of pixels that will be thumbnailed"
|
||||
)
|
||||
|
||||
@@ -17,95 +17,64 @@ from ._base import Config
|
||||
|
||||
|
||||
class ServerConfig(Config):
|
||||
def __init__(self, args):
|
||||
super(ServerConfig, self).__init__(args)
|
||||
self.server_name = args.server_name
|
||||
self.bind_port = args.bind_port
|
||||
self.bind_host = args.bind_host
|
||||
self.unsecure_port = args.unsecure_port
|
||||
self.daemonize = args.daemonize
|
||||
self.pid_file = self.abspath(args.pid_file)
|
||||
self.web_client = args.web_client
|
||||
self.manhole = args.manhole
|
||||
self.soft_file_limit = args.soft_file_limit
|
||||
|
||||
def read_config(self, config):
|
||||
self.server_name = config["server_name"]
|
||||
self.bind_port = config["bind_port"]
|
||||
self.bind_host = config["bind_host"]
|
||||
self.unsecure_port = config["unsecure_port"]
|
||||
self.manhole = config.get("manhole")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.web_client = config["web_client"]
|
||||
self.soft_file_limit = config["soft_file_limit"]
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||
self.gzip_responses = config["gzip_responses"]
|
||||
|
||||
# Attempt to guess the content_addr for the v0 content repostitory
|
||||
content_addr = config.get("content_addr")
|
||||
if not content_addr:
|
||||
host = self.server_name
|
||||
if not args.content_addr:
|
||||
host = args.server_name
|
||||
if ':' not in host:
|
||||
host = "%s:%d" % (host, self.unsecure_port)
|
||||
host = "%s:%d" % (host, args.unsecure_port)
|
||||
else:
|
||||
host = host.split(':')[0]
|
||||
host = "%s:%d" % (host, self.unsecure_port)
|
||||
content_addr = "http://%s" % (host,)
|
||||
host = "%s:%d" % (host, args.unsecure_port)
|
||||
args.content_addr = "http://%s" % (host,)
|
||||
|
||||
self.content_addr = content_addr
|
||||
self.content_addr = args.content_addr
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
if ":" in server_name:
|
||||
bind_port = int(server_name.split(":")[1])
|
||||
unsecure_port = bind_port - 400
|
||||
else:
|
||||
bind_port = 8448
|
||||
unsecure_port = 8008
|
||||
|
||||
pid_file = self.abspath("homeserver.pid")
|
||||
return """\
|
||||
## Server ##
|
||||
|
||||
# The domain name of the server, with optional explicit port.
|
||||
# This is used by remote servers to connect to this server,
|
||||
# e.g. matrix.org, localhost:8080, etc.
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
# The port to listen for HTTPS requests on.
|
||||
# For when matrix traffic is sent directly to synapse.
|
||||
bind_port: %(bind_port)s
|
||||
|
||||
# The port to listen for HTTP requests on.
|
||||
# For when matrix traffic passes through loadbalancer that unwraps TLS.
|
||||
unsecure_port: %(unsecure_port)s
|
||||
|
||||
# Local interface to listen on.
|
||||
# The empty string will cause synapse to listen on all interfaces.
|
||||
bind_host: ""
|
||||
|
||||
# When running as a daemon, the file to store the pid in
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
# Whether to serve a web client from the HTTP/HTTPS root resource.
|
||||
web_client: True
|
||||
|
||||
# Set the soft limit on the number of file descriptors synapse can use
|
||||
# Zero is used to indicate synapse should set the soft limit to the
|
||||
# hard limit.
|
||||
soft_file_limit: 0
|
||||
|
||||
# Turn on the twisted telnet manhole service on localhost on the given
|
||||
# port.
|
||||
#manhole: 9000
|
||||
|
||||
# Should synapse compress HTTP responses to clients that support it?
|
||||
# This should be disabled if running synapse behind a load balancer
|
||||
# that can do automatic compression.
|
||||
gzip_responses: True
|
||||
""" % locals()
|
||||
|
||||
def read_arguments(self, args):
|
||||
if args.manhole is not None:
|
||||
self.manhole = args.manhole
|
||||
if args.daemonize is not None:
|
||||
self.daemonize = args.daemonize
|
||||
|
||||
def add_arguments(self, parser):
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(ServerConfig, cls).add_arguments(parser)
|
||||
server_group = parser.add_argument_group("server")
|
||||
server_group.add_argument(
|
||||
"-H", "--server-name", default="localhost",
|
||||
help="The domain name of the server, with optional explicit port. "
|
||||
"This is used by remote servers to connect to this server, "
|
||||
"e.g. matrix.org, localhost:8080, etc."
|
||||
)
|
||||
server_group.add_argument("-p", "--bind-port", metavar="PORT",
|
||||
type=int, help="https port to listen on",
|
||||
default=8448)
|
||||
server_group.add_argument("--unsecure-port", metavar="PORT",
|
||||
type=int, help="http port to listen on",
|
||||
default=8008)
|
||||
server_group.add_argument("--bind-host", default="",
|
||||
help="Local interface to listen on")
|
||||
server_group.add_argument("-D", "--daemonize", action='store_true',
|
||||
default=None,
|
||||
help="Daemonize the home server")
|
||||
server_group.add_argument('--pid-file', default="homeserver.pid",
|
||||
help="When running as a daemon, the file to"
|
||||
" store the pid in")
|
||||
server_group.add_argument('--web_client', default=True, type=bool,
|
||||
help="Whether or not to serve a web client")
|
||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
" service on the given port.")
|
||||
server_group.add_argument("--content-addr", default=None,
|
||||
help="The host and scheme to use for the "
|
||||
"content repository")
|
||||
server_group.add_argument("--soft-file-limit", type=int, default=0,
|
||||
help="Set the soft limit on the number of "
|
||||
"file descriptors synapse can use. "
|
||||
"Zero is used to indicate synapse "
|
||||
"should set the soft limit to the hard"
|
||||
"limit.")
|
||||
|
||||
@@ -23,44 +23,37 @@ GENERATE_DH_PARAMS = False
|
||||
|
||||
|
||||
class TlsConfig(Config):
|
||||
def read_config(self, config):
|
||||
def __init__(self, args):
|
||||
super(TlsConfig, self).__init__(args)
|
||||
self.tls_certificate = self.read_tls_certificate(
|
||||
config.get("tls_certificate_path")
|
||||
args.tls_certificate_path
|
||||
)
|
||||
|
||||
self.no_tls = config.get("no_tls", False)
|
||||
self.no_tls = args.no_tls
|
||||
|
||||
if self.no_tls:
|
||||
self.tls_private_key = None
|
||||
else:
|
||||
self.tls_private_key = self.read_tls_private_key(
|
||||
config.get("tls_private_key_path")
|
||||
args.tls_private_key_path
|
||||
)
|
||||
|
||||
self.tls_dh_params_path = self.check_file(
|
||||
config.get("tls_dh_params_path"), "tls_dh_params"
|
||||
args.tls_dh_params_path, "tls_dh_params"
|
||||
)
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
base_key_name = os.path.join(config_dir_path, server_name)
|
||||
|
||||
tls_certificate_path = base_key_name + ".tls.crt"
|
||||
tls_private_key_path = base_key_name + ".tls.key"
|
||||
tls_dh_params_path = base_key_name + ".tls.dh"
|
||||
|
||||
return """\
|
||||
# PEM encoded X509 certificate for TLS
|
||||
tls_certificate_path: "%(tls_certificate_path)s"
|
||||
|
||||
# PEM encoded private key for TLS
|
||||
tls_private_key_path: "%(tls_private_key_path)s"
|
||||
|
||||
# PEM dh parameters for ephemeral keys
|
||||
tls_dh_params_path: "%(tls_dh_params_path)s"
|
||||
|
||||
# Don't bind to the https port
|
||||
no_tls: False
|
||||
""" % locals()
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(TlsConfig, cls).add_arguments(parser)
|
||||
tls_group = parser.add_argument_group("tls")
|
||||
tls_group.add_argument("--tls-certificate-path",
|
||||
help="PEM encoded X509 certificate for TLS")
|
||||
tls_group.add_argument("--tls-private-key-path",
|
||||
help="PEM encoded private key for TLS")
|
||||
tls_group.add_argument("--tls-dh-params-path",
|
||||
help="PEM dh parameters for ephemeral keys")
|
||||
tls_group.add_argument("--no-tls", action='store_true',
|
||||
help="Don't bind to the https port.")
|
||||
|
||||
def read_tls_certificate(self, cert_path):
|
||||
cert_pem = self.read_file(cert_path, "tls_certificate")
|
||||
@@ -70,13 +63,22 @@ class TlsConfig(Config):
|
||||
private_key_pem = self.read_file(private_key_path, "tls_private_key")
|
||||
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
|
||||
|
||||
def generate_files(self, config):
|
||||
tls_certificate_path = config["tls_certificate_path"]
|
||||
tls_private_key_path = config["tls_private_key_path"]
|
||||
tls_dh_params_path = config["tls_dh_params_path"]
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
super(TlsConfig, cls).generate_config(args, config_dir_path)
|
||||
base_key_name = os.path.join(config_dir_path, args.server_name)
|
||||
|
||||
if not os.path.exists(tls_private_key_path):
|
||||
with open(tls_private_key_path, "w") as private_key_file:
|
||||
if args.tls_certificate_path is None:
|
||||
args.tls_certificate_path = base_key_name + ".tls.crt"
|
||||
|
||||
if args.tls_private_key_path is None:
|
||||
args.tls_private_key_path = base_key_name + ".tls.key"
|
||||
|
||||
if args.tls_dh_params_path is None:
|
||||
args.tls_dh_params_path = base_key_name + ".tls.dh"
|
||||
|
||||
if not os.path.exists(args.tls_private_key_path):
|
||||
with open(args.tls_private_key_path, "w") as private_key_file:
|
||||
tls_private_key = crypto.PKey()
|
||||
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
|
||||
private_key_pem = crypto.dump_privatekey(
|
||||
@@ -84,17 +86,17 @@ class TlsConfig(Config):
|
||||
)
|
||||
private_key_file.write(private_key_pem)
|
||||
else:
|
||||
with open(tls_private_key_path) as private_key_file:
|
||||
with open(args.tls_private_key_path) as private_key_file:
|
||||
private_key_pem = private_key_file.read()
|
||||
tls_private_key = crypto.load_privatekey(
|
||||
crypto.FILETYPE_PEM, private_key_pem
|
||||
)
|
||||
|
||||
if not os.path.exists(tls_certificate_path):
|
||||
with open(tls_certificate_path, "w") as certifcate_file:
|
||||
if not os.path.exists(args.tls_certificate_path):
|
||||
with open(args.tls_certificate_path, "w") as certifcate_file:
|
||||
cert = crypto.X509()
|
||||
subject = cert.get_subject()
|
||||
subject.CN = config["server_name"]
|
||||
subject.CN = args.server_name
|
||||
|
||||
cert.set_serial_number(1000)
|
||||
cert.gmtime_adj_notBefore(0)
|
||||
@@ -108,16 +110,16 @@ class TlsConfig(Config):
|
||||
|
||||
certifcate_file.write(cert_pem)
|
||||
|
||||
if not os.path.exists(tls_dh_params_path):
|
||||
if not os.path.exists(args.tls_dh_params_path):
|
||||
if GENERATE_DH_PARAMS:
|
||||
subprocess.check_call([
|
||||
"openssl", "dhparam",
|
||||
"-outform", "PEM",
|
||||
"-out", tls_dh_params_path,
|
||||
"-out", args.tls_dh_params_path,
|
||||
"2048"
|
||||
])
|
||||
else:
|
||||
with open(tls_dh_params_path, "w") as dh_params_file:
|
||||
with open(args.tls_dh_params_path, "w") as dh_params_file:
|
||||
dh_params_file.write(
|
||||
"2048-bit DH parameters taken from rfc3526\n"
|
||||
"-----BEGIN DH PARAMETERS-----\n"
|
||||
|
||||
@@ -17,21 +17,28 @@ from ._base import Config
|
||||
|
||||
class VoipConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.turn_uris = config.get("turn_uris", [])
|
||||
self.turn_shared_secret = config["turn_shared_secret"]
|
||||
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
|
||||
def __init__(self, args):
|
||||
super(VoipConfig, self).__init__(args)
|
||||
self.turn_uris = args.turn_uris
|
||||
self.turn_shared_secret = args.turn_shared_secret
|
||||
self.turn_user_lifetime = args.turn_user_lifetime
|
||||
|
||||
def default_config(self, config_dir_path, server_name):
|
||||
return """\
|
||||
## Turn ##
|
||||
|
||||
# The public URIs of the TURN server to give to clients
|
||||
turn_uris: []
|
||||
|
||||
# The shared secret used to compute passwords for the TURN server
|
||||
turn_shared_secret: "YOUR_SHARED_SECRET"
|
||||
|
||||
# How long generated TURN credentials last
|
||||
turn_user_lifetime: "1h"
|
||||
"""
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(VoipConfig, cls).add_arguments(parser)
|
||||
group = parser.add_argument_group("voip")
|
||||
group.add_argument(
|
||||
"--turn-uris", type=str, default=None, action='append',
|
||||
help="The public URIs of the TURN server to give to clients"
|
||||
)
|
||||
group.add_argument(
|
||||
"--turn-shared-secret", type=str, default=None,
|
||||
help=(
|
||||
"The shared secret used to compute passwords for the TURN"
|
||||
" server"
|
||||
)
|
||||
)
|
||||
group.add_argument(
|
||||
"--turn-user-lifetime", type=int, default=(1000 * 60 * 60),
|
||||
help="How long generated TURN credentials last, in ms"
|
||||
)
|
||||
|
||||
@@ -18,9 +18,7 @@ from twisted.web.http import HTTPClient
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.internet import defer, reactor
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util.logcontext import (
|
||||
preserve_context_over_fn, preserve_context_over_deferred
|
||||
)
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
import simplejson as json
|
||||
import logging
|
||||
|
||||
@@ -42,14 +40,11 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||
|
||||
for i in range(5):
|
||||
try:
|
||||
protocol = yield preserve_context_over_fn(
|
||||
endpoint.connect, factory
|
||||
)
|
||||
server_response, server_certificate = yield preserve_context_over_deferred(
|
||||
protocol.remote_key
|
||||
)
|
||||
defer.returnValue((server_response, server_certificate))
|
||||
return
|
||||
with PreserveLoggingContext():
|
||||
protocol = yield endpoint.connect(factory)
|
||||
server_response, server_certificate = yield protocol.remote_key
|
||||
defer.returnValue((server_response, server_certificate))
|
||||
return
|
||||
except SynapseKeyClientError as e:
|
||||
logger.exception("Error getting key for %r" % (server_name,))
|
||||
if e.status.startswith("4"):
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
from synapse.util.retryutils import get_retry_limiter
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async import create_observer
|
||||
|
||||
from OpenSSL import crypto
|
||||
|
||||
@@ -111,10 +111,6 @@ class Keyring(object):
|
||||
|
||||
if download is None:
|
||||
download = self._get_server_verify_key_impl(server_name, key_ids)
|
||||
download = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
self.key_downloads[server_name] = download
|
||||
|
||||
@download.addBoth
|
||||
@@ -122,31 +118,30 @@ class Keyring(object):
|
||||
del self.key_downloads[server_name]
|
||||
return ret
|
||||
|
||||
r = yield download.observe()
|
||||
r = yield create_observer(download)
|
||||
defer.returnValue(r)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_server_verify_key_impl(self, server_name, key_ids):
|
||||
keys = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_key(perspective_name, perspective_keys):
|
||||
try:
|
||||
result = yield self.get_server_verify_key_v2_indirect(
|
||||
server_name, key_ids, perspective_name, perspective_keys
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
"Unable to getting key %r for %r from %r: %s %s",
|
||||
key_ids, server_name, perspective_name,
|
||||
type(e).__name__, str(e.message),
|
||||
)
|
||||
perspective_results = []
|
||||
for perspective_name, perspective_keys in self.perspective_servers.items():
|
||||
@defer.inlineCallbacks
|
||||
def get_key():
|
||||
try:
|
||||
result = yield self.get_server_verify_key_v2_indirect(
|
||||
server_name, key_ids, perspective_name, perspective_keys
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except:
|
||||
logging.info(
|
||||
"Unable to getting key %r for %r from %r",
|
||||
key_ids, server_name, perspective_name,
|
||||
)
|
||||
perspective_results.append(get_key())
|
||||
|
||||
perspective_results = yield defer.gatherResults([
|
||||
get_key(p_name, p_keys)
|
||||
for p_name, p_keys in self.perspective_servers.items()
|
||||
])
|
||||
perspective_results = yield defer.gatherResults(perspective_results)
|
||||
|
||||
for results in perspective_results:
|
||||
if results is not None:
|
||||
@@ -159,22 +154,17 @@ class Keyring(object):
|
||||
)
|
||||
|
||||
with limiter:
|
||||
if not keys:
|
||||
if keys is None:
|
||||
try:
|
||||
keys = yield self.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except Exception as e:
|
||||
logging.info(
|
||||
"Unable to getting key %r for %r directly: %s %s",
|
||||
key_ids, server_name,
|
||||
type(e).__name__, str(e.message),
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not keys:
|
||||
keys = yield self.get_server_verify_key_v1_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
keys = yield self.get_server_verify_key_v1_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
|
||||
for key_id in key_ids:
|
||||
if key_id in keys:
|
||||
@@ -194,7 +184,7 @@ class Keyring(object):
|
||||
# TODO(mark): Set the minimum_valid_until_ts to that needed by
|
||||
# the events being validated or the current time if validating
|
||||
# an incoming request.
|
||||
query_response = yield self.client.post_json(
|
||||
responses = yield self.client.post_json(
|
||||
destination=perspective_name,
|
||||
path=b"/_matrix/key/v2/query",
|
||||
data={
|
||||
@@ -210,8 +200,6 @@ class Keyring(object):
|
||||
|
||||
keys = {}
|
||||
|
||||
responses = query_response["server_keys"]
|
||||
|
||||
for response in responses:
|
||||
if (u"signatures" not in response
|
||||
or perspective_name not in response[u"signatures"]):
|
||||
@@ -335,7 +323,7 @@ class Keyring(object):
|
||||
verify_key.time_added = time_now_ms
|
||||
old_verify_keys[key_id] = verify_key
|
||||
|
||||
for key_id in response_json["signatures"].get(server_name, {}):
|
||||
for key_id in response_json["signatures"][server_name]:
|
||||
if key_id not in response_json["verify_keys"]:
|
||||
raise ValueError(
|
||||
"Key response must include verification keys for all"
|
||||
|
||||
@@ -16,12 +16,6 @@
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
|
||||
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
|
||||
# bugs where we accidentally share e.g. signature dicts. However, converting
|
||||
# a dict to frozen_dicts is expensive.
|
||||
USE_FROZEN_DICTS = True
|
||||
|
||||
|
||||
class _EventInternalMetadata(object):
|
||||
def __init__(self, internal_metadata_dict):
|
||||
self.__dict__ = dict(internal_metadata_dict)
|
||||
@@ -128,10 +122,7 @@ class FrozenEvent(EventBase):
|
||||
|
||||
unsigned = dict(event_dict.pop("unsigned", {}))
|
||||
|
||||
if USE_FROZEN_DICTS:
|
||||
frozen_dict = freeze(event_dict)
|
||||
else:
|
||||
frozen_dict = event_dict
|
||||
frozen_dict = freeze(event_dict)
|
||||
|
||||
super(FrozenEvent, self).__init__(
|
||||
frozen_dict,
|
||||
|
||||
@@ -18,12 +18,12 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.events.utils import prune_event
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@@ -78,7 +78,6 @@ class FederationBase(object):
|
||||
destinations=[pdu.origin],
|
||||
event_id=pdu.event_id,
|
||||
outlier=outlier,
|
||||
timeout=10000,
|
||||
)
|
||||
|
||||
if new_pdu:
|
||||
@@ -95,7 +94,7 @@ class FederationBase(object):
|
||||
yield defer.gatherResults(
|
||||
[do(pdu) for pdu in pdus],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
defer.returnValue(signed_pdus)
|
||||
|
||||
@@ -118,15 +117,16 @@ class FederationBase(object):
|
||||
)
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Signature check failed for %s",
|
||||
pdu.event_id,
|
||||
"Signature check failed for %s redacted to %s",
|
||||
encode_canonical_json(pdu.get_pdu_json()),
|
||||
encode_canonical_json(redacted_pdu_json),
|
||||
)
|
||||
raise
|
||||
|
||||
if not check_event_content_hash(pdu):
|
||||
logger.warn(
|
||||
"Event content has been tampered, redacting.",
|
||||
pdu.event_id,
|
||||
"Event content has been tampered, redacting %s, %s",
|
||||
pdu.event_id, encode_canonical_json(pdu.get_dict())
|
||||
)
|
||||
defer.returnValue(redacted_event)
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from .units import Edu
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.events import FrozenEvent
|
||||
@@ -165,17 +164,16 @@ class FederationClient(FederationBase):
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
pdus[:] = yield defer.gatherResults(
|
||||
[self._check_sigs_and_hash(pdu) for pdu in pdus],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
for i, pdu in enumerate(pdus):
|
||||
pdus[i] = yield self._check_sigs_and_hash(pdu)
|
||||
|
||||
# FIXME: We should handle signature failures more gracefully.
|
||||
|
||||
defer.returnValue(pdus)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
|
||||
def get_pdu(self, destinations, event_id, outlier=False):
|
||||
"""Requests the PDU with given origin and ID from the remote home
|
||||
servers.
|
||||
|
||||
@@ -191,8 +189,6 @@ class FederationClient(FederationBase):
|
||||
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
|
||||
it's from an arbitary point in the context as opposed to part
|
||||
of the current block of PDUs. Defaults to `False`
|
||||
timeout (int): How long to try (in ms) each destination for before
|
||||
moving to the next destination. None indicates no timeout.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in the requested PDU.
|
||||
@@ -216,7 +212,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
with limiter:
|
||||
transaction_data = yield self.transport_layer.get_event(
|
||||
destination, event_id, timeout=timeout,
|
||||
destination, event_id
|
||||
)
|
||||
|
||||
logger.debug("transaction_data %r", transaction_data)
|
||||
@@ -226,7 +222,7 @@ class FederationClient(FederationBase):
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
if pdu_list and pdu_list[0]:
|
||||
if pdu_list:
|
||||
pdu = pdu_list[0]
|
||||
|
||||
# Check signatures are correct.
|
||||
@@ -259,7 +255,7 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
continue
|
||||
|
||||
if self._get_pdu_cache is not None and pdu:
|
||||
if self._get_pdu_cache is not None:
|
||||
self._get_pdu_cache[event_id] = pdu
|
||||
|
||||
defer.returnValue(pdu)
|
||||
@@ -374,17 +370,13 @@ class FederationClient(FederationBase):
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
signed_state, signed_auth = yield defer.gatherResults(
|
||||
[
|
||||
self._check_sigs_and_hash_and_fetch(
|
||||
destination, state, outlier=True
|
||||
),
|
||||
self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True
|
||||
)
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
signed_state = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, state, outlier=True
|
||||
)
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True
|
||||
)
|
||||
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
@@ -499,7 +491,7 @@ class FederationClient(FederationBase):
|
||||
]
|
||||
|
||||
signed_events = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, events, outlier=False
|
||||
destination, events, outlier=True
|
||||
)
|
||||
|
||||
have_gotten_all_from_destination = True
|
||||
@@ -526,7 +518,7 @@ class FederationClient(FederationBase):
|
||||
# Are we missing any?
|
||||
|
||||
seen_events = set(earliest_events_ids)
|
||||
seen_events.update(e.event_id for e in signed_events if e)
|
||||
seen_events.update(e.event_id for e in signed_events)
|
||||
|
||||
missing_events = {}
|
||||
for e in itertools.chain(latest_events, signed_events):
|
||||
@@ -569,7 +561,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
for (result, val), (e_id, _) in zip(res, ordered_missing):
|
||||
if result and val:
|
||||
if result:
|
||||
signed_events.append(val)
|
||||
else:
|
||||
failed_to_fetch.add(e_id)
|
||||
|
||||
@@ -20,6 +20,7 @@ from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.events import FrozenEvent
|
||||
import synapse.metrics
|
||||
|
||||
@@ -122,28 +123,29 @@ class FederationServer(FederationBase):
|
||||
|
||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||
|
||||
results = []
|
||||
with PreserveLoggingContext():
|
||||
results = []
|
||||
|
||||
for pdu in pdu_list:
|
||||
d = self._handle_new_pdu(transaction.origin, pdu)
|
||||
for pdu in pdu_list:
|
||||
d = self._handle_new_pdu(transaction.origin, pdu)
|
||||
|
||||
try:
|
||||
yield d
|
||||
results.append({})
|
||||
except FederationError as e:
|
||||
self.send_failure(e, transaction.origin)
|
||||
results.append({"error": str(e)})
|
||||
except Exception as e:
|
||||
results.append({"error": str(e)})
|
||||
logger.exception("Failed to handle PDU")
|
||||
try:
|
||||
yield d
|
||||
results.append({})
|
||||
except FederationError as e:
|
||||
self.send_failure(e, transaction.origin)
|
||||
results.append({"error": str(e)})
|
||||
except Exception as e:
|
||||
results.append({"error": str(e)})
|
||||
logger.exception("Failed to handle PDU")
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in [Edu(**x) for x in transaction.edus]:
|
||||
self.received_edu(
|
||||
transaction.origin,
|
||||
edu.edu_type,
|
||||
edu.content
|
||||
)
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in [Edu(**x) for x in transaction.edus]:
|
||||
self.received_edu(
|
||||
transaction.origin,
|
||||
edu.edu_type,
|
||||
edu.content
|
||||
)
|
||||
|
||||
for failure in getattr(transaction, "pdu_failures", []):
|
||||
logger.info("Got failure %r", failure)
|
||||
|
||||
@@ -23,6 +23,8 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@@ -69,7 +71,7 @@ class TransactionActions(object):
|
||||
transaction.transaction_id,
|
||||
transaction.origin,
|
||||
code,
|
||||
response,
|
||||
encode_canonical_json(response)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -99,5 +101,5 @@ class TransactionActions(object):
|
||||
transaction.transaction_id,
|
||||
transaction.destination,
|
||||
response_code,
|
||||
response_dict,
|
||||
encode_canonical_json(response_dict)
|
||||
)
|
||||
|
||||
@@ -104,6 +104,7 @@ class TransactionQueue(object):
|
||||
return not destination.startswith("localhost")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def enqueue_pdu(self, pdu, destinations, order):
|
||||
# We loop through all destinations to see whether we already have
|
||||
# a transaction in progress. If we do, stick it in the pending_pdus
|
||||
@@ -207,13 +208,13 @@ class TransactionQueue(object):
|
||||
# request at which point pending_pdus_by_dest just keeps growing.
|
||||
# we need application-layer timeouts of some flavour of these
|
||||
# requests
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"TX [%s] Transaction already in progress",
|
||||
destination
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("TX [%s] _attempt_new_transaction", destination)
|
||||
logger.info("TX [%s] _attempt_new_transaction", destination)
|
||||
|
||||
# list of (pending_pdu, deferred, order)
|
||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||
@@ -221,11 +222,11 @@ class TransactionQueue(object):
|
||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
||||
|
||||
if pending_pdus:
|
||||
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||
destination, len(pending_pdus))
|
||||
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||
destination, len(pending_pdus))
|
||||
|
||||
if not pending_pdus and not pending_edus and not pending_failures:
|
||||
logger.debug("TX [%s] Nothing to send", destination)
|
||||
logger.info("TX [%s] Nothing to send", destination)
|
||||
return
|
||||
|
||||
# Sort based on the order field
|
||||
@@ -242,8 +243,6 @@ class TransactionQueue(object):
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
||||
txn_id = str(self._next_txn_id)
|
||||
|
||||
limiter = yield get_retry_limiter(
|
||||
destination,
|
||||
self._clock,
|
||||
@@ -251,9 +250,9 @@ class TransactionQueue(object):
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"TX [%s] {%s} Attempting new transaction"
|
||||
"TX [%s] Attempting new transaction"
|
||||
" (pdus: %d, edus: %d, failures: %d)",
|
||||
destination, txn_id,
|
||||
destination,
|
||||
len(pending_pdus),
|
||||
len(pending_edus),
|
||||
len(pending_failures)
|
||||
@@ -263,7 +262,7 @@ class TransactionQueue(object):
|
||||
|
||||
transaction = Transaction.create_new(
|
||||
origin_server_ts=int(self._clock.time_msec()),
|
||||
transaction_id=txn_id,
|
||||
transaction_id=str(self._next_txn_id),
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
pdus=pdus,
|
||||
@@ -277,13 +276,9 @@ class TransactionQueue(object):
|
||||
|
||||
logger.debug("TX [%s] Persisted transaction", destination)
|
||||
logger.info(
|
||||
"TX [%s] {%s} Sending transaction [%s],"
|
||||
" (PDUs: %d, EDUs: %d, failures: %d)",
|
||||
destination, txn_id,
|
||||
"TX [%s] Sending transaction [%s]",
|
||||
destination,
|
||||
transaction.transaction_id,
|
||||
len(pending_pdus),
|
||||
len(pending_edus),
|
||||
len(pending_failures),
|
||||
)
|
||||
|
||||
with limiter:
|
||||
@@ -319,10 +314,7 @@ class TransactionQueue(object):
|
||||
code = e.code
|
||||
response = e.response
|
||||
|
||||
logger.info(
|
||||
"TX [%s] {%s} got %d response",
|
||||
destination, txn_id, code
|
||||
)
|
||||
logger.info("TX [%s] got %d response", destination, code)
|
||||
|
||||
logger.debug("TX [%s] Sent transaction", destination)
|
||||
logger.debug("TX [%s] Marking as delivered...", destination)
|
||||
|
||||
@@ -50,15 +50,13 @@ class TransportLayerClient(object):
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_event(self, destination, event_id, timeout=None):
|
||||
def get_event(self, destination, event_id):
|
||||
""" Requests the pdu with give id and origin from the given server.
|
||||
|
||||
Args:
|
||||
destination (str): The host name of the remote home server we want
|
||||
to get the state from.
|
||||
event_id (str): The id of the event being requested.
|
||||
timeout (int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
@@ -67,7 +65,7 @@ class TransportLayerClient(object):
|
||||
destination, event_id)
|
||||
|
||||
path = PREFIX + "/event/%s/" % (event_id, )
|
||||
return self.client.get_json(destination, path=path, timeout=timeout)
|
||||
return self.client.get_json(destination, path=path)
|
||||
|
||||
@log_function
|
||||
def backfill(self, destination, room_id, event_tuples, limit):
|
||||
|
||||
@@ -93,8 +93,6 @@ class TransportLayerServer(object):
|
||||
|
||||
yield self.keyring.verify_json_for_server(origin, json_request)
|
||||
|
||||
logger.info("Request from %s", origin)
|
||||
|
||||
defer.returnValue((origin, content))
|
||||
|
||||
@log_function
|
||||
@@ -198,14 +196,6 @@ class FederationSendServlet(BaseFederationServlet):
|
||||
transaction_id, str(transaction_data)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
|
||||
transaction_id, origin,
|
||||
len(transaction_data.get("pdus", [])),
|
||||
len(transaction_data.get("edus", [])),
|
||||
len(transaction_data.get("failures", [])),
|
||||
)
|
||||
|
||||
# We should ideally be getting this from the security layer.
|
||||
# origin = body["origin"]
|
||||
|
||||
|
||||
@@ -20,8 +20,6 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.types import UserID
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@@ -78,9 +76,7 @@ class BaseHandler(object):
|
||||
context = yield state_handler.compute_event_context(builder)
|
||||
|
||||
if builder.is_state():
|
||||
builder.prev_state = yield self.store.add_event_hashes(
|
||||
context.prev_state_events
|
||||
)
|
||||
builder.prev_state = context.prev_state_events
|
||||
|
||||
yield self.auth.add_auth_events(builder, context)
|
||||
|
||||
@@ -107,9 +103,7 @@ class BaseHandler(object):
|
||||
if not suppress_auth:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
|
||||
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||
event, context=context
|
||||
)
|
||||
yield self.store.persist_event(event, context=context)
|
||||
|
||||
federation_handler = self.hs.get_handlers().federation_handler
|
||||
|
||||
@@ -143,12 +137,10 @@ class BaseHandler(object):
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
# Don't block waiting on waking up all the listeners.
|
||||
notify_d = self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
# Don't block waiting on waking up all the listeners.
|
||||
notify_d = self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
@@ -158,6 +150,8 @@ class BaseHandler(object):
|
||||
|
||||
notify_d.addErrback(log_failure)
|
||||
|
||||
federation_handler.handle_new_event(
|
||||
fed_d = federation_handler.handle_new_event(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
||||
fed_d.addErrback(log_failure)
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import UserID
|
||||
|
||||
@@ -147,7 +147,10 @@ class ApplicationServicesHandler(object):
|
||||
)
|
||||
# We need to know the members associated with this event.room_id,
|
||||
# if any.
|
||||
member_list = yield self.store.get_users_in_room(event.room_id)
|
||||
member_list = yield self.store.get_room_members(
|
||||
room_id=event.room_id,
|
||||
membership=Membership.JOIN
|
||||
)
|
||||
|
||||
services = yield self.store.get_app_services()
|
||||
interested_list = [
|
||||
@@ -177,7 +180,7 @@ class ApplicationServicesHandler(object):
|
||||
return
|
||||
|
||||
user_info = yield self.store.get_user_by_id(user_id)
|
||||
if not user_info:
|
||||
if len(user_info) > 0:
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ class AuthHandler(BaseHandler):
|
||||
logger.warn("Attempted to login as %s but they do not exist", user)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
stored_hash = user_info["password_hash"]
|
||||
stored_hash = user_info[0]["password_hash"]
|
||||
if bcrypt.checkpw(password, stored_hash):
|
||||
defer.returnValue(user)
|
||||
else:
|
||||
@@ -187,8 +187,8 @@ class AuthHandler(BaseHandler):
|
||||
# each request
|
||||
try:
|
||||
client = SimpleHttpClient(self.hs)
|
||||
resp_body = yield client.post_urlencoded_get_json(
|
||||
self.hs.config.recaptcha_siteverify_api,
|
||||
data = yield client.post_urlencoded_get_json(
|
||||
"https://www.google.com/recaptcha/api/siteverify",
|
||||
args={
|
||||
'secret': self.hs.config.recaptcha_private_key,
|
||||
'response': user_response,
|
||||
@@ -198,8 +198,7 @@ class AuthHandler(BaseHandler):
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted is silly
|
||||
data = pde.response
|
||||
resp_body = simplejson.loads(data)
|
||||
|
||||
resp_body = simplejson.loads(data)
|
||||
if 'success' in resp_body and resp_body['success']:
|
||||
defer.returnValue(True)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
@@ -22,7 +22,6 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.types import RoomAlias
|
||||
|
||||
import logging
|
||||
import string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,10 +40,6 @@ class DirectoryHandler(BaseHandler):
|
||||
def _create_association(self, room_alias, room_id, servers=None):
|
||||
# general association creation for both human users and app services
|
||||
|
||||
for wchar in string.whitespace:
|
||||
if wchar in room_alias.localpart:
|
||||
raise SynapseError(400, "Invalid characters in room alias")
|
||||
|
||||
if not self.hs.is_mine(room_alias):
|
||||
raise SynapseError(400, "Room alias must be local")
|
||||
# TODO(erikj): Change this.
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.types import UserID
|
||||
from synapse.events.utils import serialize_event
|
||||
@@ -80,9 +81,10 @@ class EventStreamHandler(BaseHandler):
|
||||
# thundering herds on restart.
|
||||
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
||||
|
||||
events, tokens = yield self.notifier.get_events_for(
|
||||
auth_user, room_ids, pagin_config, timeout
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
events, tokens = yield self.notifier.get_events_for(
|
||||
auth_user, room_ids, pagin_config, timeout
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
||||
@@ -18,11 +18,9 @@
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
||||
AuthError, FederationError, StoreError,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
@@ -31,8 +29,6 @@ from synapse.crypto.event_signing import (
|
||||
)
|
||||
from synapse.types import UserID
|
||||
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import itertools
|
||||
@@ -77,6 +73,8 @@ class FederationHandler(BaseHandler):
|
||||
# When joining a room we need to queue any events for that room up
|
||||
self.room_queues = {}
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_event(self, event, destinations):
|
||||
""" Takes in an event from the client to server side, that has already
|
||||
been authed and handled by the state module, and sends it to any
|
||||
@@ -91,7 +89,9 @@ class FederationHandler(BaseHandler):
|
||||
processing.
|
||||
"""
|
||||
|
||||
return self.replication_layer.send_pdu(event, destinations)
|
||||
yield run_on_reactor()
|
||||
|
||||
self.replication_layer.send_pdu(event, destinations)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
@@ -160,7 +160,7 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
try:
|
||||
_, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||
yield self._handle_new_event(
|
||||
origin,
|
||||
event,
|
||||
state=state,
|
||||
@@ -201,11 +201,9 @@ class FederationHandler(BaseHandler):
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=extra_users
|
||||
)
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
@@ -224,254 +222,36 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||
def backfill(self, dest, room_id, limit):
|
||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||
"""
|
||||
if not extremities:
|
||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
||||
extremities = yield self.store.get_oldest_events_in_room(room_id)
|
||||
|
||||
events = yield self.replication_layer.backfill(
|
||||
pdus = yield self.replication_layer.backfill(
|
||||
dest,
|
||||
room_id,
|
||||
limit=limit,
|
||||
limit,
|
||||
extremities=extremities,
|
||||
)
|
||||
|
||||
event_map = {e.event_id: e for e in events}
|
||||
events = []
|
||||
|
||||
event_ids = set(e.event_id for e in events)
|
||||
for pdu in pdus:
|
||||
event = pdu
|
||||
|
||||
edges = [
|
||||
ev.event_id
|
||||
for ev in events
|
||||
if set(e_id for e_id, _ in ev.prev_events) - event_ids
|
||||
]
|
||||
# FIXME (erikj): Not sure this actually works :/
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
logger.info(
|
||||
"backfill: Got %d events with %d edges",
|
||||
len(events), len(edges),
|
||||
)
|
||||
events.append((event, context))
|
||||
|
||||
# For each edge get the current state.
|
||||
|
||||
auth_events = {}
|
||||
state_events = {}
|
||||
events_to_state = {}
|
||||
for e_id in edges:
|
||||
state, auth = yield self.replication_layer.get_state_for_room(
|
||||
destination=dest,
|
||||
room_id=room_id,
|
||||
event_id=e_id
|
||||
)
|
||||
auth_events.update({a.event_id: a for a in auth})
|
||||
auth_events.update({s.event_id: s for s in state})
|
||||
state_events.update({s.event_id: s for s in state})
|
||||
events_to_state[e_id] = state
|
||||
|
||||
seen_events = yield self.store.have_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
|
||||
all_events = events + state_events.values() + auth_events.values()
|
||||
required_auth = set(
|
||||
a_id for event in all_events for a_id, _ in event.auth_events
|
||||
)
|
||||
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self._handle_new_event(
|
||||
dest, a,
|
||||
auth_events={
|
||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in a.auth_events
|
||||
},
|
||||
)
|
||||
for a in auth_events.values()
|
||||
if a.event_id not in seen_events
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self._handle_new_event(
|
||||
dest, event_map[e_id],
|
||||
state=events_to_state[e_id],
|
||||
backfilled=True,
|
||||
auth_events={
|
||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in event_map[e_id].auth_events
|
||||
},
|
||||
)
|
||||
for e_id in events_to_state
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
events.sort(key=lambda e: e.depth)
|
||||
|
||||
for event in events:
|
||||
if event in events_to_state:
|
||||
continue
|
||||
|
||||
yield self._handle_new_event(
|
||||
dest, event,
|
||||
backfilled=True,
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=True
|
||||
)
|
||||
|
||||
defer.returnValue(events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_backfill(self, room_id, current_depth):
|
||||
"""Checks the database to see if we should backfill before paginating,
|
||||
and if so do.
|
||||
"""
|
||||
extremities = yield self.store.get_oldest_events_with_depth_in_room(
|
||||
room_id
|
||||
)
|
||||
|
||||
if not extremities:
|
||||
logger.debug("Not backfilling as no extremeties found.")
|
||||
return
|
||||
|
||||
# Check if we reached a point where we should start backfilling.
|
||||
sorted_extremeties_tuple = sorted(
|
||||
extremities.items(),
|
||||
key=lambda e: -int(e[1])
|
||||
)
|
||||
max_depth = sorted_extremeties_tuple[0][1]
|
||||
|
||||
if current_depth > max_depth:
|
||||
logger.debug(
|
||||
"Not backfilling as we don't need to. %d < %d",
|
||||
max_depth, current_depth,
|
||||
)
|
||||
return
|
||||
|
||||
# Now we need to decide which hosts to hit first.
|
||||
|
||||
# First we try hosts that are already in the room
|
||||
# TODO: HEURISTIC ALERT.
|
||||
|
||||
curr_state = yield self.state_handler.get_current_state(room_id)
|
||||
|
||||
def get_domains_from_state(state):
|
||||
joined_users = [
|
||||
(state_key, int(event.depth))
|
||||
for (e_type, state_key), event in state.items()
|
||||
if e_type == EventTypes.Member
|
||||
and event.membership == Membership.JOIN
|
||||
]
|
||||
|
||||
joined_domains = {}
|
||||
for u, d in joined_users:
|
||||
try:
|
||||
dom = UserID.from_string(u).domain
|
||||
old_d = joined_domains.get(dom)
|
||||
if old_d:
|
||||
joined_domains[dom] = min(d, old_d)
|
||||
else:
|
||||
joined_domains[dom] = d
|
||||
except:
|
||||
pass
|
||||
|
||||
return sorted(joined_domains.items(), key=lambda d: d[1])
|
||||
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
|
||||
likely_domains = [
|
||||
domain for domain, depth in curr_domains
|
||||
if domain is not self.server_name
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def try_backfill(domains):
|
||||
# TODO: Should we try multiple of these at a time?
|
||||
for dom in domains:
|
||||
try:
|
||||
events = yield self.backfill(
|
||||
dom, room_id,
|
||||
limit=100,
|
||||
extremities=[e for e in extremities.keys()]
|
||||
)
|
||||
except SynapseError:
|
||||
logger.info(
|
||||
"Failed to backfill from %s because %s",
|
||||
dom, e,
|
||||
)
|
||||
continue
|
||||
except CodeMessageException as e:
|
||||
if 400 <= e.code < 500:
|
||||
raise
|
||||
|
||||
logger.info(
|
||||
"Failed to backfill from %s because %s",
|
||||
dom, e,
|
||||
)
|
||||
continue
|
||||
except NotRetryingDestination as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to backfill from %s because %s",
|
||||
dom, e,
|
||||
)
|
||||
continue
|
||||
|
||||
if events:
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
success = yield try_backfill(likely_domains)
|
||||
if success:
|
||||
defer.returnValue(True)
|
||||
|
||||
# Huh, well *those* domains didn't work out. Lets try some domains
|
||||
# from the time.
|
||||
|
||||
tried_domains = set(likely_domains)
|
||||
tried_domains.add(self.server_name)
|
||||
|
||||
event_ids = list(extremities.keys())
|
||||
|
||||
states = yield defer.gatherResults([
|
||||
self.state_handler.resolve_state_groups([e])
|
||||
for e in event_ids
|
||||
])
|
||||
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
likely_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
success = yield try_backfill([
|
||||
dom for dom in likely_domains
|
||||
if dom not in tried_domains
|
||||
])
|
||||
if success:
|
||||
defer.returnValue(True)
|
||||
|
||||
tried_domains.update(likely_domains)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, target_host, event):
|
||||
""" Sends the invite to the remote server for signing.
|
||||
@@ -600,14 +380,30 @@ class FederationHandler(BaseHandler):
|
||||
# FIXME
|
||||
pass
|
||||
|
||||
yield self._handle_auth_events(
|
||||
origin, [e for e in auth_chain if e.event_id != event.event_id]
|
||||
)
|
||||
for e in auth_chain:
|
||||
e.internal_metadata.outlier = True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_state(e):
|
||||
if e.event_id == event.event_id:
|
||||
return
|
||||
continue
|
||||
|
||||
try:
|
||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||
auth = {
|
||||
(e.type, e.state_key): e for e in auth_chain
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
yield self._handle_new_event(
|
||||
origin, e, auth_events=auth
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
for e in state:
|
||||
if e.event_id == event.event_id:
|
||||
continue
|
||||
|
||||
e.internal_metadata.outlier = True
|
||||
try:
|
||||
@@ -625,15 +421,13 @@ class FederationHandler(BaseHandler):
|
||||
e.event_id,
|
||||
)
|
||||
|
||||
yield defer.DeferredList([handle_state(e) for e in state])
|
||||
|
||||
auth_ids = [e_id for e_id, _ in event.auth_events]
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_chain
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
|
||||
_, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||
yield self._handle_new_event(
|
||||
origin,
|
||||
new_event,
|
||||
state=state,
|
||||
@@ -641,11 +435,9 @@ class FederationHandler(BaseHandler):
|
||||
auth_events=auth_events,
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
d = self.notifier.on_new_room_event(
|
||||
new_event, event_stream_id, max_stream_id,
|
||||
extra_users=[joinee]
|
||||
)
|
||||
d = self.notifier.on_new_room_event(
|
||||
new_event, extra_users=[joinee]
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
@@ -710,9 +502,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
event.internal_metadata.outlier = False
|
||||
|
||||
context, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||
origin, event
|
||||
)
|
||||
context = yield self._handle_new_event(origin, event)
|
||||
|
||||
logger.debug(
|
||||
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
|
||||
@@ -726,10 +516,9 @@ class FederationHandler(BaseHandler):
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id, extra_users=extra_users
|
||||
)
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
@@ -802,18 +591,16 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
context = yield self.state_handler.compute_event_context(event)
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=False,
|
||||
)
|
||||
|
||||
target_user = UserID.from_string(event.state_key)
|
||||
with PreserveLoggingContext():
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, event_stream_id, max_stream_id,
|
||||
extra_users=[target_user],
|
||||
)
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, extra_users=[target_user],
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
@@ -945,10 +732,8 @@ class FederationHandler(BaseHandler):
|
||||
event.event_id, event.signatures,
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
context = yield self.state_handler.compute_event_context(
|
||||
event, old_state=state, outlier=outlier,
|
||||
event, old_state=state
|
||||
)
|
||||
|
||||
if not auth_events:
|
||||
@@ -959,17 +744,14 @@ class FederationHandler(BaseHandler):
|
||||
event.event_id, auth_events,
|
||||
)
|
||||
|
||||
is_new_state = not outlier
|
||||
is_new_state = not event.internal_metadata.is_outlier()
|
||||
|
||||
# This is a hack to fix some old rooms where the initial join event
|
||||
# didn't reference the create event in its auth events.
|
||||
if event.type == EventTypes.Member and not event.auth_events:
|
||||
if len(event.prev_events) == 1 and event.depth < 5:
|
||||
c = yield self.store.get_event(
|
||||
event.prev_events[0][0],
|
||||
allow_none=True,
|
||||
)
|
||||
if c and c.type == EventTypes.Create:
|
||||
if len(event.prev_events) == 1:
|
||||
c = yield self.store.get_event(event.prev_events[0][0])
|
||||
if c.type == EventTypes.Create:
|
||||
auth_events[(c.type, c.state_key)] = c
|
||||
|
||||
try:
|
||||
@@ -995,7 +777,7 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
raise
|
||||
|
||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||
yield self.store.persist_event(
|
||||
event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
@@ -1003,7 +785,7 @@ class FederationHandler(BaseHandler):
|
||||
current_state=current_state,
|
||||
)
|
||||
|
||||
defer.returnValue((context, event_stream_id, max_stream_id))
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
||||
@@ -1143,7 +925,7 @@ class FederationHandler(BaseHandler):
|
||||
if d in have_events and not have_events[d]
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
if different_events:
|
||||
local_view = dict(auth_events)
|
||||
@@ -1388,52 +1170,3 @@ class FederationHandler(BaseHandler):
|
||||
},
|
||||
"missing": [e.event_id for e in missing_locals],
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_auth_events(self, origin, auth_events):
|
||||
auth_ids_to_deferred = {}
|
||||
|
||||
def process_auth_ev(ev):
|
||||
auth_ids = [e_id for e_id, _ in ev.auth_events]
|
||||
|
||||
prev_ds = [
|
||||
auth_ids_to_deferred[i]
|
||||
for i in auth_ids
|
||||
if i in auth_ids_to_deferred
|
||||
]
|
||||
|
||||
d = defer.Deferred()
|
||||
|
||||
auth_ids_to_deferred[ev.event_id] = d
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def f(*_):
|
||||
ev.internal_metadata.outlier = True
|
||||
|
||||
try:
|
||||
auth = {
|
||||
(e.type, e.state_key): e for e in auth_events
|
||||
if e.event_id in auth_ids
|
||||
}
|
||||
|
||||
yield self._handle_new_event(
|
||||
origin, ev, auth_events=auth
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to handle auth event %s",
|
||||
ev.event_id,
|
||||
)
|
||||
|
||||
d.callback(None)
|
||||
|
||||
if prev_ds:
|
||||
dx = defer.DeferredList(prev_ds)
|
||||
dx.addBoth(f)
|
||||
else:
|
||||
f()
|
||||
|
||||
for e in auth_events:
|
||||
process_auth_ev(e)
|
||||
|
||||
yield defer.DeferredList(auth_ids_to_deferred.values())
|
||||
|
||||
@@ -20,9 +20,8 @@ from synapse.api.errors import RoomError, SynapseError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.types import UserID, RoomStreamToken
|
||||
from synapse.types import UserID
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@@ -90,19 +89,9 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
if not pagin_config.from_token:
|
||||
pagin_config.from_token = (
|
||||
yield self.hs.get_event_sources().get_current_token(
|
||||
direction='b'
|
||||
)
|
||||
yield self.hs.get_event_sources().get_current_token()
|
||||
)
|
||||
|
||||
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
|
||||
if room_token.topological is None:
|
||||
raise SynapseError(400, "Invalid token")
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, room_token.topological
|
||||
)
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
events, next_key = yield data_source.get_pagination_rows(
|
||||
@@ -261,31 +250,47 @@ class MessageHandler(BaseHandler):
|
||||
is joined on, may return a "messages" key with messages, depending
|
||||
on the specified PaginationConfig.
|
||||
"""
|
||||
start_time = self.clock.time_msec()
|
||||
|
||||
def delta():
|
||||
return self.clock.time_msec() - start_time
|
||||
|
||||
logger.info("initial_sync: start")
|
||||
room_list = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user_id,
|
||||
membership_list=[Membership.INVITE, Membership.JOIN]
|
||||
)
|
||||
|
||||
logger.info("initial_sync: got_rooms %d", delta())
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
rooms_ret = []
|
||||
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
||||
logger.info("initial_sync: now_token %d", delta())
|
||||
|
||||
presence_stream = self.hs.get_event_sources().sources["presence"]
|
||||
pagination_config = PaginationConfig(from_token=now_token)
|
||||
presence, _ = yield presence_stream.get_pagination_rows(
|
||||
user, pagination_config.get_source_config("presence"), None
|
||||
)
|
||||
|
||||
logger.info("initial_sync: presence_done %d", delta())
|
||||
|
||||
public_room_ids = yield self.store.get_public_room_ids()
|
||||
|
||||
logger.info("initial_sync: public_rooms %d", delta())
|
||||
|
||||
limit = pagin_config.limit
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_room(event):
|
||||
logger.info("initial_sync: start: %s %d", event.room_id, delta())
|
||||
|
||||
d = {
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
@@ -314,7 +319,7 @@ class MessageHandler(BaseHandler):
|
||||
event.room_id
|
||||
),
|
||||
]
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||
@@ -336,10 +341,14 @@ class MessageHandler(BaseHandler):
|
||||
except:
|
||||
logger.exception("Failed to get snapshot")
|
||||
|
||||
logger.info("initial_sync: end: %s %d", event.room_id, delta())
|
||||
|
||||
yield defer.gatherResults(
|
||||
[handle_room(e) for e in room_list],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
logger.info("initial_sync: done", delta())
|
||||
|
||||
ret = {
|
||||
"rooms": rooms_ret,
|
||||
|
||||
@@ -18,8 +18,8 @@ from twisted.internet import defer
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.api.constants import PresenceState
|
||||
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.types import UserID
|
||||
import synapse.metrics
|
||||
|
||||
@@ -146,10 +146,6 @@ class PresenceHandler(BaseHandler):
|
||||
self._user_cachemap = {}
|
||||
self._user_cachemap_latest_serial = 0
|
||||
|
||||
# map room_ids to the latest presence serial for a member of that
|
||||
# room
|
||||
self._room_serials = {}
|
||||
|
||||
metrics.register_callback(
|
||||
"userCachemap:size",
|
||||
lambda: len(self._user_cachemap),
|
||||
@@ -282,14 +278,15 @@ class PresenceHandler(BaseHandler):
|
||||
now_online = state["presence"] != PresenceState.OFFLINE
|
||||
was_polling = target_user in self._user_cachemap
|
||||
|
||||
if now_online and not was_polling:
|
||||
self.start_polling_presence(target_user, state=state)
|
||||
elif not now_online and was_polling:
|
||||
self.stop_polling_presence(target_user)
|
||||
with PreserveLoggingContext():
|
||||
if now_online and not was_polling:
|
||||
self.start_polling_presence(target_user, state=state)
|
||||
elif not now_online and was_polling:
|
||||
self.stop_polling_presence(target_user)
|
||||
|
||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||
# we don't have to do this all the time
|
||||
self.changed_presencelike_data(target_user, state)
|
||||
# TODO(paul): perform a presence push as part of start/stop poll so
|
||||
# we don't have to do this all the time
|
||||
self.changed_presencelike_data(target_user, state)
|
||||
|
||||
def bump_presence_active_time(self, user, now=None):
|
||||
if now is None:
|
||||
@@ -301,34 +298,13 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
self.changed_presencelike_data(user, {"last_active": now})
|
||||
|
||||
def get_joined_rooms_for_user(self, user):
|
||||
"""Get the list of rooms a user is joined to.
|
||||
|
||||
Args:
|
||||
user(UserID): The user.
|
||||
Returns:
|
||||
A Deferred of a list of room id strings.
|
||||
"""
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
return rm_handler.get_joined_rooms_for_user(user)
|
||||
|
||||
def get_joined_users_for_room_id(self, room_id):
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
return rm_handler.get_room_members(room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def changed_presencelike_data(self, user, state):
|
||||
"""Updates the presence state of a local user.
|
||||
statuscache = self._get_or_make_usercache(user)
|
||||
|
||||
Args:
|
||||
user(UserID): The user being updated.
|
||||
state(dict): The new presence state for the user.
|
||||
Returns:
|
||||
A Deferred
|
||||
"""
|
||||
self._user_cachemap_latest_serial += 1
|
||||
statuscache = yield self.update_presence_cache(user, state)
|
||||
yield self.push_presence(user, statuscache=statuscache)
|
||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||
|
||||
return self.push_presence(user, statuscache=statuscache)
|
||||
|
||||
@log_function
|
||||
def started_user_eventstream(self, user):
|
||||
@@ -342,21 +318,14 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_joined_room(self, user, room_id):
|
||||
"""Called via the distributor whenever a user joins a room.
|
||||
Notifies the new member of the presence of the current members.
|
||||
Notifies the current members of the room of the new member's presence.
|
||||
|
||||
Args:
|
||||
user(UserID): The user who joined the room.
|
||||
room_id(str): The room id the user joined.
|
||||
"""
|
||||
if self.hs.is_mine(user):
|
||||
statuscache = self._get_or_make_usercache(user)
|
||||
|
||||
# No actual update but we need to bump the serial anyway for the
|
||||
# event source
|
||||
self._user_cachemap_latest_serial += 1
|
||||
statuscache = yield self.update_presence_cache(
|
||||
user, room_ids=[room_id]
|
||||
)
|
||||
statuscache.update({}, serial=self._user_cachemap_latest_serial)
|
||||
|
||||
self.push_update_to_local_and_remote(
|
||||
observed_user=user,
|
||||
room_ids=[room_id],
|
||||
@@ -364,22 +333,18 @@ class PresenceHandler(BaseHandler):
|
||||
)
|
||||
|
||||
# We also want to tell them about current presence of people.
|
||||
curr_users = yield self.get_joined_users_for_room_id(room_id)
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
curr_users = yield rm_handler.get_room_members(room_id)
|
||||
|
||||
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
|
||||
statuscache = yield self.update_presence_cache(
|
||||
local_user, room_ids=[room_id], add_to_cache=False
|
||||
)
|
||||
|
||||
self.push_update_to_local_and_remote(
|
||||
observed_user=local_user,
|
||||
users_to_push=[user],
|
||||
statuscache=statuscache,
|
||||
statuscache=self._get_or_offline_usercache(local_user),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, observer_user, observed_user):
|
||||
"""Request the presence of a local or remote user for a local user"""
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
@@ -414,15 +379,6 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invite_presence(self, observed_user, observer_user):
|
||||
"""Handles a m.presence_invite EDU. A remote or local user has
|
||||
requested presence updates for a local user. If the invite is accepted
|
||||
then allow the local or remote user to see the presence of the local
|
||||
user.
|
||||
|
||||
Args:
|
||||
observed_user(UserID): The local user whose presence is requested.
|
||||
observer_user(UserID): The remote or local user requesting presence.
|
||||
"""
|
||||
accept = yield self._should_accept_invite(observed_user, observer_user)
|
||||
|
||||
if accept:
|
||||
@@ -449,34 +405,16 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_presence(self, observed_user, observer_user):
|
||||
"""Handles a m.presence_accept EDU. Mark a presence invite from a
|
||||
local or remote user as accepted in a local user's presence list.
|
||||
Starts polling for presence updates from the local or remote user.
|
||||
|
||||
Args:
|
||||
observed_user(UserID): The user to update in the presence list.
|
||||
observer_user(UserID): The owner of the presence list to update.
|
||||
"""
|
||||
yield self.store.set_presence_list_accepted(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
self.start_polling_presence(
|
||||
observer_user, target_user=observed_user
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
self.start_polling_presence(
|
||||
observer_user, target_user=observed_user
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deny_presence(self, observed_user, observer_user):
|
||||
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
|
||||
local user's presence list.
|
||||
|
||||
Args:
|
||||
observed_user(UserID): The local or remote user to remove from the
|
||||
list.
|
||||
observer_user(UserID): The local owner of the presence list.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
yield self.store.del_presence_list(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
@@ -485,16 +423,6 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def drop(self, observed_user, observer_user):
|
||||
"""Remove a local or remote user from a local user's presence list and
|
||||
unsubscribe the local user from updates that user.
|
||||
|
||||
Args:
|
||||
observed_user(UserId): The local or remote user to remove from the
|
||||
list.
|
||||
observer_user(UserId): The local owner of the presence list.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
@@ -502,66 +430,34 @@ class PresenceHandler(BaseHandler):
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
self.stop_polling_presence(
|
||||
observer_user, target_user=observed_user
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
self.stop_polling_presence(
|
||||
observer_user, target_user=observed_user
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence_list(self, observer_user, accepted=None):
|
||||
"""Get the presence list for a local user. The retured list includes
|
||||
the current presence state for each user listed.
|
||||
|
||||
Args:
|
||||
observer_user(UserID): The local user whose presence list to fetch.
|
||||
accepted(bool or None): If not none then only include users who
|
||||
have or have not accepted the presence invite request.
|
||||
Returns:
|
||||
A Deferred list of presence state events.
|
||||
"""
|
||||
if not self.hs.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
presence_list = yield self.store.get_presence_list(
|
||||
presence = yield self.store.get_presence_list(
|
||||
observer_user.localpart, accepted=accepted
|
||||
)
|
||||
|
||||
results = []
|
||||
for row in presence_list:
|
||||
observed_user = UserID.from_string(row["observed_user_id"])
|
||||
result = {
|
||||
"observed_user": observed_user, "accepted": row["accepted"]
|
||||
}
|
||||
result.update(
|
||||
self._get_or_offline_usercache(observed_user).get_state()
|
||||
)
|
||||
if "last_active" in result:
|
||||
result["last_active_ago"] = int(
|
||||
self.clock.time_msec() - result.pop("last_active")
|
||||
for p in presence:
|
||||
observed_user = UserID.from_string(p.pop("observed_user_id"))
|
||||
p["observed_user"] = observed_user
|
||||
p.update(self._get_or_offline_usercache(observed_user).get_state())
|
||||
if "last_active" in p:
|
||||
p["last_active_ago"] = int(
|
||||
self.clock.time_msec() - p.pop("last_active")
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
defer.returnValue(results)
|
||||
defer.returnValue(presence)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def start_polling_presence(self, user, target_user=None, state=None):
|
||||
"""Subscribe a local user to presence updates from a local or remote
|
||||
user. If no target_user is supplied then subscribe to all users stored
|
||||
in the presence list for the local user.
|
||||
|
||||
Additonally this pushes the current presence state of this user to all
|
||||
target_users. That state can be provided directly or will be read from
|
||||
the stored state for the local user.
|
||||
|
||||
Also this attempts to notify the local user of the current state of
|
||||
any local target users.
|
||||
|
||||
Args:
|
||||
user(UserID): The local user that whishes for presence updates.
|
||||
target_user(UserID): The local or remote user whose updates are
|
||||
wanted.
|
||||
state(dict): Optional presence state for the local user.
|
||||
"""
|
||||
logger.debug("Start polling for presence from %s", user)
|
||||
|
||||
if target_user:
|
||||
@@ -577,7 +473,8 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
# Also include people in all my rooms
|
||||
|
||||
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
|
||||
|
||||
if state is None:
|
||||
state = yield self.store.get_presence_state(user.localpart)
|
||||
@@ -601,7 +498,9 @@ class PresenceHandler(BaseHandler):
|
||||
# We want to tell the person that just came online
|
||||
# presence state of people they are interested in?
|
||||
self.push_update_to_clients(
|
||||
observed_user=target_user,
|
||||
users_to_push=[user],
|
||||
statuscache=self._get_or_offline_usercache(target_user),
|
||||
)
|
||||
|
||||
deferreds = []
|
||||
@@ -618,12 +517,6 @@ class PresenceHandler(BaseHandler):
|
||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
|
||||
def _start_polling_local(self, user, target_user):
|
||||
"""Subscribe a local user to presence updates for a local user
|
||||
|
||||
Args:
|
||||
user(UserId): The local user that wishes for updates.
|
||||
target_user(UserId): The local users whose updates are wanted.
|
||||
"""
|
||||
target_localpart = target_user.localpart
|
||||
|
||||
if target_localpart not in self._local_pushmap:
|
||||
@@ -632,17 +525,6 @@ class PresenceHandler(BaseHandler):
|
||||
self._local_pushmap[target_localpart].add(user)
|
||||
|
||||
def _start_polling_remote(self, user, domain, remoteusers):
|
||||
"""Subscribe a local user to presence updates for remote users on a
|
||||
given remote domain.
|
||||
|
||||
Args:
|
||||
user(UserID): The local user that wishes for updates.
|
||||
domain(str): The remote server the local user wants updates from.
|
||||
remoteusers(UserID): The remote users that local user wants to be
|
||||
told about.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
to_poll = set()
|
||||
|
||||
for u in remoteusers:
|
||||
@@ -663,17 +545,6 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@log_function
|
||||
def stop_polling_presence(self, user, target_user=None):
|
||||
"""Unsubscribe a local user from presence updates from a local or
|
||||
remote user. If no target user is supplied then unsubscribe the user
|
||||
from all presence updates that the user had subscribed to.
|
||||
|
||||
Args:
|
||||
user(UserID): The local user that no longer wishes for updates.
|
||||
target_user(UserID or None): The user whose updates are no longer
|
||||
wanted.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
logger.debug("Stop polling for presence from %s", user)
|
||||
|
||||
if not target_user or self.hs.is_mine(target_user):
|
||||
@@ -702,13 +573,6 @@ class PresenceHandler(BaseHandler):
|
||||
return defer.DeferredList(deferreds, consumeErrors=True)
|
||||
|
||||
def _stop_polling_local(self, user, target_user):
|
||||
"""Unsubscribe a local user from presence updates from a local user on
|
||||
this server.
|
||||
|
||||
Args:
|
||||
user(UserID): The local user that no longer wishes for updates.
|
||||
target_user(UserID): The user whose updates are no longer wanted.
|
||||
"""
|
||||
for localpart in self._local_pushmap.keys():
|
||||
if target_user and localpart != target_user.localpart:
|
||||
continue
|
||||
@@ -721,17 +585,6 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
@log_function
|
||||
def _stop_polling_remote(self, user, domain, remoteusers):
|
||||
"""Unsubscribe a local user from presence updates from remote users on
|
||||
a given domain.
|
||||
|
||||
Args:
|
||||
user(UserID): The local user that no longer wishes for updates.
|
||||
domain(str): The remote server to unsubscribe from.
|
||||
remoteusers([UserID]): The users on that remote server that the
|
||||
local user no longer wishes to be updated about.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
to_unpoll = set()
|
||||
|
||||
for u in remoteusers:
|
||||
@@ -753,19 +606,6 @@ class PresenceHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def push_presence(self, user, statuscache):
|
||||
"""
|
||||
Notify local and remote users of a change in presence of a local user.
|
||||
Pushes the update to local clients and remote domains that are directly
|
||||
subscribed to the presence of the local user.
|
||||
Also pushes that update to any local user or remote domain that shares
|
||||
a room with the local user.
|
||||
|
||||
Args:
|
||||
user(UserID): The local user whose presence was updated.
|
||||
statuscache(UserPresenceCache): Cache of the user's presence state
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
assert(self.hs.is_mine(user))
|
||||
|
||||
logger.debug("Pushing presence update from %s", user)
|
||||
@@ -777,7 +617,8 @@ class PresenceHandler(BaseHandler):
|
||||
# and also user is informed of server-forced pushes
|
||||
localusers.add(user)
|
||||
|
||||
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
|
||||
|
||||
if not localusers and not room_ids:
|
||||
defer.returnValue(None)
|
||||
@@ -792,23 +633,44 @@ class PresenceHandler(BaseHandler):
|
||||
yield self.distributor.fire("user_presence_changed", user, statuscache)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def incoming_presence(self, origin, content):
|
||||
"""Handle an incoming m.presence EDU.
|
||||
For each presence update in the "push" list update our local cache and
|
||||
notify the appropriate local clients. Only clients that share a room
|
||||
or are directly subscribed to the presence for a user should be
|
||||
notified of the update.
|
||||
For each subscription request in the "poll" list start pushing presence
|
||||
updates to the remote server.
|
||||
For unsubscribe request in the "unpoll" list stop pushing presence
|
||||
updates to the remote server.
|
||||
def _push_presence_remote(self, user, destination, state=None):
|
||||
if state is None:
|
||||
state = yield self.store.get_presence_state(user.localpart)
|
||||
del state["mtime"]
|
||||
state["presence"] = state.pop("state")
|
||||
|
||||
Args:
|
||||
orgin(str): The source of this m.presence EDU.
|
||||
content(dict): The content of this m.presence EDU.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
if user in self._user_cachemap:
|
||||
state["last_active"] = (
|
||||
self._user_cachemap[user].get_state()["last_active"]
|
||||
)
|
||||
|
||||
yield self.distributor.fire(
|
||||
"collect_presencelike_data", user, state
|
||||
)
|
||||
|
||||
if "last_active" in state:
|
||||
state = dict(state)
|
||||
state["last_active_ago"] = int(
|
||||
self.clock.time_msec() - state.pop("last_active")
|
||||
)
|
||||
|
||||
user_state = {
|
||||
"user_id": user.to_string(),
|
||||
}
|
||||
user_state.update(**state)
|
||||
|
||||
yield self.federation.send_edu(
|
||||
destination=destination,
|
||||
edu_type="m.presence",
|
||||
content={
|
||||
"push": [
|
||||
user_state,
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def incoming_presence(self, origin, content):
|
||||
deferreds = []
|
||||
|
||||
for push in content.get("push", []):
|
||||
@@ -822,7 +684,8 @@ class PresenceHandler(BaseHandler):
|
||||
" | %d interested local observers %r", len(observers), observers
|
||||
)
|
||||
|
||||
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||
rm_handler = self.homeserver.get_handlers().room_member_handler
|
||||
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
|
||||
if room_ids:
|
||||
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
|
||||
|
||||
@@ -841,15 +704,20 @@ class PresenceHandler(BaseHandler):
|
||||
self.clock.time_msec() - state.pop("last_active_ago")
|
||||
)
|
||||
|
||||
statuscache = self._get_or_make_usercache(user)
|
||||
|
||||
self._user_cachemap_latest_serial += 1
|
||||
yield self.update_presence_cache(user, state, room_ids=room_ids)
|
||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||
|
||||
if not observers and not room_ids:
|
||||
logger.debug(" | no interested observers or room IDs")
|
||||
continue
|
||||
|
||||
self.push_update_to_clients(
|
||||
users_to_push=observers, room_ids=room_ids
|
||||
observed_user=user,
|
||||
users_to_push=observers,
|
||||
room_ids=room_ids,
|
||||
statuscache=statuscache,
|
||||
)
|
||||
|
||||
user_id = user.to_string()
|
||||
@@ -898,58 +766,13 @@ class PresenceHandler(BaseHandler):
|
||||
if not self._remote_sendmap[user]:
|
||||
del self._remote_sendmap[user]
|
||||
|
||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_presence_cache(self, user, state={}, room_ids=None,
|
||||
add_to_cache=True):
|
||||
"""Update the presence cache for a user with a new state and bump the
|
||||
serial to the latest value.
|
||||
|
||||
Args:
|
||||
user(UserID): The user being updated
|
||||
state(dict): The presence state being updated
|
||||
room_ids(None or list of str): A list of room_ids to update. If
|
||||
room_ids is None then fetch the list of room_ids the user is
|
||||
joined to.
|
||||
add_to_cache: Whether to add an entry to the presence cache if the
|
||||
user isn't already in the cache.
|
||||
Returns:
|
||||
A Deferred UserPresenceCache for the user being updated.
|
||||
"""
|
||||
if room_ids is None:
|
||||
room_ids = yield self.get_joined_rooms_for_user(user)
|
||||
|
||||
for room_id in room_ids:
|
||||
self._room_serials[room_id] = self._user_cachemap_latest_serial
|
||||
if add_to_cache:
|
||||
statuscache = self._get_or_make_usercache(user)
|
||||
else:
|
||||
statuscache = self._get_or_offline_usercache(user)
|
||||
statuscache.update(state, serial=self._user_cachemap_latest_serial)
|
||||
defer.returnValue(statuscache)
|
||||
with PreserveLoggingContext():
|
||||
yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_update_to_local_and_remote(self, observed_user, statuscache,
|
||||
users_to_push=[], room_ids=[],
|
||||
remote_domains=[]):
|
||||
"""Notify local clients and remote servers of a change in the presence
|
||||
of a user.
|
||||
|
||||
Args:
|
||||
observed_user(UserID): The user to push the presence state for.
|
||||
statuscache(UserPresenceCache): The cache for the presence state to
|
||||
push.
|
||||
users_to_push([UserID]): A list of local and remote users to
|
||||
notify.
|
||||
room_ids([str]): Notify the local and remote occupants of these
|
||||
rooms.
|
||||
remote_domains([str]): A list of remote servers to notify in
|
||||
addition to those implied by the users_to_push and the
|
||||
room_ids.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
|
||||
localusers, remoteusers = partitionbool(
|
||||
users_to_push,
|
||||
@@ -959,7 +782,10 @@ class PresenceHandler(BaseHandler):
|
||||
localusers = set(localusers)
|
||||
|
||||
self.push_update_to_clients(
|
||||
users_to_push=localusers, room_ids=room_ids
|
||||
observed_user=observed_user,
|
||||
users_to_push=localusers,
|
||||
room_ids=room_ids,
|
||||
statuscache=statuscache,
|
||||
)
|
||||
|
||||
remote_domains = set(remote_domains)
|
||||
@@ -984,65 +810,11 @@ class PresenceHandler(BaseHandler):
|
||||
|
||||
defer.returnValue((localusers, remote_domains))
|
||||
|
||||
def push_update_to_clients(self, users_to_push=[], room_ids=[]):
|
||||
"""Notify clients of a new presence event.
|
||||
|
||||
Args:
|
||||
users_to_push([UserID]): List of users to notify.
|
||||
room_ids([str]): List of room_ids to notify.
|
||||
"""
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_user_event(
|
||||
"presence_key",
|
||||
self._user_cachemap_latest_serial,
|
||||
users_to_push,
|
||||
room_ids,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_presence_remote(self, user, destination, state=None):
|
||||
"""Push a user's presence to a remote server. If a presence state event
|
||||
that event is sent. Otherwise a new state event is constructed from the
|
||||
stored presence state.
|
||||
The last_active is replaced with last_active_ago in case the wallclock
|
||||
time on the remote server is different to the time on this server.
|
||||
Sends an EDU to the remote server with the current presence state.
|
||||
|
||||
Args:
|
||||
user(UserID): The user to push the presence state for.
|
||||
destination(str): The remote server to send state to.
|
||||
state(dict): The state to push, or None to use the current stored
|
||||
state.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
if state is None:
|
||||
state = yield self.store.get_presence_state(user.localpart)
|
||||
del state["mtime"]
|
||||
state["presence"] = state.pop("state")
|
||||
|
||||
if user in self._user_cachemap:
|
||||
state["last_active"] = (
|
||||
self._user_cachemap[user].get_state()["last_active"]
|
||||
)
|
||||
|
||||
yield self.distributor.fire(
|
||||
"collect_presencelike_data", user, state
|
||||
)
|
||||
|
||||
if "last_active" in state:
|
||||
state = dict(state)
|
||||
state["last_active_ago"] = int(
|
||||
self.clock.time_msec() - state.pop("last_active")
|
||||
)
|
||||
|
||||
user_state = {"user_id": user.to_string(), }
|
||||
user_state.update(state)
|
||||
|
||||
yield self.federation.send_edu(
|
||||
destination=destination,
|
||||
edu_type="m.presence",
|
||||
content={"push": [user_state, ], }
|
||||
def push_update_to_clients(self, observed_user, users_to_push=[],
|
||||
room_ids=[], statuscache=None):
|
||||
self.notifier.on_new_user_event(
|
||||
users_to_push,
|
||||
room_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -1051,11 +823,39 @@ class PresenceEventSource(object):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_visible(self, observer_user, observed_user):
|
||||
if observer_user == observed_user:
|
||||
defer.returnValue(True)
|
||||
|
||||
presence = self.hs.get_handlers().presence_handler
|
||||
|
||||
if (yield presence.store.user_rooms_intersect(
|
||||
[u.to_string() for u in observer_user, observed_user])):
|
||||
defer.returnValue(True)
|
||||
|
||||
if self.hs.is_mine(observed_user):
|
||||
pushmap = presence._local_pushmap
|
||||
|
||||
defer.returnValue(
|
||||
observed_user.localpart in pushmap and
|
||||
observer_user in pushmap[observed_user.localpart]
|
||||
)
|
||||
else:
|
||||
recvmap = presence._remote_recvmap
|
||||
|
||||
defer.returnValue(
|
||||
observed_user in recvmap and
|
||||
observer_user in recvmap[observed_user]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
from_key = int(from_key)
|
||||
|
||||
observer_user = user
|
||||
|
||||
presence = self.hs.get_handlers().presence_handler
|
||||
cachemap = presence._user_cachemap
|
||||
|
||||
@@ -1064,27 +864,17 @@ class PresenceEventSource(object):
|
||||
clock = self.clock
|
||||
latest_serial = 0
|
||||
|
||||
user_ids_to_check = {user}
|
||||
presence_list = yield presence.store.get_presence_list(
|
||||
user.localpart, accepted=True
|
||||
)
|
||||
if presence_list is not None:
|
||||
user_ids_to_check |= set(
|
||||
UserID.from_string(p["observed_user_id"]) for p in presence_list
|
||||
)
|
||||
room_ids = yield presence.get_joined_rooms_for_user(user)
|
||||
for room_id in set(room_ids) & set(presence._room_serials):
|
||||
if presence._room_serials[room_id] > from_key:
|
||||
joined = yield presence.get_joined_users_for_room_id(room_id)
|
||||
user_ids_to_check |= set(joined)
|
||||
|
||||
updates = []
|
||||
for observed_user in user_ids_to_check & set(cachemap):
|
||||
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
||||
for observed_user in cachemap.keys():
|
||||
cached = cachemap[observed_user]
|
||||
|
||||
if cached.serial <= from_key or cached.serial > max_serial:
|
||||
continue
|
||||
|
||||
if not (yield self.is_visible(observer_user, observed_user)):
|
||||
continue
|
||||
|
||||
latest_serial = max(cached.serial, latest_serial)
|
||||
updates.append(cached.make_event(user=observed_user, clock=clock))
|
||||
|
||||
@@ -1121,6 +911,8 @@ class PresenceEventSource(object):
|
||||
def get_pagination_rows(self, user, pagination_config, key):
|
||||
# TODO (erikj): Does this make sense? Ordering?
|
||||
|
||||
observer_user = user
|
||||
|
||||
from_key = int(pagination_config.from_key)
|
||||
|
||||
if pagination_config.to_key:
|
||||
@@ -1131,26 +923,14 @@ class PresenceEventSource(object):
|
||||
presence = self.hs.get_handlers().presence_handler
|
||||
cachemap = presence._user_cachemap
|
||||
|
||||
user_ids_to_check = {user}
|
||||
presence_list = yield presence.store.get_presence_list(
|
||||
user.localpart, accepted=True
|
||||
)
|
||||
if presence_list is not None:
|
||||
user_ids_to_check |= set(
|
||||
UserID.from_string(p["observed_user_id"]) for p in presence_list
|
||||
)
|
||||
room_ids = yield presence.get_joined_rooms_for_user(user)
|
||||
for room_id in set(room_ids) & set(presence._room_serials):
|
||||
if presence._room_serials[room_id] >= from_key:
|
||||
joined = yield presence.get_joined_users_for_room_id(room_id)
|
||||
user_ids_to_check |= set(joined)
|
||||
|
||||
updates = []
|
||||
for observed_user in user_ids_to_check & set(cachemap):
|
||||
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
||||
for observed_user in cachemap.keys():
|
||||
if not (to_key < cachemap[observed_user].serial <= from_key):
|
||||
continue
|
||||
|
||||
updates.append((observed_user, cachemap[observed_user]))
|
||||
if (yield self.is_visible(observer_user, observed_user)):
|
||||
updates.append((observed_user, cachemap[observed_user]))
|
||||
|
||||
# TODO(paul): limit
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.types import UserID
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@@ -88,9 +88,6 @@ class ProfileHandler(BaseHandler):
|
||||
if target_user != auth_user:
|
||||
raise AuthError(400, "Cannot set another user's displayname")
|
||||
|
||||
if new_displayname == '':
|
||||
new_displayname = None
|
||||
|
||||
yield self.store.set_profile_displayname(
|
||||
target_user.localpart, new_displayname
|
||||
)
|
||||
@@ -157,13 +154,14 @@ class ProfileHandler(BaseHandler):
|
||||
if not self.hs.is_mine(user):
|
||||
defer.returnValue(None)
|
||||
|
||||
(displayname, avatar_url) = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_profile_displayname(user.localpart),
|
||||
self.store.get_profile_avatar_url(user.localpart),
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
with PreserveLoggingContext():
|
||||
(displayname, avatar_url) = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_profile_displayname(user.localpart),
|
||||
self.store.get_profile_avatar_url(user.localpart),
|
||||
],
|
||||
consumeErrors=True
|
||||
)
|
||||
|
||||
state["displayname"] = displayname
|
||||
state["avatar_url"] = avatar_url
|
||||
|
||||
@@ -21,12 +21,11 @@ from ._base import BaseHandler
|
||||
from synapse.types import UserID, RoomAlias, RoomID
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import StoreError, SynapseError
|
||||
from synapse.util import stringutils, unwrapFirstError
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.events.utils import serialize_event
|
||||
|
||||
import logging
|
||||
import string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,10 +50,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
self.ratelimit(user_id)
|
||||
|
||||
if "room_alias_name" in config:
|
||||
for wchar in string.whitespace:
|
||||
if wchar in config["room_alias_name"]:
|
||||
raise SynapseError(400, "Invalid characters in room alias")
|
||||
|
||||
room_alias = RoomAlias.create(
|
||||
config["room_alias_name"],
|
||||
self.hs.hostname,
|
||||
@@ -534,17 +529,11 @@ class RoomListHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
def get_public_room_list(self):
|
||||
chunk = yield self.store.get_rooms(is_public=True)
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_users_in_room(room["room_id"])
|
||||
for room in chunk
|
||||
],
|
||||
consumeErrors=True,
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
for i, room in enumerate(chunk):
|
||||
room["num_joined_members"] = len(results[i])
|
||||
|
||||
for room in chunk:
|
||||
joined_users = yield self.store.get_users_in_room(
|
||||
room_id=room["room_id"],
|
||||
)
|
||||
room["num_joined_members"] = len(joined_users)
|
||||
# FIXME (erikj): START is no longer a valid value
|
||||
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
|
||||
|
||||
@@ -580,8 +569,8 @@ class RoomEventSource(object):
|
||||
|
||||
defer.returnValue((events, end_key))
|
||||
|
||||
def get_current_key(self, direction='f'):
|
||||
return self.store.get_room_events_max_id(direction)
|
||||
def get_current_key(self):
|
||||
return self.store.get_room_events_max_id()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_pagination_rows(self, user, config, key):
|
||||
|
||||
@@ -92,7 +92,7 @@ class SyncHandler(BaseHandler):
|
||||
result = yield self.current_sync_for_user(sync_config, since_token)
|
||||
defer.returnValue(result)
|
||||
else:
|
||||
def current_sync_callback(before_token, after_token):
|
||||
def current_sync_callback():
|
||||
return self.current_sync_for_user(sync_config, since_token)
|
||||
|
||||
rm_handler = self.hs.get_handlers().room_member_handler
|
||||
|
||||
@@ -18,7 +18,6 @@ from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.api.errors import SynapseError, AuthError
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.types import UserID
|
||||
|
||||
import logging
|
||||
@@ -217,10 +216,7 @@ class TypingNotificationHandler(BaseHandler):
|
||||
self._latest_room_serial += 1
|
||||
self._room_serials[room_id] = self._latest_room_serial
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.notifier.on_new_user_event(
|
||||
"typing_key", self._latest_room_serial, rooms=[room_id]
|
||||
)
|
||||
self.notifier.on_new_user_event(rooms=[room_id])
|
||||
|
||||
|
||||
class TypingNotificationEventSource(object):
|
||||
|
||||
@@ -14,14 +14,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.errors import CodeMessageException
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
import synapse.metrics
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.web.client import (
|
||||
Agent, readBody, FileBodyProducer, PartialDownloadError,
|
||||
HTTPConnectionPool,
|
||||
Agent, readBody, FileBodyProducer, PartialDownloadError
|
||||
)
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
@@ -56,19 +54,14 @@ class SimpleHttpClient(object):
|
||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||
# 'like a browser'
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
pool.maxPersistentPerHost = 10
|
||||
self.agent = Agent(reactor, pool=pool)
|
||||
self.agent = Agent(reactor)
|
||||
self.version_string = hs.version_string
|
||||
|
||||
def request(self, method, *args, **kwargs):
|
||||
# A small wrapper around self.agent.request() so we can easily attach
|
||||
# counters to it
|
||||
outgoing_requests_counter.inc(method)
|
||||
d = preserve_context_over_fn(
|
||||
self.agent.request,
|
||||
method, *args, **kwargs
|
||||
)
|
||||
d = self.agent.request(method, *args, **kwargs)
|
||||
|
||||
def _cb(response):
|
||||
incoming_responses_counter.inc(method, response.code)
|
||||
|
||||
@@ -16,13 +16,13 @@
|
||||
|
||||
from twisted.internet import defer, reactor, protocol
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
|
||||
from twisted.web.client import readBody, _AgentBase, _URI
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web._newclient import ResponseDone
|
||||
|
||||
from synapse.http.endpoint import matrix_federation_endpoint
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
import synapse.metrics
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
@@ -103,17 +103,14 @@ class MatrixFederationHttpClient(object):
|
||||
self.hs = hs
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
pool.maxPersistentPerHost = 10
|
||||
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
|
||||
self.agent = MatrixFederationHttpAgent(reactor)
|
||||
self.clock = hs.get_clock()
|
||||
self.version_string = hs.version_string
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_request(self, destination, method, path_bytes,
|
||||
body_callback, headers_dict={}, param_bytes=b"",
|
||||
query_bytes=b"", retry_on_dns_fail=True,
|
||||
timeout=None):
|
||||
query_bytes=b"", retry_on_dns_fail=True):
|
||||
""" Creates and sends a request to the given url
|
||||
"""
|
||||
headers_dict[b"User-Agent"] = [self.version_string]
|
||||
@@ -147,22 +144,22 @@ class MatrixFederationHttpClient(object):
|
||||
producer = body_callback(method, url_bytes, headers_dict)
|
||||
|
||||
try:
|
||||
request_deferred = preserve_context_over_fn(
|
||||
self.agent.request,
|
||||
destination,
|
||||
endpoint,
|
||||
method,
|
||||
path_bytes,
|
||||
param_bytes,
|
||||
query_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
request_deferred = self.agent.request(
|
||||
destination,
|
||||
endpoint,
|
||||
method,
|
||||
path_bytes,
|
||||
param_bytes,
|
||||
query_bytes,
|
||||
Headers(headers_dict),
|
||||
producer
|
||||
)
|
||||
|
||||
response = yield self.clock.time_bound_deferred(
|
||||
request_deferred,
|
||||
time_out=timeout/1000. if timeout else 60,
|
||||
)
|
||||
response = yield self.clock.time_bound_deferred(
|
||||
request_deferred,
|
||||
time_out=60,
|
||||
)
|
||||
|
||||
logger.debug("Got response to %s", method)
|
||||
break
|
||||
@@ -184,7 +181,7 @@ class MatrixFederationHttpClient(object):
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
if retries_left and not timeout:
|
||||
if retries_left:
|
||||
yield sleep(2 ** (5 - retries_left))
|
||||
retries_left -= 1
|
||||
else:
|
||||
@@ -337,8 +334,7 @@ class MatrixFederationHttpClient(object):
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
|
||||
timeout=None):
|
||||
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
|
||||
""" GETs some json from the given host homeserver and path
|
||||
|
||||
Args:
|
||||
@@ -347,9 +343,6 @@ class MatrixFederationHttpClient(object):
|
||||
path (str): The HTTP path.
|
||||
args (dict): A dictionary used to create query strings, defaults to
|
||||
None.
|
||||
timeout (int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout and that the request will
|
||||
be retried.
|
||||
Returns:
|
||||
Deferred: Succeeds when we get *any* HTTP response.
|
||||
|
||||
@@ -377,8 +370,7 @@ class MatrixFederationHttpClient(object):
|
||||
path.encode("ascii"),
|
||||
query_bytes=query_bytes,
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
timeout=timeout,
|
||||
retry_on_dns_fail=retry_on_dns_fail
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
|
||||
@@ -17,12 +17,11 @@
|
||||
from synapse.api.errors import (
|
||||
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
|
||||
)
|
||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
import synapse.metrics
|
||||
import synapse.events
|
||||
|
||||
from syutil.jsonutil import (
|
||||
encode_canonical_json, encode_pretty_printed_json, encode_json
|
||||
encode_canonical_json, encode_pretty_printed_json
|
||||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -86,9 +85,7 @@ def request_handler(request_handler):
|
||||
"Received request: %s %s",
|
||||
request.method, request.path
|
||||
)
|
||||
d = request_handler(self, request)
|
||||
with PreserveLoggingContext():
|
||||
yield d
|
||||
yield request_handler(self, request)
|
||||
code = request.code
|
||||
except CodeMessageException as e:
|
||||
code = e.code
|
||||
@@ -169,10 +166,9 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
|
||||
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
|
||||
|
||||
def __init__(self, hs, canonical_json=True):
|
||||
def __init__(self, hs):
|
||||
resource.Resource.__init__(self)
|
||||
|
||||
self.canonical_json = canonical_json
|
||||
self.clock = hs.get_clock()
|
||||
self.path_regexs = {}
|
||||
self.version_string = hs.version_string
|
||||
@@ -258,7 +254,6 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
response_code_message=response_code_message,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
version_string=self.version_string,
|
||||
canonical_json=self.canonical_json,
|
||||
)
|
||||
|
||||
|
||||
@@ -280,16 +275,11 @@ class RootRedirect(resource.Resource):
|
||||
|
||||
def respond_with_json(request, code, json_object, send_cors=False,
|
||||
response_code_message=None, pretty_print=False,
|
||||
version_string="", canonical_json=True):
|
||||
version_string=""):
|
||||
if pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
||||
else:
|
||||
if canonical_json:
|
||||
json_bytes = encode_canonical_json(json_object)
|
||||
else:
|
||||
json_bytes = encode_json(
|
||||
json_object, using_frozen_dicts=synapse.events.USE_FROZEN_DICTS
|
||||
)
|
||||
json_bytes = encode_canonical_json(json_object)
|
||||
|
||||
return respond_with_json_bytes(
|
||||
request, code, json_bytes,
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import PreserveLoggingContext
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.types import StreamToken
|
||||
import synapse.metrics
|
||||
@@ -42,78 +43,63 @@ def count(func, l):
|
||||
|
||||
class _NotificationListener(object):
|
||||
""" This represents a single client connection to the events stream.
|
||||
|
||||
The events stream handler will have yielded to the deferred, so to
|
||||
notify the handler it is sufficient to resolve the deferred.
|
||||
"""
|
||||
|
||||
def __init__(self, deferred):
|
||||
self.deferred = deferred
|
||||
|
||||
def notified(self):
|
||||
return self.deferred.called
|
||||
|
||||
def notify(self, token):
|
||||
""" Inform whoever is listening about the new events.
|
||||
"""
|
||||
try:
|
||||
self.deferred.callback(token)
|
||||
except defer.AlreadyCalledError:
|
||||
pass
|
||||
|
||||
|
||||
class _NotifierUserStream(object):
|
||||
"""This represents a user connected to the event stream.
|
||||
It tracks the most recent stream token for that user.
|
||||
At a given point a user may have a number of streams listening for
|
||||
events.
|
||||
|
||||
This listener will also keep track of which rooms it is listening in
|
||||
so that it can remove itself from the indexes in the Notifier class.
|
||||
"""
|
||||
|
||||
def __init__(self, user, rooms, current_token, time_now_ms,
|
||||
def __init__(self, user, rooms, from_token, limit, timeout, deferred,
|
||||
appservice=None):
|
||||
self.user = str(user)
|
||||
self.user = user
|
||||
self.appservice = appservice
|
||||
self.listeners = set()
|
||||
self.rooms = set(rooms)
|
||||
self.current_token = current_token
|
||||
self.last_notified_ms = time_now_ms
|
||||
self.from_token = from_token
|
||||
self.limit = limit
|
||||
self.timeout = timeout
|
||||
self.deferred = deferred
|
||||
self.rooms = rooms
|
||||
self.timer = None
|
||||
|
||||
def notify(self, stream_key, stream_id, time_now_ms):
|
||||
"""Notify any listeners for this user of a new event from an
|
||||
event source.
|
||||
Args:
|
||||
stream_key(str): The stream the event came from.
|
||||
stream_id(str): The new id for the stream the event came from.
|
||||
time_now_ms(int): The current time in milliseconds.
|
||||
"""
|
||||
self.current_token = self.current_token.copy_and_advance(
|
||||
stream_key, stream_id
|
||||
)
|
||||
if self.listeners:
|
||||
self.last_notified_ms = time_now_ms
|
||||
listeners = self.listeners
|
||||
self.listeners = set()
|
||||
for listener in listeners:
|
||||
listener.notify(self.current_token)
|
||||
def notified(self):
|
||||
return self.deferred.called
|
||||
|
||||
def remove(self, notifier):
|
||||
""" Remove this listener from all the indexes in the Notifier
|
||||
def notify(self, notifier, events, start_token, end_token):
|
||||
""" Inform whoever is listening about the new events. This will
|
||||
also remove this listener from all the indexes in the Notifier
|
||||
it knows about.
|
||||
"""
|
||||
|
||||
result = (events, (start_token, end_token))
|
||||
|
||||
try:
|
||||
self.deferred.callback(result)
|
||||
notified_events_counter.inc_by(len(events))
|
||||
except defer.AlreadyCalledError:
|
||||
pass
|
||||
|
||||
# Should the following be done be using intrusively linked lists?
|
||||
# -- erikj
|
||||
|
||||
for room in self.rooms:
|
||||
lst = notifier.room_to_user_streams.get(room, set())
|
||||
lst = notifier.room_to_listeners.get(room, set())
|
||||
lst.discard(self)
|
||||
|
||||
notifier.user_to_user_stream.pop(self.user)
|
||||
notifier.user_to_listeners.get(self.user, set()).discard(self)
|
||||
|
||||
if self.appservice:
|
||||
notifier.appservice_to_user_streams.get(
|
||||
notifier.appservice_to_listeners.get(
|
||||
self.appservice, set()
|
||||
).discard(self)
|
||||
|
||||
# Cancel the timeout for this notifer if one exists.
|
||||
if self.timer is not None:
|
||||
try:
|
||||
notifier.clock.cancel_call_later(self.timer)
|
||||
except:
|
||||
logger.warn("Failed to cancel notifier timer")
|
||||
|
||||
|
||||
class Notifier(object):
|
||||
""" This class is responsible for notifying any listeners when there are
|
||||
@@ -122,18 +108,14 @@ class Notifier(object):
|
||||
Primarily used from the /events stream.
|
||||
"""
|
||||
|
||||
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
|
||||
self.user_to_user_stream = {}
|
||||
self.room_to_user_streams = {}
|
||||
self.appservice_to_user_streams = {}
|
||||
self.room_to_listeners = {}
|
||||
self.user_to_listeners = {}
|
||||
self.appservice_to_listeners = {}
|
||||
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.store = hs.get_datastore()
|
||||
self.pending_new_room_events = []
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@@ -141,80 +123,47 @@ class Notifier(object):
|
||||
"user_joined_room", self._user_joined_room
|
||||
)
|
||||
|
||||
self.clock.looping_call(
|
||||
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
|
||||
)
|
||||
|
||||
# This is not a very cheap test to perform, but it's only executed
|
||||
# when rendering the metrics page, which is likely once per minute at
|
||||
# most when scraping it.
|
||||
def count_listeners():
|
||||
all_user_streams = set()
|
||||
all_listeners = set()
|
||||
|
||||
for x in self.room_to_user_streams.values():
|
||||
all_user_streams |= x
|
||||
for x in self.user_to_user_stream.values():
|
||||
all_user_streams.add(x)
|
||||
for x in self.appservice_to_user_streams.values():
|
||||
all_user_streams |= x
|
||||
for x in self.room_to_listeners.values():
|
||||
all_listeners |= x
|
||||
for x in self.user_to_listeners.values():
|
||||
all_listeners |= x
|
||||
for x in self.appservice_to_listeners.values():
|
||||
all_listeners |= x
|
||||
|
||||
return sum(len(stream.listeners) for stream in all_user_streams)
|
||||
return len(all_listeners)
|
||||
metrics.register_callback("listeners", count_listeners)
|
||||
|
||||
metrics.register_callback(
|
||||
"rooms",
|
||||
lambda: count(bool, self.room_to_user_streams.values()),
|
||||
lambda: count(bool, self.room_to_listeners.values()),
|
||||
)
|
||||
metrics.register_callback(
|
||||
"users",
|
||||
lambda: len(self.user_to_user_stream),
|
||||
lambda: count(bool, self.user_to_listeners.values()),
|
||||
)
|
||||
metrics.register_callback(
|
||||
"appservices",
|
||||
lambda: count(bool, self.appservice_to_user_streams.values()),
|
||||
lambda: count(bool, self.appservice_to_listeners.values()),
|
||||
)
|
||||
|
||||
@log_function
|
||||
@defer.inlineCallbacks
|
||||
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
|
||||
extra_users=[]):
|
||||
def on_new_room_event(self, event, extra_users=[]):
|
||||
""" Used by handlers to inform the notifier something has happened
|
||||
in the room, room event wise.
|
||||
|
||||
This triggers the notifier to wake up any listeners that are
|
||||
listening to the room, and any listeners for the users in the
|
||||
`extra_users` param.
|
||||
|
||||
The events can be peristed out of order. The notifier will wait
|
||||
until all previous events have been persisted before notifying
|
||||
the client streams.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
self.pending_new_room_events.append((
|
||||
room_stream_id, event, extra_users
|
||||
))
|
||||
self._notify_pending_new_room_events(max_room_stream_id)
|
||||
|
||||
def _notify_pending_new_room_events(self, max_room_stream_id):
|
||||
"""Notify for the room events that were queued waiting for a previous
|
||||
event to be persisted.
|
||||
Args:
|
||||
max_room_stream_id(int): The highest stream_id below which all
|
||||
events have been persisted.
|
||||
"""
|
||||
pending = self.pending_new_room_events
|
||||
self.pending_new_room_events = []
|
||||
for room_stream_id, event, extra_users in pending:
|
||||
if room_stream_id > max_room_stream_id:
|
||||
self.pending_new_room_events.append((
|
||||
room_stream_id, event, extra_users
|
||||
))
|
||||
else:
|
||||
self._on_new_room_event(event, room_stream_id, extra_users)
|
||||
|
||||
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
|
||||
"""Notify any user streams that are interested in this room event"""
|
||||
# poke any interested application service.
|
||||
self.hs.get_handlers().appservice_handler.notify_interested_services(
|
||||
event
|
||||
@@ -222,129 +171,194 @@ class Notifier(object):
|
||||
|
||||
room_id = event.room_id
|
||||
|
||||
room_user_streams = self.room_to_user_streams.get(room_id, set())
|
||||
room_source = self.event_sources.sources["room"]
|
||||
|
||||
user_streams = room_user_streams.copy()
|
||||
room_listeners = self.room_to_listeners.get(room_id, set())
|
||||
|
||||
_discard_if_notified(room_listeners)
|
||||
|
||||
listeners = room_listeners.copy()
|
||||
|
||||
for user in extra_users:
|
||||
user_stream = self.user_to_user_stream.get(str(user))
|
||||
if user_stream is not None:
|
||||
user_streams.add(user_stream)
|
||||
user_listeners = self.user_to_listeners.get(user, set())
|
||||
|
||||
for appservice in self.appservice_to_user_streams:
|
||||
_discard_if_notified(user_listeners)
|
||||
|
||||
listeners |= user_listeners
|
||||
|
||||
for appservice in self.appservice_to_listeners:
|
||||
# TODO (kegan): Redundant appservice listener checks?
|
||||
# App services will already be in the room_to_user_streams set, but
|
||||
# App services will already be in the room_to_listeners set, but
|
||||
# that isn't enough. They need to be checked here in order to
|
||||
# receive *invites* for users they are interested in. Does this
|
||||
# make the room_to_user_streams check somewhat obselete?
|
||||
# make the room_to_listeners check somewhat obselete?
|
||||
if appservice.is_interested(event):
|
||||
app_user_streams = self.appservice_to_user_streams.get(
|
||||
app_listeners = self.appservice_to_listeners.get(
|
||||
appservice, set()
|
||||
)
|
||||
user_streams |= app_user_streams
|
||||
|
||||
logger.debug("on_new_room_event listeners %s", user_streams)
|
||||
_discard_if_notified(app_listeners)
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
for user_stream in user_streams:
|
||||
try:
|
||||
user_stream.notify(
|
||||
"room_key", "s%d" % (room_stream_id,), time_now_ms
|
||||
listeners |= app_listeners
|
||||
|
||||
logger.debug("on_new_room_event listeners %s", listeners)
|
||||
|
||||
# TODO (erikj): Can we make this more efficient by hitting the
|
||||
# db once?
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify(listener):
|
||||
events, end_key = yield room_source.get_new_events_for_user(
|
||||
listener.user,
|
||||
listener.from_token.room_key,
|
||||
listener.limit,
|
||||
)
|
||||
|
||||
if events:
|
||||
end_token = listener.from_token.copy_and_replace(
|
||||
"room_key", end_key
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to notify listener")
|
||||
|
||||
listener.notify(
|
||||
self, events, listener.from_token, end_token
|
||||
)
|
||||
|
||||
def eb(failure):
|
||||
logger.exception("Failed to notify listener", failure)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
yield defer.DeferredList(
|
||||
[notify(l).addErrback(eb) for l in listeners],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
|
||||
def on_new_user_event(self, users=[], rooms=[]):
|
||||
""" Used to inform listeners that something has happend
|
||||
presence/user event wise.
|
||||
|
||||
Will wake up all listeners for the given users and rooms.
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
user_streams = set()
|
||||
|
||||
# TODO(paul): This is horrible, having to manually list every event
|
||||
# source here individually
|
||||
presence_source = self.event_sources.sources["presence"]
|
||||
typing_source = self.event_sources.sources["typing"]
|
||||
|
||||
listeners = set()
|
||||
|
||||
for user in users:
|
||||
user_stream = self.user_to_user_stream.get(str(user))
|
||||
if user_stream is not None:
|
||||
user_streams.add(user_stream)
|
||||
user_listeners = self.user_to_listeners.get(user, set())
|
||||
|
||||
_discard_if_notified(user_listeners)
|
||||
|
||||
listeners |= user_listeners
|
||||
|
||||
for room in rooms:
|
||||
user_streams |= self.room_to_user_streams.get(room, set())
|
||||
room_listeners = self.room_to_listeners.get(room, set())
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
for user_stream in user_streams:
|
||||
try:
|
||||
user_stream.notify(stream_key, new_token, time_now_ms)
|
||||
except:
|
||||
logger.exception("Failed to notify listener")
|
||||
_discard_if_notified(room_listeners)
|
||||
|
||||
listeners |= room_listeners
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify(listener):
|
||||
presence_events, presence_end_key = (
|
||||
yield presence_source.get_new_events_for_user(
|
||||
listener.user,
|
||||
listener.from_token.presence_key,
|
||||
listener.limit,
|
||||
)
|
||||
)
|
||||
typing_events, typing_end_key = (
|
||||
yield typing_source.get_new_events_for_user(
|
||||
listener.user,
|
||||
listener.from_token.typing_key,
|
||||
listener.limit,
|
||||
)
|
||||
)
|
||||
|
||||
if presence_events or typing_events:
|
||||
end_token = listener.from_token.copy_and_replace(
|
||||
"presence_key", presence_end_key
|
||||
).copy_and_replace(
|
||||
"typing_key", typing_end_key
|
||||
)
|
||||
|
||||
listener.notify(
|
||||
self,
|
||||
presence_events + typing_events,
|
||||
listener.from_token,
|
||||
end_token
|
||||
)
|
||||
|
||||
def eb(failure):
|
||||
logger.error(
|
||||
"Failed to notify listener",
|
||||
exc_info=(
|
||||
failure.type,
|
||||
failure.value,
|
||||
failure.getTracebackObject())
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
yield defer.DeferredList(
|
||||
[notify(l).addErrback(eb) for l in listeners],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(self, user, rooms, timeout, callback,
|
||||
from_token=StreamToken("s0", "0", "0")):
|
||||
def wait_for_events(self, user, rooms, filter, timeout, callback):
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
timeout fires.
|
||||
"""
|
||||
|
||||
deferred = defer.Deferred()
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
user = str(user)
|
||||
user_stream = self.user_to_user_stream.get(user)
|
||||
if user_stream is None:
|
||||
appservice = yield self.store.get_app_service_by_user_id(user)
|
||||
current_token = yield self.event_sources.get_current_token()
|
||||
rooms = yield self.store.get_rooms_for_user(user)
|
||||
rooms = [room.room_id for room in rooms]
|
||||
user_stream = _NotifierUserStream(
|
||||
user=user,
|
||||
rooms=rooms,
|
||||
appservice=appservice,
|
||||
current_token=current_token,
|
||||
time_now_ms=time_now_ms,
|
||||
)
|
||||
self._register_with_keys(user_stream)
|
||||
else:
|
||||
current_token = user_stream.current_token
|
||||
from_token = StreamToken("s0", "0", "0")
|
||||
|
||||
listener = [_NotificationListener(deferred)]
|
||||
listener = [_NotificationListener(
|
||||
user=user,
|
||||
rooms=rooms,
|
||||
from_token=from_token,
|
||||
limit=1,
|
||||
timeout=timeout,
|
||||
deferred=deferred,
|
||||
)]
|
||||
|
||||
if timeout and not current_token.is_after(from_token):
|
||||
user_stream.listeners.add(listener[0])
|
||||
|
||||
if current_token.is_after(from_token):
|
||||
result = yield callback(from_token, current_token)
|
||||
else:
|
||||
result = None
|
||||
if timeout:
|
||||
self._register_with_keys(listener[0])
|
||||
|
||||
result = yield callback()
|
||||
timer = [None]
|
||||
|
||||
if result:
|
||||
user_stream.listeners.discard(listener[0])
|
||||
defer.returnValue(result)
|
||||
return
|
||||
|
||||
if timeout:
|
||||
timed_out = [False]
|
||||
|
||||
def _timeout_listener():
|
||||
timed_out[0] = True
|
||||
timer[0] = None
|
||||
user_stream.listeners.discard(listener[0])
|
||||
listener[0].notify(current_token)
|
||||
listener[0].notify(self, [], from_token, from_token)
|
||||
|
||||
# We create multiple notification listeners so we have to manage
|
||||
# canceling the timeout ourselves.
|
||||
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
|
||||
|
||||
while not result and not timed_out[0]:
|
||||
new_token = yield deferred
|
||||
yield deferred
|
||||
deferred = defer.Deferred()
|
||||
listener[0] = _NotificationListener(deferred)
|
||||
user_stream.listeners.add(listener[0])
|
||||
result = yield callback(current_token, new_token)
|
||||
current_token = new_token
|
||||
listener[0] = _NotificationListener(
|
||||
user=user,
|
||||
rooms=rooms,
|
||||
from_token=from_token,
|
||||
limit=1,
|
||||
timeout=timeout,
|
||||
deferred=deferred,
|
||||
)
|
||||
self._register_with_keys(listener[0])
|
||||
result = yield callback()
|
||||
|
||||
if timer[0] is not None:
|
||||
try:
|
||||
@@ -354,79 +368,125 @@ class Notifier(object):
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_for(self, user, rooms, pagination_config, timeout):
|
||||
""" For the given user and rooms, return any new events for them. If
|
||||
there are no new events wait for up to `timeout` milliseconds for any
|
||||
new events to happen before returning.
|
||||
"""
|
||||
from_token = pagination_config.from_token
|
||||
deferred = defer.Deferred()
|
||||
|
||||
self._get_events(
|
||||
deferred, user, rooms, pagination_config.from_token,
|
||||
pagination_config.limit, timeout
|
||||
).addErrback(deferred.errback)
|
||||
|
||||
return deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events(self, deferred, user, rooms, from_token, limit, timeout):
|
||||
if not from_token:
|
||||
from_token = yield self.event_sources.get_current_token()
|
||||
|
||||
limit = pagination_config.limit
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_for_updates(before_token, after_token):
|
||||
events = []
|
||||
end_token = from_token
|
||||
for name, source in self.event_sources.sources.items():
|
||||
keyname = "%s_key" % name
|
||||
before_id = getattr(before_token, keyname)
|
||||
after_id = getattr(after_token, keyname)
|
||||
if before_id == after_id:
|
||||
continue
|
||||
stuff, new_key = yield source.get_new_events_for_user(
|
||||
user, getattr(from_token, keyname), limit,
|
||||
)
|
||||
events.extend(stuff)
|
||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||
|
||||
if events:
|
||||
defer.returnValue((events, (from_token, end_token)))
|
||||
else:
|
||||
defer.returnValue(None)
|
||||
|
||||
result = yield self.wait_for_events(
|
||||
user, rooms, timeout, check_for_updates, from_token=from_token
|
||||
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
|
||||
user.to_string()
|
||||
)
|
||||
|
||||
if result is None:
|
||||
result = ([], (from_token, from_token))
|
||||
listener = _NotificationListener(
|
||||
user,
|
||||
rooms,
|
||||
from_token,
|
||||
limit,
|
||||
timeout,
|
||||
deferred,
|
||||
appservice=appservice
|
||||
)
|
||||
|
||||
defer.returnValue(result)
|
||||
def _timeout_listener():
|
||||
# TODO (erikj): We should probably set to_token to the current
|
||||
# max rather than reusing from_token.
|
||||
# Remove the timer from the listener so we don't try to cancel it.
|
||||
listener.timer = None
|
||||
listener.notify(
|
||||
self,
|
||||
[],
|
||||
listener.from_token,
|
||||
listener.from_token,
|
||||
)
|
||||
|
||||
if timeout:
|
||||
self._register_with_keys(listener)
|
||||
|
||||
yield self._check_for_updates(listener)
|
||||
|
||||
if not timeout:
|
||||
_timeout_listener()
|
||||
else:
|
||||
# Only add the timer if the listener hasn't been notified
|
||||
if not listener.notified():
|
||||
listener.timer = self.clock.call_later(
|
||||
timeout/1000.0, _timeout_listener
|
||||
)
|
||||
return
|
||||
|
||||
@log_function
|
||||
def remove_expired_streams(self):
|
||||
time_now_ms = self.clock.time_msec()
|
||||
expired_streams = []
|
||||
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
|
||||
for stream in self.user_to_user_stream.values():
|
||||
if stream.listeners:
|
||||
continue
|
||||
if stream.last_notified_ms < expire_before_ts:
|
||||
expired_streams.append(stream)
|
||||
def _register_with_keys(self, listener):
|
||||
for room in listener.rooms:
|
||||
s = self.room_to_listeners.setdefault(room, set())
|
||||
s.add(listener)
|
||||
|
||||
for expired_stream in expired_streams:
|
||||
expired_stream.remove(self)
|
||||
self.user_to_listeners.setdefault(listener.user, set()).add(listener)
|
||||
|
||||
if listener.appservice:
|
||||
self.appservice_to_listeners.setdefault(
|
||||
listener.appservice, set()
|
||||
).add(listener)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _register_with_keys(self, user_stream):
|
||||
self.user_to_user_stream[user_stream.user] = user_stream
|
||||
def _check_for_updates(self, listener):
|
||||
# TODO (erikj): We need to think about limits across multiple sources
|
||||
events = []
|
||||
|
||||
for room in user_stream.rooms:
|
||||
s = self.room_to_user_streams.setdefault(room, set())
|
||||
s.add(user_stream)
|
||||
from_token = listener.from_token
|
||||
limit = listener.limit
|
||||
|
||||
if user_stream.appservice:
|
||||
self.appservice_to_user_stream.setdefault(
|
||||
user_stream.appservice, set()
|
||||
).add(user_stream)
|
||||
# TODO (erikj): DeferredList?
|
||||
for name, source in self.event_sources.sources.items():
|
||||
keyname = "%s_key" % name
|
||||
|
||||
stuff, new_key = yield source.get_new_events_for_user(
|
||||
listener.user,
|
||||
getattr(from_token, keyname),
|
||||
limit,
|
||||
)
|
||||
|
||||
events.extend(stuff)
|
||||
|
||||
from_token = from_token.copy_and_replace(keyname, new_key)
|
||||
|
||||
end_token = from_token
|
||||
|
||||
if events:
|
||||
listener.notify(self, events, listener.from_token, end_token)
|
||||
|
||||
defer.returnValue(listener)
|
||||
|
||||
def _user_joined_room(self, user, room_id):
|
||||
user = str(user)
|
||||
new_user_stream = self.user_to_user_stream.get(user)
|
||||
if new_user_stream is not None:
|
||||
room_streams = self.room_to_user_streams.setdefault(room_id, set())
|
||||
room_streams.add(new_user_stream)
|
||||
new_user_stream.rooms.add(room_id)
|
||||
new_listeners = self.user_to_listeners.get(user, set())
|
||||
|
||||
listeners = self.room_to_listeners.setdefault(room_id, set())
|
||||
listeners |= new_listeners
|
||||
|
||||
for l in new_listeners:
|
||||
l.rooms.add(room_id)
|
||||
|
||||
|
||||
def _discard_if_notified(listener_set):
|
||||
"""Remove any 'stale' listeners from the given set.
|
||||
"""
|
||||
to_discard = set()
|
||||
for l in listener_set:
|
||||
if l.notified():
|
||||
to_discard.add(l)
|
||||
|
||||
listener_set -= to_discard
|
||||
|
||||
@@ -24,7 +24,6 @@ import baserules
|
||||
import logging
|
||||
import simplejson as json
|
||||
import re
|
||||
import random
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -75,33 +74,35 @@ class Pusher(object):
|
||||
|
||||
rawrules = yield self.store.get_push_rules_for_user(self.user_name)
|
||||
|
||||
rules = []
|
||||
for rawrule in rawrules:
|
||||
rule = dict(rawrule)
|
||||
rule['conditions'] = json.loads(rawrule['conditions'])
|
||||
rule['actions'] = json.loads(rawrule['actions'])
|
||||
rules.append(rule)
|
||||
for r in rawrules:
|
||||
r['conditions'] = json.loads(r['conditions'])
|
||||
r['actions'] = json.loads(r['actions'])
|
||||
|
||||
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
|
||||
|
||||
user = UserID.from_string(self.user_name)
|
||||
|
||||
rules = baserules.list_with_base_rules(rules, user)
|
||||
|
||||
room_id = ev['room_id']
|
||||
rules = baserules.list_with_base_rules(rawrules, user)
|
||||
|
||||
# get *our* member event for display name matching
|
||||
my_display_name = None
|
||||
our_member_event = yield self.store.get_current_state(
|
||||
room_id=room_id,
|
||||
member_events_for_room = yield self.store.get_current_state(
|
||||
room_id=ev['room_id'],
|
||||
event_type='m.room.member',
|
||||
state_key=self.user_name,
|
||||
state_key=None
|
||||
)
|
||||
if our_member_event:
|
||||
my_display_name = our_member_event[0].content.get("displayname")
|
||||
my_display_name = None
|
||||
room_member_count = 0
|
||||
for mev in member_events_for_room:
|
||||
if mev.content['membership'] != 'join':
|
||||
continue
|
||||
|
||||
room_members = yield self.store.get_users_in_room(room_id)
|
||||
room_member_count = len(room_members)
|
||||
# This loop does two things:
|
||||
# 1) Find our current display name
|
||||
if mev.state_key == self.user_name and 'displayname' in mev.content:
|
||||
my_display_name = mev.content['displayname']
|
||||
|
||||
# and 2) Get the number of people in that room
|
||||
room_member_count += 1
|
||||
|
||||
for r in rules:
|
||||
if r['rule_id'] in enabled_map:
|
||||
@@ -257,154 +258,132 @@ class Pusher(object):
|
||||
logger.info("Pusher %s for user %s starting from token %s",
|
||||
self.pushkey, self.user_name, self.last_token)
|
||||
|
||||
wait = 0
|
||||
while self.alive:
|
||||
try:
|
||||
if wait > 0:
|
||||
yield synapse.util.async.sleep(wait)
|
||||
yield self.get_and_dispatch()
|
||||
wait = 0
|
||||
except:
|
||||
if wait == 0:
|
||||
wait = 1
|
||||
else:
|
||||
wait = min(wait * 2, 1800)
|
||||
logger.exception(
|
||||
"Exception in pusher loop for pushkey %s. Pausing for %ds",
|
||||
self.pushkey, wait
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_and_dispatch(self):
|
||||
from_tok = StreamToken.from_string(self.last_token)
|
||||
config = PaginationConfig(from_token=from_tok, limit='1')
|
||||
timeout = (300 + random.randint(-60, 60)) * 1000
|
||||
chunk = yield self.evStreamHandler.get_stream(
|
||||
self.user_name, config,
|
||||
timeout=timeout, affect_presence=False
|
||||
)
|
||||
|
||||
# limiting to 1 may get 1 event plus 1 presence event, so
|
||||
# pick out the actual event
|
||||
single_event = None
|
||||
for c in chunk['chunk']:
|
||||
if 'event_id' in c: # Hmmm...
|
||||
single_event = c
|
||||
break
|
||||
if not single_event:
|
||||
self.last_token = chunk['end']
|
||||
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
|
||||
return
|
||||
|
||||
if not self.alive:
|
||||
return
|
||||
|
||||
processed = False
|
||||
actions = yield self._actions_for_event(single_event)
|
||||
tweaks = _tweaks_for_actions(actions)
|
||||
|
||||
if len(actions) == 0:
|
||||
logger.warn("Empty actions! Using default action.")
|
||||
actions = Pusher.DEFAULT_ACTIONS
|
||||
|
||||
if 'notify' not in actions and 'dont_notify' not in actions:
|
||||
logger.warn("Neither notify nor dont_notify in actions: adding default")
|
||||
actions.extend(Pusher.DEFAULT_ACTIONS)
|
||||
|
||||
if 'dont_notify' in actions:
|
||||
logger.debug(
|
||||
"%s for %s: dont_notify",
|
||||
single_event['event_id'], self.user_name
|
||||
from_tok = StreamToken.from_string(self.last_token)
|
||||
config = PaginationConfig(from_token=from_tok, limit='1')
|
||||
chunk = yield self.evStreamHandler.get_stream(
|
||||
self.user_name, config,
|
||||
timeout=100*365*24*60*60*1000, affect_presence=False
|
||||
)
|
||||
processed = True
|
||||
else:
|
||||
rejected = yield self.dispatch_push(single_event, tweaks)
|
||||
self.has_unread = True
|
||||
if isinstance(rejected, list) or isinstance(rejected, tuple):
|
||||
|
||||
# limiting to 1 may get 1 event plus 1 presence event, so
|
||||
# pick out the actual event
|
||||
single_event = None
|
||||
for c in chunk['chunk']:
|
||||
if 'event_id' in c: # Hmmm...
|
||||
single_event = c
|
||||
break
|
||||
if not single_event:
|
||||
self.last_token = chunk['end']
|
||||
continue
|
||||
|
||||
if not self.alive:
|
||||
continue
|
||||
|
||||
processed = False
|
||||
actions = yield self._actions_for_event(single_event)
|
||||
tweaks = _tweaks_for_actions(actions)
|
||||
|
||||
if len(actions) == 0:
|
||||
logger.warn("Empty actions! Using default action.")
|
||||
actions = Pusher.DEFAULT_ACTIONS
|
||||
if 'notify' not in actions and 'dont_notify' not in actions:
|
||||
logger.warn("Neither notify nor dont_notify in actions: adding default")
|
||||
actions.extend(Pusher.DEFAULT_ACTIONS)
|
||||
if 'dont_notify' in actions:
|
||||
logger.debug(
|
||||
"%s for %s: dont_notify",
|
||||
single_event['event_id'], self.user_name
|
||||
)
|
||||
processed = True
|
||||
for pk in rejected:
|
||||
if pk != self.pushkey:
|
||||
# for sanity, we only remove the pushkey if it
|
||||
# was the one we actually sent...
|
||||
logger.warn(
|
||||
("Ignoring rejected pushkey %s because we"
|
||||
" didn't send it"), pk
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Pushkey %s was rejected: removing",
|
||||
pk
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pusher(
|
||||
self.app_id, pk, self.user_name
|
||||
)
|
||||
else:
|
||||
rejected = yield self.dispatch_push(single_event, tweaks)
|
||||
self.has_unread = True
|
||||
if isinstance(rejected, list) or isinstance(rejected, tuple):
|
||||
processed = True
|
||||
for pk in rejected:
|
||||
if pk != self.pushkey:
|
||||
# for sanity, we only remove the pushkey if it
|
||||
# was the one we actually sent...
|
||||
logger.warn(
|
||||
("Ignoring rejected pushkey %s because we"
|
||||
" didn't send it"), pk
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Pushkey %s was rejected: removing",
|
||||
pk
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pusher(
|
||||
self.app_id, pk, self.user_name
|
||||
)
|
||||
|
||||
if not self.alive:
|
||||
return
|
||||
if not self.alive:
|
||||
continue
|
||||
|
||||
if processed:
|
||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||
self.last_token = chunk['end']
|
||||
self.store.update_pusher_last_token_and_success(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.last_token,
|
||||
self.clock.time_msec()
|
||||
)
|
||||
if self.failing_since:
|
||||
self.failing_since = None
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since)
|
||||
else:
|
||||
if not self.failing_since:
|
||||
self.failing_since = self.clock.time_msec()
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since
|
||||
)
|
||||
|
||||
if (self.failing_since and
|
||||
self.failing_since <
|
||||
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
|
||||
# we really only give up so that if the URL gets
|
||||
# fixed, we don't suddenly deliver a load
|
||||
# of old notifications.
|
||||
logger.warn("Giving up on a notification to user %s, "
|
||||
"pushkey %s",
|
||||
self.user_name, self.pushkey)
|
||||
if processed:
|
||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||
self.last_token = chunk['end']
|
||||
self.store.update_pusher_last_token(
|
||||
self.store.update_pusher_last_token_and_success(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.last_token
|
||||
)
|
||||
|
||||
self.failing_since = None
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since
|
||||
self.last_token,
|
||||
self.clock.time_msec()
|
||||
)
|
||||
if self.failing_since:
|
||||
self.failing_since = None
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since)
|
||||
else:
|
||||
logger.warn("Failed to dispatch push for user %s "
|
||||
"(failing for %dms)."
|
||||
"Trying again in %dms",
|
||||
self.user_name,
|
||||
self.clock.time_msec() - self.failing_since,
|
||||
self.backoff_delay)
|
||||
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
|
||||
self.backoff_delay *= 2
|
||||
if self.backoff_delay > Pusher.MAX_BACKOFF:
|
||||
self.backoff_delay = Pusher.MAX_BACKOFF
|
||||
if not self.failing_since:
|
||||
self.failing_since = self.clock.time_msec()
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since
|
||||
)
|
||||
|
||||
if (self.failing_since and
|
||||
self.failing_since <
|
||||
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
|
||||
# we really only give up so that if the URL gets
|
||||
# fixed, we don't suddenly deliver a load
|
||||
# of old notifications.
|
||||
logger.warn("Giving up on a notification to user %s, "
|
||||
"pushkey %s",
|
||||
self.user_name, self.pushkey)
|
||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||
self.last_token = chunk['end']
|
||||
self.store.update_pusher_last_token(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.last_token
|
||||
)
|
||||
|
||||
self.failing_since = None
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since
|
||||
)
|
||||
else:
|
||||
logger.warn("Failed to dispatch push for user %s "
|
||||
"(failing for %dms)."
|
||||
"Trying again in %dms",
|
||||
self.user_name,
|
||||
self.clock.time_msec() - self.failing_since,
|
||||
self.backoff_delay)
|
||||
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
|
||||
self.backoff_delay *= 2
|
||||
if self.backoff_delay > Pusher.MAX_BACKOFF:
|
||||
self.backoff_delay = Pusher.MAX_BACKOFF
|
||||
|
||||
def stop(self):
|
||||
self.alive = False
|
||||
|
||||
@@ -18,23 +18,22 @@ from distutils.version import LooseVersion
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUIREMENTS = {
|
||||
"syutil>=0.0.7": ["syutil>=0.0.7"],
|
||||
"syutil>=0.0.6": ["syutil>=0.0.6"],
|
||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||
"pyyaml": ["yaml"],
|
||||
"pyasn1": ["pyasn1"],
|
||||
"pynacl>=0.0.3": ["nacl>=0.0.3"],
|
||||
"pynacl": ["nacl"],
|
||||
"daemonize": ["daemonize"],
|
||||
"py-bcrypt": ["bcrypt"],
|
||||
"frozendict>=0.4": ["frozendict"],
|
||||
"pillow": ["PIL"],
|
||||
"pydenticon": ["pydenticon"],
|
||||
"ujson": ["ujson"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"web_client": {
|
||||
"matrix_angular_sdk>=0.6.6": ["syweb>=0.6.6"],
|
||||
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,15 +50,20 @@ def github_link(project, version, egg):
|
||||
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
|
||||
|
||||
DEPENDENCY_LINKS = [
|
||||
github_link(
|
||||
project="pyca/pynacl",
|
||||
version="d4d3175589b892f6ea7c22f466e0e223853516fa",
|
||||
egg="pynacl-0.3.0",
|
||||
),
|
||||
github_link(
|
||||
project="matrix-org/syutil",
|
||||
version="v0.0.7",
|
||||
egg="syutil-0.0.7",
|
||||
version="v0.0.6",
|
||||
egg="syutil-0.0.6",
|
||||
),
|
||||
github_link(
|
||||
project="matrix-org/matrix-angular-sdk",
|
||||
version="v0.6.6",
|
||||
egg="matrix_angular_sdk-0.6.6",
|
||||
version="v0.6.5",
|
||||
egg="matrix_angular_sdk-0.6.5",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class ClientV1RestResource(JsonResource):
|
||||
"""A resource for version 1 of the matrix client API."""
|
||||
|
||||
def __init__(self, hs):
|
||||
JsonResource.__init__(self, hs, canonical_json=False)
|
||||
JsonResource.__init__(self, hs)
|
||||
self.register_servlets(self, hs)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -118,14 +118,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||
user.to_string()
|
||||
)
|
||||
|
||||
ruleslist = []
|
||||
for rawrule in rawrules:
|
||||
rule = dict(rawrule)
|
||||
rule["conditions"] = json.loads(rawrule["conditions"])
|
||||
rule["actions"] = json.loads(rawrule["actions"])
|
||||
ruleslist.append(rule)
|
||||
for r in rawrules:
|
||||
r["conditions"] = json.loads(r["conditions"])
|
||||
r["actions"] = json.loads(r["actions"])
|
||||
|
||||
ruleslist = baserules.list_with_base_rules(ruleslist, user)
|
||||
ruleslist = baserules.list_with_base_rules(rawrules, user)
|
||||
|
||||
rules = {'global': {}, 'device': {}}
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ClientV2AlphaRestResource(JsonResource):
|
||||
"""A resource for version 2 alpha of the matrix client API."""
|
||||
|
||||
def __init__(self, hs):
|
||||
JsonResource.__init__(self, hs, canonical_json=False)
|
||||
JsonResource.__init__(self, hs)
|
||||
self.register_servlets(self, hs)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -65,12 +65,12 @@ class PasswordRestServlet(RestServlet):
|
||||
if 'medium' not in threepid or 'address' not in threepid:
|
||||
raise SynapseError(500, "Malformed threepid")
|
||||
# if using email, we must know about the email they're authing with!
|
||||
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
threepid_user = yield self.hs.get_datastore().get_user_by_threepid(
|
||||
threepid['medium'], threepid['address']
|
||||
)
|
||||
if not threepid_user_id:
|
||||
if not threepid_user:
|
||||
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
||||
user_id = threepid_user_id
|
||||
user_id = threepid_user
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
@@ -82,10 +82,8 @@ class RegisterRestServlet(RestServlet):
|
||||
[LoginType.EMAIL_IDENTITY]
|
||||
]
|
||||
|
||||
result = None
|
||||
if service:
|
||||
is_application_server = True
|
||||
params = body
|
||||
elif 'mac' in body:
|
||||
# Check registration-specific shared secret auth
|
||||
if 'username' not in body:
|
||||
@@ -94,7 +92,6 @@ class RegisterRestServlet(RestServlet):
|
||||
body['username'], body['mac']
|
||||
)
|
||||
is_using_shared_secret = True
|
||||
params = body
|
||||
else:
|
||||
authed, result, params = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
@@ -121,7 +118,7 @@ class RegisterRestServlet(RestServlet):
|
||||
password=new_password
|
||||
)
|
||||
|
||||
if result and LoginType.EMAIL_IDENTITY in result:
|
||||
if LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
|
||||
for reqd in ['medium', 'address', 'validated_at']:
|
||||
|
||||
@@ -15,18 +15,17 @@
|
||||
|
||||
from .thumbnailer import Thumbnailer
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.http.server import respond_with_json
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_error, Codes, SynapseError
|
||||
)
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async import create_observer
|
||||
|
||||
import os
|
||||
|
||||
@@ -53,7 +52,7 @@ class BaseMediaResource(Resource):
|
||||
def __init__(self, hs, filepaths):
|
||||
Resource.__init__(self)
|
||||
self.auth = hs.get_auth()
|
||||
self.client = MatrixFederationHttpClient(hs)
|
||||
self.client = hs.get_http_client()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
@@ -84,17 +83,13 @@ class BaseMediaResource(Resource):
|
||||
download = self.downloads.get(key)
|
||||
if download is None:
|
||||
download = self._get_remote_media_impl(server_name, media_id)
|
||||
download = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
self.downloads[key] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(media_info):
|
||||
del self.downloads[key]
|
||||
return media_info
|
||||
return download.observe()
|
||||
return create_observer(download)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
@@ -274,65 +269,57 @@ class BaseMediaResource(Resource):
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
remote_thumbnails = []
|
||||
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
def generate_thumbnails():
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
)
|
||||
return
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
)
|
||||
return
|
||||
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
remote_thumbnails.append([
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
])
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
remote_thumbnails.append([
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
])
|
||||
|
||||
yield threads.deferToThread(generate_thumbnails)
|
||||
|
||||
for r in remote_thumbnails:
|
||||
yield self.store.store_remote_media_thumbnail(*r)
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
|
||||
@@ -59,6 +59,7 @@ class BaseHomeServer(object):
|
||||
'config',
|
||||
'clock',
|
||||
'http_client',
|
||||
'db_name',
|
||||
'db_pool',
|
||||
'persistence_service',
|
||||
'replication_layer',
|
||||
|
||||
@@ -106,7 +106,7 @@ class StateHandler(object):
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_event_context(self, event, old_state=None, outlier=False):
|
||||
def compute_event_context(self, event, old_state=None):
|
||||
""" Fills out the context with the `current state` of the graph. The
|
||||
`current state` here is defined to be the state of the event graph
|
||||
just before the event - i.e. it never includes `event`
|
||||
@@ -119,23 +119,9 @@ class StateHandler(object):
|
||||
Returns:
|
||||
an EventContext
|
||||
"""
|
||||
yield run_on_reactor()
|
||||
|
||||
context = EventContext()
|
||||
|
||||
if outlier:
|
||||
# If this is an outlier, then we know it shouldn't have any current
|
||||
# state. Certainly store.get_current_state won't return any, and
|
||||
# persisting the event won't store the state group.
|
||||
if old_state:
|
||||
context.current_state = {
|
||||
(s.type, s.state_key): s for s in old_state
|
||||
}
|
||||
else:
|
||||
context.current_state = {}
|
||||
context.prev_state_events = []
|
||||
context.state_group = None
|
||||
defer.returnValue(context)
|
||||
yield run_on_reactor()
|
||||
|
||||
if old_state:
|
||||
context.current_state = {
|
||||
@@ -169,6 +155,10 @@ class StateHandler(object):
|
||||
context.current_state = curr_state
|
||||
context.state_group = group if not event.is_state() else None
|
||||
|
||||
prev_state = yield self.store.add_event_hashes(
|
||||
prev_state
|
||||
)
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.current_state:
|
||||
|
||||
@@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 20
|
||||
SCHEMA_VERSION = 17
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
@@ -348,7 +348,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
module_name, absolute_path, python_file
|
||||
)
|
||||
logger.debug("Running script %s", relative_path)
|
||||
module.run_upgrade(cur, database_engine)
|
||||
module.run_upgrade(cur)
|
||||
elif ext == ".sql":
|
||||
# A plain old .sql file, just read and execute it
|
||||
logger.debug("Applying schema %s", relative_path)
|
||||
|
||||
@@ -15,8 +15,10 @@
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
|
||||
from synapse.util.lrucache import LruCache
|
||||
import synapse.metrics
|
||||
|
||||
@@ -25,13 +27,11 @@ from util.id_generators import IdGenerator, StreamIdGenerator
|
||||
from twisted.internet import defer
|
||||
|
||||
from collections import namedtuple, OrderedDict
|
||||
|
||||
import functools
|
||||
import simplejson as json
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
|
||||
DEBUG_CACHES = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,6 +46,7 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
|
||||
|
||||
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
||||
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
||||
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
|
||||
|
||||
caches_by_name = {}
|
||||
cache_counter = metrics.register_cache(
|
||||
@@ -67,19 +68,8 @@ class Cache(object):
|
||||
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
self.sequence = 0
|
||||
self.thread = None
|
||||
caches_by_name[name] = self.cache
|
||||
|
||||
def check_thread(self):
|
||||
expected_thread = self.thread
|
||||
if expected_thread is None:
|
||||
self.thread = threading.current_thread()
|
||||
else:
|
||||
if expected_thread is not threading.current_thread():
|
||||
raise ValueError(
|
||||
"Cache objects can only be accessed from the main thread"
|
||||
)
|
||||
caches_by_name[name] = self.cache
|
||||
|
||||
def get(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
@@ -92,13 +82,6 @@ class Cache(object):
|
||||
cache_counter.inc_misses(self.name)
|
||||
raise KeyError()
|
||||
|
||||
def update(self, sequence, *args):
|
||||
self.check_thread()
|
||||
if self.sequence == sequence:
|
||||
# Only update the cache if the caches sequence number matches the
|
||||
# number that the cache had before the SELECT was started (SYN-369)
|
||||
self.prefill(*args)
|
||||
|
||||
def prefill(self, *args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
@@ -113,21 +96,13 @@ class Cache(object):
|
||||
self.cache[keyargs] = value
|
||||
|
||||
def invalidate(self, *keyargs):
|
||||
self.check_thread()
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
# Increment the sequence number so that any SELECT statements that
|
||||
# raced with the INSERT don't update the cache (SYN-369)
|
||||
self.sequence += 1
|
||||
|
||||
self.cache.pop(keyargs, None)
|
||||
|
||||
def invalidate_all(self):
|
||||
self.check_thread()
|
||||
self.sequence += 1
|
||||
self.cache.clear()
|
||||
|
||||
|
||||
class CacheDescriptor(object):
|
||||
def cached(max_entries=1000, num_args=1, lru=False):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
The function is presumed to take zero or more arguments, which are used in
|
||||
@@ -141,84 +116,43 @@ class CacheDescriptor(object):
|
||||
which can be used to insert values into the cache specifically, without
|
||||
calling the calculation function.
|
||||
"""
|
||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
|
||||
self.orig = orig
|
||||
|
||||
self.max_entries = max_entries
|
||||
self.num_args = num_args
|
||||
self.lru = lru
|
||||
|
||||
def __get__(self, obj, objtype=None):
|
||||
def wrap(orig):
|
||||
cache = Cache(
|
||||
name=self.orig.__name__,
|
||||
max_entries=self.max_entries,
|
||||
keylen=self.num_args,
|
||||
lru=self.lru,
|
||||
name=orig.__name__,
|
||||
max_entries=max_entries,
|
||||
keylen=num_args,
|
||||
lru=lru,
|
||||
)
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
@functools.wraps(orig)
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(*keyargs):
|
||||
def wrapped(self, *keyargs):
|
||||
try:
|
||||
cached_result = cache.get(*keyargs[:self.num_args])
|
||||
if DEBUG_CACHES:
|
||||
actual_result = yield self.orig(obj, *keyargs)
|
||||
if actual_result != cached_result:
|
||||
logger.error(
|
||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||
self.orig.__name__, keyargs,
|
||||
cached_result, actual_result,
|
||||
)
|
||||
raise ValueError("Stale cache entry")
|
||||
defer.returnValue(cached_result)
|
||||
defer.returnValue(cache.get(*keyargs))
|
||||
except KeyError:
|
||||
# Get the sequence number of the cache before reading from the
|
||||
# database so that we can tell if the cache is invalidated
|
||||
# while the SELECT is executing (SYN-369)
|
||||
sequence = cache.sequence
|
||||
ret = yield orig(self, *keyargs)
|
||||
|
||||
ret = yield self.orig(obj, *keyargs)
|
||||
|
||||
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
|
||||
cache.prefill(*keyargs + (ret,))
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.invalidate_all = cache.invalidate_all
|
||||
wrapped.prefill = cache.prefill
|
||||
|
||||
obj.__dict__[self.orig.__name__] = wrapped
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, lru=False):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
lru=lru
|
||||
)
|
||||
return wrap
|
||||
|
||||
|
||||
class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method."""
|
||||
__slots__ = ["txn", "name", "database_engine", "after_callbacks"]
|
||||
__slots__ = ["txn", "name", "database_engine"]
|
||||
|
||||
def __init__(self, txn, name, database_engine, after_callbacks):
|
||||
def __init__(self, txn, name, database_engine):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
object.__setattr__(self, "after_callbacks", after_callbacks)
|
||||
|
||||
def call_after(self, callback, *args):
|
||||
"""Call the given callback on the main twisted thread after the
|
||||
transaction has finished. Used to invalidate the caches on the
|
||||
correct thread.
|
||||
"""
|
||||
self.after_callbacks.append((callback, args))
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.txn, name)
|
||||
@@ -226,23 +160,22 @@ class LoggingTransaction(object):
|
||||
def __setattr__(self, name, value):
|
||||
setattr(self.txn, name, value)
|
||||
|
||||
def execute(self, sql, *args):
|
||||
self._do_execute(self.txn.execute, sql, *args)
|
||||
|
||||
def executemany(self, sql, *args):
|
||||
self._do_execute(self.txn.executemany, sql, *args)
|
||||
|
||||
def _do_execute(self, func, sql, *args):
|
||||
def execute(self, sql, *args, **kwargs):
|
||||
# TODO(paul): Maybe use 'info' and 'debug' for values?
|
||||
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
|
||||
|
||||
sql = self.database_engine.convert_param_style(sql)
|
||||
|
||||
if args:
|
||||
if args and args[0]:
|
||||
args = list(args)
|
||||
args[0] = [
|
||||
self.database_engine.encode_parameter(a) for a in args[0]
|
||||
]
|
||||
try:
|
||||
sql_logger.debug(
|
||||
"[SQL values] {%s} %r",
|
||||
self.name, args[0]
|
||||
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
|
||||
self.name,
|
||||
*args[0]
|
||||
)
|
||||
except:
|
||||
# Don't let logging failures stop SQL from working
|
||||
@@ -251,8 +184,8 @@ class LoggingTransaction(object):
|
||||
start = time.time() * 1000
|
||||
|
||||
try:
|
||||
return func(
|
||||
sql, *args
|
||||
return self.txn.execute(
|
||||
sql, *args, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
@@ -321,12 +254,6 @@ class SQLBaseStore(object):
|
||||
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
||||
max_entries=hs.config.event_cache_size)
|
||||
|
||||
self._event_fetch_lock = threading.Condition()
|
||||
self._event_fetch_list = []
|
||||
self._event_fetch_ongoing = 0
|
||||
|
||||
self._pending_ds = []
|
||||
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator()
|
||||
@@ -334,8 +261,6 @@ class SQLBaseStore(object):
|
||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||
|
||||
def start_profiling(self):
|
||||
self._previous_loop_ts = self._clock.time_msec()
|
||||
@@ -366,75 +291,6 @@ class SQLBaseStore(object):
|
||||
|
||||
self._clock.looping_call(loop, 10000)
|
||||
|
||||
def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
|
||||
start = time.time() * 1000
|
||||
txn_id = self._TXN_ID
|
||||
|
||||
# We don't really need these to be unique, so lets stop it from
|
||||
# growing really large.
|
||||
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
|
||||
|
||||
name = "%s-%x" % (desc, txn_id, )
|
||||
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
|
||||
try:
|
||||
i = 0
|
||||
N = 5
|
||||
while True:
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
txn = LoggingTransaction(
|
||||
txn, name, self.database_engine, after_callbacks
|
||||
)
|
||||
r = func(txn, *args, **kwargs)
|
||||
conn.commit()
|
||||
return r
|
||||
except self.database_engine.module.OperationalError as e:
|
||||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
logger.warn(
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name, e, i, N
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warn(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, e1,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
except self.database_engine.module.DatabaseError as e:
|
||||
if self.database_engine.is_deadlock(e):
|
||||
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warn(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, e1,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||
raise
|
||||
finally:
|
||||
end = time.time() * 1000
|
||||
duration = end - start
|
||||
|
||||
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
|
||||
|
||||
self._current_txn_total_time += duration
|
||||
self._txn_perf_counters.update(desc, start, end)
|
||||
sql_txn_timer.inc_by(duration, desc)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
@@ -442,54 +298,82 @@ class SQLBaseStore(object):
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
after_callbacks = []
|
||||
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
with LoggingContext("runInteraction") as context:
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
logger.debug("Reconnecting closed database connection")
|
||||
conn.reconnect()
|
||||
|
||||
current_context.copy_to(context)
|
||||
return self._new_transaction(
|
||||
conn, desc, after_callbacks, func, *args, **kwargs
|
||||
)
|
||||
start = time.time() * 1000
|
||||
txn_id = self._TXN_ID
|
||||
|
||||
result = yield preserve_context_over_fn(
|
||||
self._db_pool.runWithConnection,
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
# We don't really need these to be unique, so lets stop it from
|
||||
# growing really large.
|
||||
self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
|
||||
|
||||
for after_callback, after_args in after_callbacks:
|
||||
after_callback(*after_args)
|
||||
defer.returnValue(result)
|
||||
name = "%s-%x" % (desc, txn_id, )
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runWithConnection(self, func, *args, **kwargs):
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
with LoggingContext("runWithConnection") as context:
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
try:
|
||||
i = 0
|
||||
N = 5
|
||||
while True:
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
return func(
|
||||
LoggingTransaction(txn, name, self.database_engine),
|
||||
*args, **kwargs
|
||||
)
|
||||
except self.database_engine.module.OperationalError as e:
|
||||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
logger.warn(
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name, e, i, N
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warn(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, e1,
|
||||
)
|
||||
continue
|
||||
except self.database_engine.module.DatabaseError as e:
|
||||
if self.database_engine.is_deadlock(e):
|
||||
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warn(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, e1,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||
raise
|
||||
finally:
|
||||
end = time.time() * 1000
|
||||
duration = end - start
|
||||
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
logger.debug("Reconnecting closed database connection")
|
||||
conn.reconnect()
|
||||
transaction_logger.debug("[TXN END] {%s} %f", name, duration)
|
||||
|
||||
current_context.copy_to(context)
|
||||
|
||||
return func(conn, *args, **kwargs)
|
||||
|
||||
result = yield preserve_context_over_fn(
|
||||
self._db_pool.runWithConnection,
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
self._current_txn_total_time += duration
|
||||
self._txn_perf_counters.update(desc, start, end)
|
||||
sql_txn_timer.inc_by(duration, desc)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
def cursor_to_dict(self, cursor):
|
||||
@@ -554,49 +438,18 @@ class SQLBaseStore(object):
|
||||
|
||||
@log_function
|
||||
def _simple_insert_txn(self, txn, table, values):
|
||||
keys, vals = zip(*values.items())
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys),
|
||||
", ".join("?" for _ in keys)
|
||||
", ".join(k for k in values),
|
||||
", ".join("?" for k in values)
|
||||
)
|
||||
|
||||
txn.execute(sql, vals)
|
||||
|
||||
def _simple_insert_many_txn(self, txn, table, values):
|
||||
if not values:
|
||||
return
|
||||
|
||||
# This is a *slight* abomination to get a list of tuples of key names
|
||||
# and a list of tuples of value names.
|
||||
#
|
||||
# i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
|
||||
# => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
|
||||
#
|
||||
# The sort is to ensure that we don't rely on dictionary iteration
|
||||
# order.
|
||||
keys, vals = zip(*[
|
||||
zip(
|
||||
*(sorted(i.items(), key=lambda kv: kv[0]))
|
||||
)
|
||||
for i in values
|
||||
if i
|
||||
])
|
||||
|
||||
for k in keys:
|
||||
if k != keys[0]:
|
||||
raise RuntimeError(
|
||||
"All items must have the same keys"
|
||||
)
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys[0]),
|
||||
", ".join("?" for _ in keys[0])
|
||||
logger.debug(
|
||||
"[SQL] %s Args=%s",
|
||||
sql, values.values(),
|
||||
)
|
||||
|
||||
txn.executemany(sql, vals)
|
||||
txn.execute(sql, values.values())
|
||||
|
||||
def _simple_upsert(self, table, keyvalues, values,
|
||||
insertion_values={}, desc="_simple_upsert", lock=True):
|
||||
@@ -929,6 +782,158 @@ class SQLBaseStore(object):
|
||||
|
||||
return self.runInteraction("_simple_max_id", func)
|
||||
|
||||
def _get_events(self, event_ids, check_redacted=True,
|
||||
get_prev_content=False):
|
||||
return self.runInteraction(
|
||||
"_get_events", self._get_events_txn, event_ids,
|
||||
check_redacted=check_redacted, get_prev_content=get_prev_content,
|
||||
)
|
||||
|
||||
def _get_events_txn(self, txn, event_ids, check_redacted=True,
|
||||
get_prev_content=False):
|
||||
if not event_ids:
|
||||
return []
|
||||
|
||||
events = [
|
||||
self._get_event_txn(
|
||||
txn, event_id,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content
|
||||
)
|
||||
for event_id in event_ids
|
||||
]
|
||||
|
||||
return [e for e in events if e]
|
||||
|
||||
def _invalidate_get_event_cache(self, event_id):
|
||||
for check_redacted in (False, True):
|
||||
for get_prev_content in (False, True):
|
||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
||||
get_prev_content)
|
||||
|
||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
def update_counter(desc, last_time):
|
||||
curr_time = self._get_event_counters.update(desc, last_time)
|
||||
sql_getevents_timer.inc_by(curr_time - last_time, desc)
|
||||
return curr_time
|
||||
|
||||
try:
|
||||
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
|
||||
|
||||
if allow_rejected or not ret.rejected_reason:
|
||||
return ret
|
||||
else:
|
||||
return None
|
||||
except KeyError:
|
||||
pass
|
||||
finally:
|
||||
start_time = update_counter("event_cache", start_time)
|
||||
|
||||
sql = (
|
||||
"SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
|
||||
"FROM event_json as e "
|
||||
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
|
||||
"LEFT JOIN rejections as rej on rej.event_id = e.event_id "
|
||||
"WHERE e.event_id = ? "
|
||||
"LIMIT 1 "
|
||||
)
|
||||
|
||||
txn.execute(sql, (event_id,))
|
||||
|
||||
res = txn.fetchone()
|
||||
|
||||
if not res:
|
||||
return None
|
||||
|
||||
internal_metadata, js, redacted, rejected_reason = res
|
||||
|
||||
start_time = update_counter("select_event", start_time)
|
||||
|
||||
result = self._get_event_from_row_txn(
|
||||
txn, internal_metadata, js, redacted,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
|
||||
|
||||
if allow_rejected or not rejected_reason:
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
|
||||
check_redacted=True, get_prev_content=False,
|
||||
rejected_reason=None):
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
def update_counter(desc, last_time):
|
||||
curr_time = self._get_event_counters.update(desc, last_time)
|
||||
sql_getevents_timer.inc_by(curr_time - last_time, desc)
|
||||
return curr_time
|
||||
|
||||
d = json.loads(js)
|
||||
start_time = update_counter("decode_json", start_time)
|
||||
|
||||
internal_metadata = json.loads(internal_metadata)
|
||||
start_time = update_counter("decode_internal", start_time)
|
||||
|
||||
ev = FrozenEvent(
|
||||
d,
|
||||
internal_metadata_dict=internal_metadata,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
start_time = update_counter("build_frozen_event", start_time)
|
||||
|
||||
if check_redacted and redacted:
|
||||
ev = prune_event(ev)
|
||||
|
||||
ev.unsigned["redacted_by"] = redacted
|
||||
# Get the redaction event.
|
||||
|
||||
because = self._get_event_txn(
|
||||
txn,
|
||||
redacted,
|
||||
check_redacted=False
|
||||
)
|
||||
|
||||
if because:
|
||||
ev.unsigned["redacted_because"] = because
|
||||
start_time = update_counter("redact_event", start_time)
|
||||
|
||||
if get_prev_content and "replaces_state" in ev.unsigned:
|
||||
prev = self._get_event_txn(
|
||||
txn,
|
||||
ev.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
)
|
||||
if prev:
|
||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
start_time = update_counter("get_prev_content", start_time)
|
||||
|
||||
return ev
|
||||
|
||||
def _parse_events(self, rows):
|
||||
return self.runInteraction(
|
||||
"_parse_events", self._parse_events_txn, rows
|
||||
)
|
||||
|
||||
def _parse_events_txn(self, txn, rows):
|
||||
event_ids = [r["event_id"] for r in rows]
|
||||
|
||||
return self._get_events_txn(txn, event_ids)
|
||||
|
||||
def _has_been_redacted_txn(self, txn, event):
|
||||
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
|
||||
txn.execute(sql, (event.event_id,))
|
||||
result = txn.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def get_next_stream_id(self):
|
||||
with self._next_stream_id_lock:
|
||||
i = self._next_stream_id
|
||||
|
||||
@@ -19,8 +19,6 @@ from ._base import IncorrectDatabaseSetup
|
||||
|
||||
|
||||
class PostgresEngine(object):
|
||||
single_threaded = False
|
||||
|
||||
def __init__(self, database_module):
|
||||
self.module = database_module
|
||||
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
||||
@@ -38,6 +36,9 @@ class PostgresEngine(object):
|
||||
def convert_param_style(self, sql):
|
||||
return sql.replace("?", "%s")
|
||||
|
||||
def encode_parameter(self, param):
|
||||
return param
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
db_conn.set_isolation_level(
|
||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
|
||||
@@ -17,8 +17,6 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
|
||||
|
||||
|
||||
class Sqlite3Engine(object):
|
||||
single_threaded = True
|
||||
|
||||
def __init__(self, database_module):
|
||||
self.module = database_module
|
||||
|
||||
@@ -28,6 +26,9 @@ class Sqlite3Engine(object):
|
||||
def convert_param_style(self, sql):
|
||||
return sql
|
||||
|
||||
def encode_parameter(self, param):
|
||||
return param
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
self.prepare_database(db_conn)
|
||||
|
||||
|
||||
@@ -13,13 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore
|
||||
from syutil.base64util import encode_base64
|
||||
|
||||
import logging
|
||||
from Queue import PriorityQueue, Empty
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,7 +33,16 @@ class EventFederationStore(SQLBaseStore):
|
||||
"""
|
||||
|
||||
def get_auth_chain(self, event_ids):
|
||||
return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
|
||||
return self.runInteraction(
|
||||
"get_auth_chain",
|
||||
self._get_auth_chain_txn,
|
||||
event_ids
|
||||
)
|
||||
|
||||
def _get_auth_chain_txn(self, txn, event_ids):
|
||||
results = self._get_auth_chain_ids_txn(txn, event_ids)
|
||||
|
||||
return self._get_events_txn(txn, results)
|
||||
|
||||
def get_auth_chain_ids(self, event_ids):
|
||||
return self.runInteraction(
|
||||
@@ -73,28 +79,6 @@ class EventFederationStore(SQLBaseStore):
|
||||
room_id,
|
||||
)
|
||||
|
||||
def get_oldest_events_with_depth_in_room(self, room_id):
|
||||
return self.runInteraction(
|
||||
"get_oldest_events_with_depth_in_room",
|
||||
self.get_oldest_events_with_depth_in_room_txn,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
|
||||
sql = (
|
||||
"SELECT b.event_id, MAX(e.depth) FROM events as e"
|
||||
" INNER JOIN event_edges as g"
|
||||
" ON g.event_id = e.event_id AND g.room_id = e.room_id"
|
||||
" INNER JOIN event_backward_extremities as b"
|
||||
" ON g.prev_event_id = b.event_id AND g.room_id = b.room_id"
|
||||
" WHERE b.room_id = ? AND g.is_state is ?"
|
||||
" GROUP BY b.event_id"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id, False,))
|
||||
|
||||
return dict(txn.fetchall())
|
||||
|
||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||
return self._simple_select_onecol_txn(
|
||||
txn,
|
||||
@@ -112,7 +96,6 @@ class EventFederationStore(SQLBaseStore):
|
||||
room_id,
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_latest_event_ids_in_room(self, room_id):
|
||||
return self._simple_select_onecol(
|
||||
table="event_forward_extremities",
|
||||
@@ -120,7 +103,7 @@ class EventFederationStore(SQLBaseStore):
|
||||
"room_id": room_id,
|
||||
},
|
||||
retcol="event_id",
|
||||
desc="get_latest_event_ids_in_room",
|
||||
desc="get_latest_events_in_room",
|
||||
)
|
||||
|
||||
def _get_latest_events_in_room(self, txn, room_id):
|
||||
@@ -263,13 +246,11 @@ class EventFederationStore(SQLBaseStore):
|
||||
do_insert = depth < min_depth if min_depth else True
|
||||
|
||||
if do_insert:
|
||||
self._simple_upsert_txn(
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="room_depth",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
values={
|
||||
"room_id": room_id,
|
||||
"min_depth": depth,
|
||||
},
|
||||
)
|
||||
@@ -280,19 +261,18 @@ class EventFederationStore(SQLBaseStore):
|
||||
For the given event, update the event edges table and forward and
|
||||
backward extremities tables.
|
||||
"""
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values=[
|
||||
{
|
||||
for e_id, _ in prev_events:
|
||||
# TODO (erikj): This could be done as a bulk insert
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": room_id,
|
||||
"is_state": False,
|
||||
}
|
||||
for e_id, _ in prev_events
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# Update the extremities table if this is not an outlier.
|
||||
if not outlier:
|
||||
@@ -324,32 +304,30 @@ class EventFederationStore(SQLBaseStore):
|
||||
|
||||
txn.execute(query, (event_id, room_id))
|
||||
|
||||
# Insert all the prev_events as a backwards thing, they'll get
|
||||
# deleted in a second if they're incorrect anyway.
|
||||
for e_id, _ in prev_events:
|
||||
# TODO (erikj): This could be done as a bulk insert
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_backward_extremities",
|
||||
values={
|
||||
"event_id": e_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Also delete from the backwards extremities table all ones that
|
||||
# reference events that we have already seen
|
||||
query = (
|
||||
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
||||
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||
" SELECT 1 FROM event_backward_extremities"
|
||||
" WHERE event_id = ? AND room_id = ?"
|
||||
" )"
|
||||
" AND NOT EXISTS ("
|
||||
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
|
||||
" AND outlier = ?"
|
||||
" )"
|
||||
)
|
||||
|
||||
txn.executemany(query, [
|
||||
(e_id, room_id, e_id, room_id, e_id, room_id, False)
|
||||
for e_id, _ in prev_events
|
||||
])
|
||||
|
||||
query = (
|
||||
"DELETE FROM event_backward_extremities"
|
||||
" WHERE event_id = ? AND room_id = ?"
|
||||
)
|
||||
txn.execute(query, (event_id, room_id))
|
||||
|
||||
txn.call_after(
|
||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
||||
"DELETE FROM event_backward_extremities WHERE EXISTS ("
|
||||
"SELECT 1 FROM events "
|
||||
"WHERE "
|
||||
"event_backward_extremities.event_id = events.event_id "
|
||||
"AND not events.outlier "
|
||||
")"
|
||||
)
|
||||
txn.execute(query)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
"""Get a list of Events for a given topic that occurred before (and
|
||||
@@ -364,10 +342,6 @@ class EventFederationStore(SQLBaseStore):
|
||||
return self.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events, room_id, event_list, limit
|
||||
).addCallback(
|
||||
self._get_events
|
||||
).addCallback(
|
||||
lambda l: sorted(l, key=lambda e: -e.depth)
|
||||
)
|
||||
|
||||
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
||||
@@ -376,75 +350,54 @@ class EventFederationStore(SQLBaseStore):
|
||||
room_id, repr(event_list), limit
|
||||
)
|
||||
|
||||
event_results = set()
|
||||
event_results = event_list
|
||||
|
||||
# We want to make sure that we do a breadth-first, "depth" ordered
|
||||
# search.
|
||||
front = event_list
|
||||
|
||||
query = (
|
||||
"SELECT depth, prev_event_id FROM event_edges"
|
||||
" INNER JOIN events"
|
||||
" ON prev_event_id = events.event_id"
|
||||
" AND event_edges.room_id = events.room_id"
|
||||
" WHERE event_edges.room_id = ? AND event_edges.event_id = ?"
|
||||
" AND event_edges.is_state = ?"
|
||||
" LIMIT ?"
|
||||
"SELECT prev_event_id FROM event_edges "
|
||||
"WHERE room_id = ? AND event_id = ? "
|
||||
"LIMIT ?"
|
||||
)
|
||||
|
||||
queue = PriorityQueue()
|
||||
# We iterate through all event_ids in `front` to select their previous
|
||||
# events. These are dumped in `new_front`.
|
||||
# We continue until we reach the limit *or* new_front is empty (i.e.,
|
||||
# we've run out of things to select
|
||||
while front and len(event_results) < limit:
|
||||
|
||||
for event_id in event_list:
|
||||
depth = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="events",
|
||||
keyvalues={
|
||||
"event_id": event_id,
|
||||
},
|
||||
retcol="depth"
|
||||
)
|
||||
new_front = []
|
||||
for event_id in front:
|
||||
logger.debug(
|
||||
"_backfill_interaction: id=%s",
|
||||
event_id
|
||||
)
|
||||
|
||||
queue.put((-depth, event_id))
|
||||
txn.execute(
|
||||
query,
|
||||
(room_id, event_id, limit - len(event_results))
|
||||
)
|
||||
|
||||
while not queue.empty() and len(event_results) < limit:
|
||||
try:
|
||||
_, event_id = queue.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
for row in txn.fetchall():
|
||||
logger.debug(
|
||||
"_backfill_interaction: got id=%s",
|
||||
*row
|
||||
)
|
||||
new_front.append(row[0])
|
||||
|
||||
if event_id in event_results:
|
||||
continue
|
||||
front = new_front
|
||||
event_results += new_front
|
||||
|
||||
event_results.add(event_id)
|
||||
return self._get_events_txn(txn, event_results)
|
||||
|
||||
txn.execute(
|
||||
query,
|
||||
(room_id, event_id, False, limit - len(event_results))
|
||||
)
|
||||
|
||||
for row in txn.fetchall():
|
||||
if row[1] not in event_results:
|
||||
queue.put((-row[0], row[1]))
|
||||
|
||||
return event_results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_missing_events(self, room_id, earliest_events, latest_events,
|
||||
limit, min_depth):
|
||||
ids = yield self.runInteraction(
|
||||
return self.runInteraction(
|
||||
"get_missing_events",
|
||||
self._get_missing_events,
|
||||
room_id, earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
|
||||
events = yield self._get_events(ids)
|
||||
|
||||
events = sorted(
|
||||
[ev for ev in events if ev.depth >= min_depth],
|
||||
key=lambda e: e.depth,
|
||||
)
|
||||
|
||||
defer.returnValue(events[:limit])
|
||||
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
|
||||
limit, min_depth):
|
||||
|
||||
@@ -476,7 +429,14 @@ class EventFederationStore(SQLBaseStore):
|
||||
front = new_front
|
||||
event_results |= new_front
|
||||
|
||||
return event_results
|
||||
events = self._get_events_txn(txn, event_results)
|
||||
|
||||
events = sorted(
|
||||
[ev for ev in events if ev.depth >= min_depth],
|
||||
key=lambda e: e.depth,
|
||||
)
|
||||
|
||||
return events[:limit]
|
||||
|
||||
def clean_room_for_join(self, room_id):
|
||||
return self.runInteraction(
|
||||
@@ -489,4 +449,3 @@ class EventFederationStore(SQLBaseStore):
|
||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||
|
||||
txn.execute(query, (room_id,))
|
||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
|
||||
|
||||
@@ -15,36 +15,20 @@
|
||||
|
||||
from _base import SQLBaseStore, _RollbackButIsFineException
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.events import FrozenEvent, USE_FROZEN_DICTS
|
||||
from synapse.events.utils import prune_event
|
||||
|
||||
from synapse.util.logcontext import preserve_context_over_deferred
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
from syutil.jsonutil import encode_json
|
||||
from contextlib import contextmanager
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
import logging
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# These values are used in the `enqueus_event` and `_do_fetch` methods to
|
||||
# control how we batch/bulk fetch events from the database.
|
||||
# The values are plucked out of thing air to make initial sync run faster
|
||||
# on jki.re
|
||||
# TODO: Make these configurable.
|
||||
EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
|
||||
EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
|
||||
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
|
||||
|
||||
|
||||
class EventsStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
@@ -57,32 +41,20 @@ class EventsStore(SQLBaseStore):
|
||||
self.min_token -= 1
|
||||
stream_ordering = self.min_token
|
||||
|
||||
if stream_ordering is None:
|
||||
stream_ordering_manager = yield self._stream_id_gen.get_next(self)
|
||||
else:
|
||||
@contextmanager
|
||||
def stream_ordering_manager():
|
||||
yield stream_ordering
|
||||
stream_ordering_manager = stream_ordering_manager()
|
||||
|
||||
try:
|
||||
with stream_ordering_manager as stream_ordering:
|
||||
yield self.runInteraction(
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
yield self.runInteraction(
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
except _RollbackButIsFineException:
|
||||
pass
|
||||
|
||||
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
|
||||
defer.returnValue((stream_ordering, max_persisted_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False,
|
||||
@@ -102,17 +74,18 @@ class EventsStore(SQLBaseStore):
|
||||
Returns:
|
||||
Deferred : A FrozenEvent.
|
||||
"""
|
||||
events = yield self._get_events(
|
||||
[event_id],
|
||||
event = yield self.runInteraction(
|
||||
"get_event", self._get_event_txn,
|
||||
event_id,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
if not events and not allow_none:
|
||||
if not event and not allow_none:
|
||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||
|
||||
defer.returnValue(events[0] if events else None)
|
||||
defer.returnValue(event)
|
||||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||
@@ -120,17 +93,20 @@ class EventsStore(SQLBaseStore):
|
||||
current_state=None):
|
||||
|
||||
# Remove the any existing cache entries for the event_id
|
||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
|
||||
if stream_ordering is None:
|
||||
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
|
||||
return self._persist_event_txn(
|
||||
txn, event, context, backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
|
||||
# We purposefully do this first since if we include a `current_state`
|
||||
# key, we *want* to update the `current_state_events` table
|
||||
if current_state:
|
||||
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
||||
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="current_state_events",
|
||||
@@ -146,29 +122,52 @@ class EventsStore(SQLBaseStore):
|
||||
"room_id": s.room_id,
|
||||
"type": s.type,
|
||||
"state_key": s.state_key,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
if event.is_state() and is_new_state:
|
||||
if not backfilled and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
)
|
||||
|
||||
for prev_state_id, _ in event.prev_state:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": prev_state_id,
|
||||
}
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
self._update_min_depth_for_room_txn(
|
||||
txn,
|
||||
event.room_id,
|
||||
event.depth
|
||||
)
|
||||
|
||||
have_persisted = self._simple_select_one_txn(
|
||||
have_persisted = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="events",
|
||||
table="event_json",
|
||||
keyvalues={"event_id": event.event_id},
|
||||
retcols=["event_id", "outlier"],
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
metadata_json = encode_json(
|
||||
event.internal_metadata.get_dict(),
|
||||
using_frozen_dicts=USE_FROZEN_DICTS
|
||||
metadata_json = encode_canonical_json(
|
||||
event.internal_metadata.get_dict()
|
||||
).decode("UTF-8")
|
||||
|
||||
# If we have already persisted this event, we don't need to do any
|
||||
@@ -178,9 +177,7 @@ class EventsStore(SQLBaseStore):
|
||||
# if we are persisting an event that we had persisted as an outlier,
|
||||
# but is no longer one.
|
||||
if have_persisted:
|
||||
if not outlier and have_persisted["outlier"]:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
if not outlier:
|
||||
sql = (
|
||||
"UPDATE event_json SET internal_metadata = ?"
|
||||
" WHERE event_id = ?"
|
||||
@@ -200,9 +197,6 @@ class EventsStore(SQLBaseStore):
|
||||
)
|
||||
return
|
||||
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
self._handle_prev_events(
|
||||
txn,
|
||||
outlier=outlier,
|
||||
@@ -236,14 +230,12 @@ class EventsStore(SQLBaseStore):
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json,
|
||||
"json": encode_json(
|
||||
event_dict, using_frozen_dicts=USE_FROZEN_DICTS
|
||||
).decode("UTF-8"),
|
||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
||||
},
|
||||
)
|
||||
|
||||
content = encode_json(
|
||||
event.content, using_frozen_dicts=USE_FROZEN_DICTS
|
||||
content = encode_canonical_json(
|
||||
event.content
|
||||
).decode("UTF-8")
|
||||
|
||||
vals = {
|
||||
@@ -269,8 +261,8 @@ class EventsStore(SQLBaseStore):
|
||||
]
|
||||
}
|
||||
|
||||
vals["unrecognized_keys"] = encode_json(
|
||||
unrec, using_frozen_dicts=USE_FROZEN_DICTS
|
||||
vals["unrecognized_keys"] = encode_canonical_json(
|
||||
unrec
|
||||
).decode("UTF-8")
|
||||
|
||||
sql = (
|
||||
@@ -289,9 +281,7 @@ class EventsStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
if context.rejected:
|
||||
self._store_rejections_txn(
|
||||
txn, event.event_id, context.rejected
|
||||
)
|
||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
@@ -303,22 +293,19 @@ class EventsStore(SQLBaseStore):
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg,
|
||||
hash_bytes
|
||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values=[
|
||||
{
|
||||
for auth_id, _ in event.auth_events:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"auth_id": auth_id,
|
||||
}
|
||||
for auth_id, _ in event.auth_events
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||
self._store_event_reference_hash_txn(
|
||||
@@ -343,33 +330,19 @@ class EventsStore(SQLBaseStore):
|
||||
vals,
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values=[
|
||||
{
|
||||
for e_id, h in event.prev_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": event.room_id,
|
||||
"is_state": True,
|
||||
}
|
||||
for e_id, h in event.prev_state
|
||||
],
|
||||
)
|
||||
|
||||
if is_new_state and not context.rejected:
|
||||
txn.call_after(
|
||||
self.get_current_state_for_key.invalidate,
|
||||
event.room_id, event.type, event.state_key
|
||||
},
|
||||
)
|
||||
|
||||
if (event.type == EventTypes.Name
|
||||
or event.type == EventTypes.Aliases):
|
||||
txn.call_after(
|
||||
self.get_room_name_and_aliases.invalidate,
|
||||
event.room_id
|
||||
)
|
||||
|
||||
if is_new_state and not context.rejected:
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
@@ -383,11 +356,9 @@ class EventsStore(SQLBaseStore):
|
||||
}
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
# invalidate the cache for the redacted event
|
||||
txn.call_after(self._invalidate_get_event_cache, event.redacts)
|
||||
self._invalidate_get_event_cache(event.redacts)
|
||||
txn.execute(
|
||||
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
|
||||
(event.event_id, event.redacts)
|
||||
@@ -424,409 +395,3 @@ class EventsStore(SQLBaseStore):
|
||||
return self.runInteraction(
|
||||
"have_events", f,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events(self, event_ids, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
if not event_ids:
|
||||
defer.returnValue([])
|
||||
|
||||
event_map = self._get_events_from_cache(
|
||||
event_ids,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
missing_events_ids = [e for e in event_ids if e not in event_map]
|
||||
|
||||
if not missing_events_ids:
|
||||
defer.returnValue([
|
||||
event_map[e_id] for e_id in event_ids
|
||||
if e_id in event_map and event_map[e_id]
|
||||
])
|
||||
|
||||
missing_events = yield self._enqueue_events(
|
||||
missing_events_ids,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
event_map.update(missing_events)
|
||||
|
||||
defer.returnValue([
|
||||
event_map[e_id] for e_id in event_ids
|
||||
if e_id in event_map and event_map[e_id]
|
||||
])
|
||||
|
||||
def _get_events_txn(self, txn, event_ids, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
if not event_ids:
|
||||
return []
|
||||
|
||||
event_map = self._get_events_from_cache(
|
||||
event_ids,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
missing_events_ids = [e for e in event_ids if e not in event_map]
|
||||
|
||||
if not missing_events_ids:
|
||||
return [
|
||||
event_map[e_id] for e_id in event_ids
|
||||
if e_id in event_map and event_map[e_id]
|
||||
]
|
||||
|
||||
missing_events = self._fetch_events_txn(
|
||||
txn,
|
||||
missing_events_ids,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
event_map.update(missing_events)
|
||||
|
||||
return [
|
||||
event_map[e_id] for e_id in event_ids
|
||||
if e_id in event_map and event_map[e_id]
|
||||
]
|
||||
|
||||
def _invalidate_get_event_cache(self, event_id):
|
||||
for check_redacted in (False, True):
|
||||
for get_prev_content in (False, True):
|
||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
||||
get_prev_content)
|
||||
|
||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
|
||||
events = self._get_events_txn(
|
||||
txn, [event_id],
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
return events[0] if events else None
|
||||
|
||||
def _get_events_from_cache(self, events, check_redacted, get_prev_content,
|
||||
allow_rejected):
|
||||
event_map = {}
|
||||
|
||||
for event_id in events:
|
||||
try:
|
||||
ret = self._get_event_cache.get(
|
||||
event_id, check_redacted, get_prev_content
|
||||
)
|
||||
|
||||
if allow_rejected or not ret.rejected_reason:
|
||||
event_map[event_id] = ret
|
||||
else:
|
||||
event_map[event_id] = None
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
return event_map
|
||||
|
||||
def _do_fetch(self, conn):
|
||||
"""Takes a database connection and waits for requests for events from
|
||||
the _event_fetch_list queue.
|
||||
"""
|
||||
event_list = []
|
||||
i = 0
|
||||
while True:
|
||||
try:
|
||||
with self._event_fetch_lock:
|
||||
event_list = self._event_fetch_list
|
||||
self._event_fetch_list = []
|
||||
|
||||
if not event_list:
|
||||
single_threaded = self.database_engine.single_threaded
|
||||
if single_threaded or i > EVENT_QUEUE_ITERATIONS:
|
||||
self._event_fetch_ongoing -= 1
|
||||
return
|
||||
else:
|
||||
self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
|
||||
i += 1
|
||||
continue
|
||||
i = 0
|
||||
|
||||
event_id_lists = zip(*event_list)[0]
|
||||
event_ids = [
|
||||
item for sublist in event_id_lists for item in sublist
|
||||
]
|
||||
|
||||
rows = self._new_transaction(
|
||||
conn, "do_fetch", [], self._fetch_event_rows, event_ids
|
||||
)
|
||||
|
||||
row_dict = {
|
||||
r["event_id"]: r
|
||||
for r in rows
|
||||
}
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
def fire(lst, res):
|
||||
for ids, d in lst:
|
||||
if not d.called:
|
||||
try:
|
||||
d.callback([
|
||||
res[i]
|
||||
for i in ids
|
||||
if i in res
|
||||
])
|
||||
except:
|
||||
logger.exception("Failed to callback")
|
||||
reactor.callFromThread(fire, event_list, row_dict)
|
||||
except Exception as e:
|
||||
logger.exception("do_fetch")
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
def fire(evs):
|
||||
for _, d in evs:
|
||||
if not d.called:
|
||||
d.errback(e)
|
||||
|
||||
if event_list:
|
||||
reactor.callFromThread(fire, event_list)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _enqueue_events(self, events, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
"""Fetches events from the database using the _event_fetch_list. This
|
||||
allows batch and bulk fetching of events - it allows us to fetch events
|
||||
without having to create a new transaction for each request for events.
|
||||
"""
|
||||
if not events:
|
||||
defer.returnValue({})
|
||||
|
||||
events_d = defer.Deferred()
|
||||
with self._event_fetch_lock:
|
||||
self._event_fetch_list.append(
|
||||
(events, events_d)
|
||||
)
|
||||
|
||||
self._event_fetch_lock.notify()
|
||||
|
||||
if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
|
||||
self._event_fetch_ongoing += 1
|
||||
should_start = True
|
||||
else:
|
||||
should_start = False
|
||||
|
||||
if should_start:
|
||||
self.runWithConnection(
|
||||
self._do_fetch
|
||||
)
|
||||
|
||||
rows = yield preserve_context_over_deferred(events_d)
|
||||
|
||||
if not allow_rejected:
|
||||
rows[:] = [r for r in rows if not r["rejects"]]
|
||||
|
||||
res = yield defer.gatherResults(
|
||||
[
|
||||
self._get_event_from_row(
|
||||
row["internal_metadata"], row["json"], row["redacts"],
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
rejected_reason=row["rejects"],
|
||||
)
|
||||
for row in rows
|
||||
],
|
||||
consumeErrors=True
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
e.event_id: e
|
||||
for e in res if e
|
||||
})
|
||||
|
||||
def _fetch_event_rows(self, txn, events):
|
||||
rows = []
|
||||
N = 200
|
||||
for i in range(1 + len(events) / N):
|
||||
evs = events[i*N:(i + 1)*N]
|
||||
if not evs:
|
||||
break
|
||||
|
||||
sql = (
|
||||
"SELECT "
|
||||
" e.event_id as event_id, "
|
||||
" e.internal_metadata,"
|
||||
" e.json,"
|
||||
" r.redacts as redacts,"
|
||||
" rej.event_id as rejects "
|
||||
" FROM event_json as e"
|
||||
" LEFT JOIN rejections as rej USING (event_id)"
|
||||
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
|
||||
" WHERE e.event_id IN (%s)"
|
||||
) % (",".join(["?"]*len(evs)),)
|
||||
|
||||
txn.execute(sql, evs)
|
||||
rows.extend(self.cursor_to_dict(txn))
|
||||
|
||||
return rows
|
||||
|
||||
def _fetch_events_txn(self, txn, events, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
if not events:
|
||||
return {}
|
||||
|
||||
rows = self._fetch_event_rows(
|
||||
txn, events,
|
||||
)
|
||||
|
||||
if not allow_rejected:
|
||||
rows[:] = [r for r in rows if not r["rejects"]]
|
||||
|
||||
res = [
|
||||
self._get_event_from_row_txn(
|
||||
txn,
|
||||
row["internal_metadata"], row["json"], row["redacts"],
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
rejected_reason=row["rejects"],
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return {
|
||||
r.event_id: r
|
||||
for r in res
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_event_from_row(self, internal_metadata, js, redacted,
|
||||
check_redacted=True, get_prev_content=False,
|
||||
rejected_reason=None):
|
||||
d = json.loads(js)
|
||||
internal_metadata = json.loads(internal_metadata)
|
||||
|
||||
if rejected_reason:
|
||||
rejected_reason = yield self._simple_select_one_onecol(
|
||||
table="rejections",
|
||||
keyvalues={"event_id": rejected_reason},
|
||||
retcol="reason",
|
||||
desc="_get_event_from_row",
|
||||
)
|
||||
|
||||
ev = FrozenEvent(
|
||||
d,
|
||||
internal_metadata_dict=internal_metadata,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
|
||||
if check_redacted and redacted:
|
||||
ev = prune_event(ev)
|
||||
|
||||
redaction_id = yield self._simple_select_one_onecol(
|
||||
table="redactions",
|
||||
keyvalues={"redacts": ev.event_id},
|
||||
retcol="event_id",
|
||||
desc="_get_event_from_row",
|
||||
)
|
||||
|
||||
ev.unsigned["redacted_by"] = redaction_id
|
||||
# Get the redaction event.
|
||||
|
||||
because = yield self.get_event(
|
||||
redaction_id,
|
||||
check_redacted=False,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if because:
|
||||
ev.unsigned["redacted_because"] = because
|
||||
|
||||
if get_prev_content and "replaces_state" in ev.unsigned:
|
||||
prev = yield self.get_event(
|
||||
ev.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
allow_none=True,
|
||||
)
|
||||
if prev:
|
||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
|
||||
self._get_event_cache.prefill(
|
||||
ev.event_id, check_redacted, get_prev_content, ev
|
||||
)
|
||||
|
||||
defer.returnValue(ev)
|
||||
|
||||
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
|
||||
check_redacted=True, get_prev_content=False,
|
||||
rejected_reason=None):
|
||||
d = json.loads(js)
|
||||
internal_metadata = json.loads(internal_metadata)
|
||||
|
||||
if rejected_reason:
|
||||
rejected_reason = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="rejections",
|
||||
keyvalues={"event_id": rejected_reason},
|
||||
retcol="reason",
|
||||
)
|
||||
|
||||
ev = FrozenEvent(
|
||||
d,
|
||||
internal_metadata_dict=internal_metadata,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
|
||||
if check_redacted and redacted:
|
||||
ev = prune_event(ev)
|
||||
|
||||
redaction_id = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="redactions",
|
||||
keyvalues={"redacts": ev.event_id},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
ev.unsigned["redacted_by"] = redaction_id
|
||||
# Get the redaction event.
|
||||
|
||||
because = self._get_event_txn(
|
||||
txn,
|
||||
redaction_id,
|
||||
check_redacted=False
|
||||
)
|
||||
|
||||
if because:
|
||||
ev.unsigned["redacted_because"] = because
|
||||
|
||||
if get_prev_content and "replaces_state" in ev.unsigned:
|
||||
prev = self._get_event_txn(
|
||||
txn,
|
||||
ev.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
)
|
||||
if prev:
|
||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||
|
||||
self._get_event_cache.prefill(
|
||||
ev.event_id, check_redacted, get_prev_content, ev
|
||||
)
|
||||
|
||||
return ev
|
||||
|
||||
def _parse_events(self, rows):
|
||||
return self.runInteraction(
|
||||
"_parse_events", self._parse_events_txn, rows
|
||||
)
|
||||
|
||||
def _parse_events_txn(self, txn, rows):
|
||||
event_ids = [r["event_id"] for r in rows]
|
||||
|
||||
return self._get_events_txn(txn, event_ids)
|
||||
|
||||
def _has_been_redacted_txn(self, txn, event):
|
||||
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
|
||||
txn.execute(sql, (event.event_id,))
|
||||
result = txn.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
@@ -13,9 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
from twisted.internet import defer
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class PresenceStore(SQLBaseStore):
|
||||
@@ -89,48 +87,31 @@ class PresenceStore(SQLBaseStore):
|
||||
desc="add_presence_list_pending",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_presence_list_accepted(self, observer_localpart, observed_userid):
|
||||
result = yield self._simple_update_one(
|
||||
return self._simple_update_one(
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
updatevalues={"accepted": True},
|
||||
desc="set_presence_list_accepted",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||
defer.returnValue(result)
|
||||
|
||||
def get_presence_list(self, observer_localpart, accepted=None):
|
||||
if accepted:
|
||||
return self.get_presence_list_accepted(observer_localpart)
|
||||
else:
|
||||
keyvalues = {"user_id": observer_localpart}
|
||||
if accepted is not None:
|
||||
keyvalues["accepted"] = accepted
|
||||
keyvalues = {"user_id": observer_localpart}
|
||||
if accepted is not None:
|
||||
keyvalues["accepted"] = accepted
|
||||
|
||||
return self._simple_select_list(
|
||||
table="presence_list",
|
||||
keyvalues=keyvalues,
|
||||
retcols=["observed_user_id", "accepted"],
|
||||
desc="get_presence_list",
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_presence_list_accepted(self, observer_localpart):
|
||||
return self._simple_select_list(
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart, "accepted": True},
|
||||
keyvalues=keyvalues,
|
||||
retcols=["observed_user_id", "accepted"],
|
||||
desc="get_presence_list_accepted",
|
||||
desc="get_presence_list",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def del_presence_list(self, observer_localpart, observed_userid):
|
||||
yield self._simple_delete_one(
|
||||
return self._simple_delete_one(
|
||||
table="presence_list",
|
||||
keyvalues={"user_id": observer_localpart,
|
||||
"observed_user_id": observed_userid},
|
||||
desc="del_presence_list",
|
||||
)
|
||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
||||
|
||||
@@ -13,60 +13,61 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
import collections
|
||||
|
||||
from ._base import SQLBaseStore, Table
|
||||
from twisted.internet import defer
|
||||
|
||||
import logging
|
||||
import copy
|
||||
import simplejson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PushRuleStore(SQLBaseStore):
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
def get_push_rules_for_user(self, user_name):
|
||||
rows = yield self._simple_select_list(
|
||||
table=PushRuleTable.table_name,
|
||||
keyvalues={
|
||||
"user_name": user_name,
|
||||
},
|
||||
retcols=PushRuleTable.fields,
|
||||
desc="get_push_rules_enabled_for_user",
|
||||
sql = (
|
||||
"SELECT "+",".join(PushRuleTable.fields)+" "
|
||||
"FROM "+PushRuleTable.table_name+" "
|
||||
"WHERE user_name = ? "
|
||||
"ORDER BY priority_class DESC, priority DESC"
|
||||
)
|
||||
rows = yield self._execute("get_push_rules_for_user", None, sql, user_name)
|
||||
|
||||
rows.sort(
|
||||
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
|
||||
)
|
||||
dicts = []
|
||||
for r in rows:
|
||||
d = {}
|
||||
for i, f in enumerate(PushRuleTable.fields):
|
||||
d[f] = r[i]
|
||||
dicts.append(d)
|
||||
|
||||
defer.returnValue(rows)
|
||||
defer.returnValue(dicts)
|
||||
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
def get_push_rules_enabled_for_user(self, user_name):
|
||||
results = yield self._simple_select_list(
|
||||
table=PushRuleEnableTable.table_name,
|
||||
keyvalues={
|
||||
'user_name': user_name
|
||||
},
|
||||
retcols=PushRuleEnableTable.fields,
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name},
|
||||
PushRuleEnableTable.fields,
|
||||
desc="get_push_rules_enabled_for_user",
|
||||
)
|
||||
defer.returnValue({
|
||||
r['rule_id']: False if r['enabled'] == 0 else True for r in results
|
||||
})
|
||||
defer.returnValue(
|
||||
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_push_rule(self, before, after, **kwargs):
|
||||
vals = kwargs
|
||||
vals = copy.copy(kwargs)
|
||||
if 'conditions' in vals:
|
||||
vals['conditions'] = json.dumps(vals['conditions'])
|
||||
if 'actions' in vals:
|
||||
vals['actions'] = json.dumps(vals['actions'])
|
||||
|
||||
# we could check the rest of the keys are valid column names
|
||||
# but sqlite will do that anyway so I think it's just pointless.
|
||||
vals.pop("id", None)
|
||||
if 'id' in vals:
|
||||
del vals['id']
|
||||
|
||||
if before or after:
|
||||
ret = yield self.runInteraction(
|
||||
@@ -86,39 +87,39 @@ class PushRuleStore(SQLBaseStore):
|
||||
defer.returnValue(ret)
|
||||
|
||||
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
|
||||
after = kwargs.pop("after", None)
|
||||
relative_to_rule = kwargs.pop("before", after)
|
||||
after = None
|
||||
relative_to_rule = None
|
||||
if 'after' in kwargs and kwargs['after']:
|
||||
after = kwargs['after']
|
||||
relative_to_rule = after
|
||||
if 'before' in kwargs and kwargs['before']:
|
||||
relative_to_rule = kwargs['before']
|
||||
|
||||
res = self._simple_select_one_txn(
|
||||
txn,
|
||||
table=PushRuleTable.table_name,
|
||||
keyvalues={
|
||||
"user_name": user_name,
|
||||
"rule_id": relative_to_rule,
|
||||
},
|
||||
retcols=["priority_class", "priority"],
|
||||
allow_none=True,
|
||||
# get the priority of the rule we're inserting after/before
|
||||
sql = (
|
||||
"SELECT priority_class, priority FROM ? "
|
||||
"WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
|
||||
)
|
||||
|
||||
txn.execute(sql, (user_name, relative_to_rule))
|
||||
res = txn.fetchall()
|
||||
if not res:
|
||||
raise RuleNotFoundException(
|
||||
"before/after rule not found: %s" % (relative_to_rule,)
|
||||
)
|
||||
|
||||
priority_class = res["priority_class"]
|
||||
base_rule_priority = res["priority"]
|
||||
priority_class, base_rule_priority = res[0]
|
||||
|
||||
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
|
||||
raise InconsistentRuleException(
|
||||
"Given priority class does not match class of relative rule"
|
||||
)
|
||||
|
||||
new_rule = kwargs
|
||||
new_rule.pop("before", None)
|
||||
new_rule.pop("after", None)
|
||||
new_rule = copy.copy(kwargs)
|
||||
if 'before' in new_rule:
|
||||
del new_rule['before']
|
||||
if 'after' in new_rule:
|
||||
del new_rule['after']
|
||||
new_rule['priority_class'] = priority_class
|
||||
new_rule['user_name'] = user_name
|
||||
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
|
||||
|
||||
# check if the priority before/after is free
|
||||
new_rule_priority = base_rule_priority
|
||||
@@ -152,19 +153,12 @@ class PushRuleStore(SQLBaseStore):
|
||||
|
||||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
)
|
||||
# now insert the new rule
|
||||
sql = "INSERT INTO "+PushRuleTable.table_name+" ("
|
||||
sql += ",".join(new_rule.keys())+") VALUES ("
|
||||
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table=PushRuleTable.table_name,
|
||||
values=new_rule,
|
||||
)
|
||||
txn.execute(sql, new_rule.values())
|
||||
|
||||
def _add_push_rule_highest_priority_txn(self, txn, user_name,
|
||||
priority_class, **kwargs):
|
||||
@@ -182,24 +176,18 @@ class PushRuleStore(SQLBaseStore):
|
||||
new_prio = highest_prio + 1
|
||||
|
||||
# and insert the new rule
|
||||
new_rule = kwargs
|
||||
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
|
||||
new_rule = copy.copy(kwargs)
|
||||
if 'id' in new_rule:
|
||||
del new_rule['id']
|
||||
new_rule['user_name'] = user_name
|
||||
new_rule['priority_class'] = priority_class
|
||||
new_rule['priority'] = new_prio
|
||||
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
)
|
||||
sql = "INSERT INTO "+PushRuleTable.table_name+" ("
|
||||
sql += ",".join(new_rule.keys())+") VALUES ("
|
||||
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table=PushRuleTable.table_name,
|
||||
values=new_rule,
|
||||
)
|
||||
txn.execute(sql, new_rule.values())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_push_rule(self, user_name, rule_id):
|
||||
@@ -218,32 +206,13 @@ class PushRuleStore(SQLBaseStore):
|
||||
desc="delete_push_rule",
|
||||
)
|
||||
|
||||
self.get_push_rules_for_user.invalidate(user_name)
|
||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||
ret = yield self.runInteraction(
|
||||
"_set_push_rule_enabled_txn",
|
||||
self._set_push_rule_enabled_txn,
|
||||
user_name, rule_id, enabled
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def _set_push_rule_enabled_txn(self, txn, user_name, rule_id, enabled):
|
||||
new_id = self._push_rules_enable_id_gen.get_next_txn(txn)
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
yield self._simple_upsert(
|
||||
PushRuleEnableTable.table_name,
|
||||
{'user_name': user_name, 'rule_id': rule_id},
|
||||
{'enabled': 1 if enabled else 0},
|
||||
{'id': new_id},
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_for_user.invalidate, user_name
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
||||
{'enabled': enabled},
|
||||
desc="set_push_rule_enabled",
|
||||
)
|
||||
|
||||
|
||||
@@ -255,7 +224,7 @@ class InconsistentRuleException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class PushRuleTable(object):
|
||||
class PushRuleTable(Table):
|
||||
table_name = "push_rules"
|
||||
|
||||
fields = [
|
||||
@@ -268,8 +237,10 @@ class PushRuleTable(object):
|
||||
"actions",
|
||||
]
|
||||
|
||||
EntryType = collections.namedtuple("PushRuleEntry", fields)
|
||||
|
||||
class PushRuleEnableTable(object):
|
||||
|
||||
class PushRuleEnableTable(Table):
|
||||
table_name = "push_rules_enable"
|
||||
|
||||
fields = [
|
||||
|
||||
@@ -112,15 +112,14 @@ class RegistrationStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_delete_access_tokens_apart_from(self, user_id, token_id):
|
||||
yield self.runInteraction(
|
||||
"user_delete_access_tokens_apart_from",
|
||||
self._user_delete_access_tokens_apart_from, user_id, token_id
|
||||
)
|
||||
rows = yield self.get_user_by_id(user_id)
|
||||
if len(rows) == 0:
|
||||
raise Exception("No such user!")
|
||||
|
||||
def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id):
|
||||
txn.execute(
|
||||
yield self._execute(
|
||||
"delete_access_tokens_apart_from", None,
|
||||
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
|
||||
(user_id, token_id)
|
||||
rows[0]['id'], token_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -202,15 +201,15 @@ class RegistrationStore(SQLBaseStore):
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_id_by_threepid(self, medium, address):
|
||||
def get_user_by_threepid(self, medium, address):
|
||||
ret = yield self._simple_select_one(
|
||||
"user_threepids",
|
||||
{
|
||||
"medium": medium,
|
||||
"address": address
|
||||
},
|
||||
['user_id'], True, 'get_user_id_by_threepid'
|
||||
['user'], True, 'get_user_by_threepid'
|
||||
)
|
||||
if ret:
|
||||
defer.returnValue(ret['user_id'])
|
||||
defer.returnValue(ret['user'])
|
||||
defer.returnValue(None)
|
||||
|
||||
@@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
import collections
|
||||
import logging
|
||||
@@ -186,7 +186,6 @@ class RoomStore(SQLBaseStore):
|
||||
}
|
||||
)
|
||||
|
||||
@cached()
|
||||
@defer.inlineCallbacks
|
||||
def get_room_name_and_aliases(self, room_id):
|
||||
def f(txn):
|
||||
|
||||
@@ -64,9 +64,8 @@ class RoomMemberStore(SQLBaseStore):
|
||||
}
|
||||
)
|
||||
|
||||
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
||||
self.get_rooms_for_user.invalidate(target_user_id)
|
||||
self.get_joined_hosts_for_room.invalidate(event.room_id)
|
||||
|
||||
def get_room_member(self, user_id, room_id):
|
||||
"""Retrieve the current state of a room member.
|
||||
@@ -77,18 +76,17 @@ class RoomMemberStore(SQLBaseStore):
|
||||
Returns:
|
||||
Deferred: Results in a MembershipEvent or None.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_room_member",
|
||||
self._get_members_events_txn,
|
||||
room_id,
|
||||
user_id=user_id,
|
||||
).addCallback(
|
||||
self._get_events
|
||||
).addCallback(
|
||||
lambda events: events[0] if events else None
|
||||
)
|
||||
def f(txn):
|
||||
events = self._get_members_events_txn(
|
||||
txn,
|
||||
room_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return events[0] if events else None
|
||||
|
||||
return self.runInteraction("get_room_member", f)
|
||||
|
||||
@cached()
|
||||
def get_users_in_room(self, room_id):
|
||||
def f(txn):
|
||||
|
||||
@@ -112,12 +110,15 @@ class RoomMemberStore(SQLBaseStore):
|
||||
Returns:
|
||||
list of namedtuples representing the members in this room.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_room_members",
|
||||
self._get_members_events_txn,
|
||||
room_id,
|
||||
membership=membership,
|
||||
).addCallback(self._get_events)
|
||||
|
||||
def f(txn):
|
||||
return self._get_members_events_txn(
|
||||
txn,
|
||||
room_id,
|
||||
membership=membership,
|
||||
)
|
||||
|
||||
return self.runInteraction("get_room_members", f)
|
||||
|
||||
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
|
||||
""" Get all the rooms for this user where the membership for this user
|
||||
@@ -189,14 +190,14 @@ class RoomMemberStore(SQLBaseStore):
|
||||
return self.runInteraction(
|
||||
"get_members_query", self._get_members_events_txn,
|
||||
where_clause, where_values
|
||||
).addCallbacks(self._get_events)
|
||||
)
|
||||
|
||||
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
|
||||
rows = self._get_members_rows_txn(
|
||||
txn,
|
||||
room_id, membership, user_id,
|
||||
)
|
||||
return [r["event_id"] for r in rows]
|
||||
return self._get_events_txn(txn, [r["event_id"] for r in rows])
|
||||
|
||||
def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
|
||||
where_clause = "c.room_id = ?"
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_upgrade(cur, *args, **kwargs):
|
||||
def run_upgrade(cur):
|
||||
cur.execute("SELECT id, regex FROM application_services_regex")
|
||||
for row in cur.fetchall():
|
||||
try:
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
DROP INDEX IF EXISTS sent_transaction_dest;
|
||||
DROP INDEX IF EXISTS sent_transaction_sent;
|
||||
DROP INDEX IF EXISTS user_ips_user;
|
||||
@@ -1,32 +0,0 @@
|
||||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS new_server_keys_json (
|
||||
server_name TEXT NOT NULL, -- Server name.
|
||||
key_id TEXT NOT NULL, -- Requested key id.
|
||||
from_server TEXT NOT NULL, -- Which server the keys were fetched from.
|
||||
ts_added_ms BIGINT NOT NULL, -- When the keys were fetched
|
||||
ts_valid_until_ms BIGINT NOT NULL, -- When this version of the keys exipires.
|
||||
key_json bytea NOT NULL, -- JSON certificate for the remote server.
|
||||
CONSTRAINT server_keys_json_uniqueness UNIQUE (server_name, key_id, from_server)
|
||||
);
|
||||
|
||||
INSERT INTO new_server_keys_json
|
||||
SELECT server_name, key_id, from_server,ts_added_ms, ts_valid_until_ms, key_json FROM server_keys_json ;
|
||||
|
||||
DROP TABLE server_keys_json;
|
||||
|
||||
ALTER TABLE new_server_keys_json RENAME TO server_keys_json;
|
||||
@@ -1,19 +0,0 @@
|
||||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
CREATE INDEX events_order_topo_stream_room ON events(
|
||||
topological_ordering, stream_ordering, room_id
|
||||
);
|
||||
@@ -1,76 +0,0 @@
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
Main purpose of this upgrade is to change the unique key on the
|
||||
pushers table again (it was missed when the v16 full schema was
|
||||
made) but this also changes the pushkey and data columns to text.
|
||||
When selecting a bytea column into a text column, postgres inserts
|
||||
the hex encoded data, and there's no portable way of getting the
|
||||
UTF-8 bytes, so we have to do it in Python.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
logger.info("Porting pushers table...")
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pushers2 (
|
||||
id BIGINT PRIMARY KEY,
|
||||
user_name TEXT NOT NULL,
|
||||
access_token BIGINT DEFAULT NULL,
|
||||
profile_tag VARCHAR(32) NOT NULL,
|
||||
kind VARCHAR(8) NOT NULL,
|
||||
app_id VARCHAR(64) NOT NULL,
|
||||
app_display_name VARCHAR(64) NOT NULL,
|
||||
device_display_name VARCHAR(128) NOT NULL,
|
||||
pushkey TEXT NOT NULL,
|
||||
ts BIGINT NOT NULL,
|
||||
lang VARCHAR(8),
|
||||
data TEXT,
|
||||
last_token TEXT,
|
||||
last_success BIGINT,
|
||||
failing_since BIGINT,
|
||||
UNIQUE (app_id, pushkey, user_name)
|
||||
)
|
||||
""")
|
||||
cur.execute("""SELECT
|
||||
id, user_name, access_token, profile_tag, kind,
|
||||
app_id, app_display_name, device_display_name,
|
||||
pushkey, ts, lang, data, last_token, last_success,
|
||||
failing_since
|
||||
FROM pushers
|
||||
""")
|
||||
count = 0
|
||||
for row in cur.fetchall():
|
||||
row = list(row)
|
||||
row[8] = bytes(row[8]).decode("utf-8")
|
||||
row[11] = bytes(row[11]).decode("utf-8")
|
||||
cur.execute(database_engine.convert_param_style("""
|
||||
INSERT into pushers2 (
|
||||
id, user_name, access_token, profile_tag, kind,
|
||||
app_id, app_display_name, device_display_name,
|
||||
pushkey, ts, lang, data, last_token, last_success,
|
||||
failing_since
|
||||
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
|
||||
row
|
||||
)
|
||||
count += 1
|
||||
cur.execute("DROP TABLE pushers")
|
||||
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
|
||||
logger.info("Moved %d pushers to new table", count)
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore, cached
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -43,7 +43,6 @@ class StateStore(SQLBaseStore):
|
||||
* `state_groups_state`: Maps state group to state events.
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_groups(self, event_ids):
|
||||
""" Get the state groups for the given list of event_ids
|
||||
|
||||
@@ -72,33 +71,17 @@ class StateStore(SQLBaseStore):
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
res[group] = state_ids
|
||||
state = self._get_events_txn(txn, state_ids)
|
||||
|
||||
res[group] = state
|
||||
|
||||
return res
|
||||
|
||||
states = yield self.runInteraction(
|
||||
return self.runInteraction(
|
||||
"get_state_groups",
|
||||
f,
|
||||
)
|
||||
|
||||
state_list = yield defer.gatherResults(
|
||||
[
|
||||
self._fetch_events_for_group(group, vals)
|
||||
for group, vals in states.items()
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
defer.returnValue(dict(state_list))
|
||||
|
||||
@cached(num_args=1)
|
||||
def _fetch_events_for_group(self, state_group, events):
|
||||
return self._get_events(
|
||||
events, get_prev_content=False
|
||||
).addCallback(
|
||||
lambda evs: (state_group, evs)
|
||||
)
|
||||
|
||||
def _store_state_groups_txn(self, txn, event, context):
|
||||
if context.current_state is None:
|
||||
return
|
||||
@@ -121,20 +104,18 @@ class StateStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values=[
|
||||
{
|
||||
for state in state_events.values():
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_groups_state",
|
||||
values={
|
||||
"state_group": state_group,
|
||||
"room_id": state.room_id,
|
||||
"type": state.type,
|
||||
"state_key": state.state_key,
|
||||
"event_id": state.event_id,
|
||||
}
|
||||
for state in state_events.values()
|
||||
],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
@@ -147,12 +128,6 @@ class StateStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
if event_type and state_key is not None:
|
||||
result = yield self.get_current_state_for_key(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT event_id FROM current_state_events"
|
||||
@@ -169,29 +144,11 @@ class StateStore(SQLBaseStore):
|
||||
args = (room_id, )
|
||||
|
||||
txn.execute(sql, args)
|
||||
results = txn.fetchall()
|
||||
results = self.cursor_to_dict(txn)
|
||||
|
||||
return [r[0] for r in results]
|
||||
return self._parse_events_txn(txn, results)
|
||||
|
||||
event_ids = yield self.runInteraction("get_current_state", f)
|
||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||
defer.returnValue(events)
|
||||
|
||||
@cached(num_args=3)
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state_for_key(self, room_id, event_type, state_key):
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT event_id FROM current_state_events"
|
||||
" WHERE room_id = ? AND type = ? AND state_key = ?"
|
||||
)
|
||||
|
||||
args = (room_id, event_type, state_key)
|
||||
txn.execute(sql, args)
|
||||
results = txn.fetchall()
|
||||
return [r[0] for r in results]
|
||||
event_ids = yield self.runInteraction("get_current_state_for_key", f)
|
||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||
events = yield self.runInteraction("get_current_state", f)
|
||||
defer.returnValue(events)
|
||||
|
||||
|
||||
|
||||
@@ -37,9 +37,11 @@ from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.util.logutils import log_function
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
@@ -53,26 +55,76 @@ _STREAM_TOKEN = "stream"
|
||||
_TOPOLOGICAL_TOKEN = "topological"
|
||||
|
||||
|
||||
def lower_bound(token):
|
||||
if token.topological is None:
|
||||
return "(%d < %s)" % (token.stream, "stream_ordering")
|
||||
else:
|
||||
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
||||
token.topological, "topological_ordering",
|
||||
token.topological, "topological_ordering",
|
||||
token.stream, "stream_ordering",
|
||||
)
|
||||
class _StreamToken(namedtuple("_StreamToken", "topological stream")):
|
||||
"""Tokens are positions between events. The token "s1" comes after event 1.
|
||||
|
||||
s0 s1
|
||||
| |
|
||||
[0] V [1] V [2]
|
||||
|
||||
def upper_bound(token):
|
||||
if token.topological is None:
|
||||
return "(%d >= %s)" % (token.stream, "stream_ordering")
|
||||
else:
|
||||
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
||||
token.topological, "topological_ordering",
|
||||
token.topological, "topological_ordering",
|
||||
token.stream, "stream_ordering",
|
||||
)
|
||||
Tokens can either be a point in the live event stream or a cursor going
|
||||
through historic events.
|
||||
|
||||
When traversing the live event stream events are ordered by when they
|
||||
arrived at the homeserver.
|
||||
|
||||
When traversing historic events the events are ordered by their depth in
|
||||
the event graph "topological_ordering" and then by when they arrived at the
|
||||
homeserver "stream_ordering".
|
||||
|
||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||
event it comes after. Historic tokens start with a "t" followed by the
|
||||
"topological_ordering" id of the event it comes after, follewed by "-",
|
||||
followed by the "stream_ordering" id of the event it comes after.
|
||||
"""
|
||||
__slots__ = []
|
||||
|
||||
@classmethod
|
||||
def parse(cls, string):
|
||||
try:
|
||||
if string[0] == 's':
|
||||
return cls(topological=None, stream=int(string[1:]))
|
||||
if string[0] == 't':
|
||||
parts = string[1:].split('-', 1)
|
||||
return cls(topological=int(parts[0]), stream=int(parts[1]))
|
||||
except:
|
||||
pass
|
||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||
|
||||
@classmethod
|
||||
def parse_stream_token(cls, string):
|
||||
try:
|
||||
if string[0] == 's':
|
||||
return cls(topological=None, stream=int(string[1:]))
|
||||
except:
|
||||
pass
|
||||
raise SynapseError(400, "Invalid token %r" % (string,))
|
||||
|
||||
def __str__(self):
|
||||
if self.topological is not None:
|
||||
return "t%d-%d" % (self.topological, self.stream)
|
||||
else:
|
||||
return "s%d" % (self.stream,)
|
||||
|
||||
def lower_bound(self):
|
||||
if self.topological is None:
|
||||
return "(%d < %s)" % (self.stream, "stream_ordering")
|
||||
else:
|
||||
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
||||
self.topological, "topological_ordering",
|
||||
self.topological, "topological_ordering",
|
||||
self.stream, "stream_ordering",
|
||||
)
|
||||
|
||||
def upper_bound(self):
|
||||
if self.topological is None:
|
||||
return "(%d >= %s)" % (self.stream, "stream_ordering")
|
||||
else:
|
||||
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
||||
self.topological, "topological_ordering",
|
||||
self.topological, "topological_ordering",
|
||||
self.stream, "stream_ordering",
|
||||
)
|
||||
|
||||
|
||||
class StreamStore(SQLBaseStore):
|
||||
@@ -87,8 +139,8 @@ class StreamStore(SQLBaseStore):
|
||||
limit = MAX_STREAM_SIZE
|
||||
|
||||
# From and to keys should be integers from ordering.
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key)
|
||||
to_id = RoomStreamToken.parse_stream_token(to_key)
|
||||
from_id = _StreamToken.parse_stream_token(from_key)
|
||||
to_id = _StreamToken.parse_stream_token(to_key)
|
||||
|
||||
if from_key == to_key:
|
||||
defer.returnValue(([], to_key))
|
||||
@@ -182,8 +234,8 @@ class StreamStore(SQLBaseStore):
|
||||
limit = MAX_STREAM_SIZE
|
||||
|
||||
# From and to keys should be integers from ordering.
|
||||
from_id = RoomStreamToken.parse_stream_token(from_key)
|
||||
to_id = RoomStreamToken.parse_stream_token(to_key)
|
||||
from_id = _StreamToken.parse_stream_token(from_key)
|
||||
to_id = _StreamToken.parse_stream_token(to_key)
|
||||
|
||||
if from_key == to_key:
|
||||
return defer.succeed(([], to_key))
|
||||
@@ -224,7 +276,7 @@ class StreamStore(SQLBaseStore):
|
||||
|
||||
return self.runInteraction("get_room_events_stream", f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def paginate_room_events(self, room_id, from_key, to_key=None,
|
||||
direction='b', limit=-1,
|
||||
with_feedback=False):
|
||||
@@ -236,17 +288,17 @@ class StreamStore(SQLBaseStore):
|
||||
args = [False, room_id]
|
||||
if direction == 'b':
|
||||
order = "DESC"
|
||||
bounds = upper_bound(RoomStreamToken.parse(from_key))
|
||||
bounds = _StreamToken.parse(from_key).upper_bound()
|
||||
if to_key:
|
||||
bounds = "%s AND %s" % (
|
||||
bounds, lower_bound(RoomStreamToken.parse(to_key))
|
||||
bounds, _StreamToken.parse(to_key).lower_bound()
|
||||
)
|
||||
else:
|
||||
order = "ASC"
|
||||
bounds = lower_bound(RoomStreamToken.parse(from_key))
|
||||
bounds = _StreamToken.parse(from_key).lower_bound()
|
||||
if to_key:
|
||||
bounds = "%s AND %s" % (
|
||||
bounds, upper_bound(RoomStreamToken.parse(to_key))
|
||||
bounds, _StreamToken.parse(to_key).upper_bound()
|
||||
)
|
||||
|
||||
if int(limit) > 0:
|
||||
@@ -281,30 +333,28 @@ class StreamStore(SQLBaseStore):
|
||||
# when we are going backwards so we subtract one from the
|
||||
# stream part.
|
||||
toke -= 1
|
||||
next_token = str(RoomStreamToken(topo, toke))
|
||||
next_token = str(_StreamToken(topo, toke))
|
||||
else:
|
||||
# TODO (erikj): We should work out what to do here instead.
|
||||
next_token = to_key if to_key else from_key
|
||||
|
||||
return rows, next_token,
|
||||
events = self._get_events_txn(
|
||||
txn,
|
||||
[r["event_id"] for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
|
||||
rows, token = yield self.runInteraction("paginate_room_events", f)
|
||||
self._set_before_and_after(events, rows)
|
||||
|
||||
events = yield self._get_events(
|
||||
[r["event_id"] for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
return events, next_token,
|
||||
|
||||
self._set_before_and_after(events, rows)
|
||||
return self.runInteraction("paginate_room_events", f)
|
||||
|
||||
defer.returnValue((events, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_recent_events_for_room(self, room_id, limit, end_token,
|
||||
with_feedback=False, from_token=None):
|
||||
# TODO (erikj): Handle compressed feedback
|
||||
|
||||
end_token = RoomStreamToken.parse_stream_token(end_token)
|
||||
end_token = _StreamToken.parse_stream_token(end_token)
|
||||
|
||||
if from_token is None:
|
||||
sql = (
|
||||
@@ -315,7 +365,7 @@ class StreamStore(SQLBaseStore):
|
||||
" LIMIT ?"
|
||||
)
|
||||
else:
|
||||
from_token = RoomStreamToken.parse_stream_token(from_token)
|
||||
from_token = _StreamToken.parse_stream_token(from_token)
|
||||
sql = (
|
||||
"SELECT stream_ordering, topological_ordering, event_id"
|
||||
" FROM events"
|
||||
@@ -345,49 +395,30 @@ class StreamStore(SQLBaseStore):
|
||||
# stream part.
|
||||
topo = rows[0]["topological_ordering"]
|
||||
toke = rows[0]["stream_ordering"] - 1
|
||||
start_token = str(RoomStreamToken(topo, toke))
|
||||
start_token = str(_StreamToken(topo, toke))
|
||||
|
||||
token = (start_token, str(end_token))
|
||||
else:
|
||||
token = (str(end_token), str(end_token))
|
||||
|
||||
return rows, token
|
||||
events = self._get_events_txn(
|
||||
txn,
|
||||
[r["event_id"] for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
|
||||
rows, token = yield self.runInteraction(
|
||||
self._set_before_and_after(events, rows)
|
||||
|
||||
return events, token
|
||||
|
||||
return self.runInteraction(
|
||||
"get_recent_events_for_room", get_recent_events_for_room_txn
|
||||
)
|
||||
|
||||
logger.debug("stream before")
|
||||
events = yield self._get_events(
|
||||
[r["event_id"] for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
logger.debug("stream after")
|
||||
|
||||
self._set_before_and_after(events, rows)
|
||||
|
||||
defer.returnValue((events, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_events_max_id(self, direction='f'):
|
||||
def get_room_events_max_id(self):
|
||||
token = yield self._stream_id_gen.get_max_token(self)
|
||||
if direction != 'b':
|
||||
defer.returnValue("s%d" % (token,))
|
||||
else:
|
||||
topo = yield self.runInteraction(
|
||||
"_get_max_topological_txn", self._get_max_topological_txn
|
||||
)
|
||||
defer.returnValue("t%d-%d" % (topo, token))
|
||||
|
||||
def _get_max_topological_txn(self, txn):
|
||||
txn.execute(
|
||||
"SELECT MAX(topological_ordering) FROM events"
|
||||
" WHERE outlier = ?",
|
||||
(False,)
|
||||
)
|
||||
|
||||
rows = txn.fetchall()
|
||||
return rows[0][0] if rows else 0
|
||||
defer.returnValue("s%d" % (token,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_min_token(self):
|
||||
@@ -408,5 +439,5 @@ class StreamStore(SQLBaseStore):
|
||||
stream = row["stream_ordering"]
|
||||
topo = event.depth
|
||||
internal = event.internal_metadata
|
||||
internal.before = str(RoomStreamToken(topo, stream - 1))
|
||||
internal.after = str(RoomStreamToken(topo, stream))
|
||||
internal.before = str(_StreamToken(topo, stream - 1))
|
||||
internal.after = str(_StreamToken(topo, stream))
|
||||
|
||||
@@ -17,7 +17,6 @@ from ._base import SQLBaseStore, cached
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -83,7 +82,7 @@ class TransactionStore(SQLBaseStore):
|
||||
"transaction_id": transaction_id,
|
||||
"origin": origin,
|
||||
"response_code": code,
|
||||
"response_json": buffer(encode_canonical_json(response_dict)),
|
||||
"response_json": response_dict,
|
||||
},
|
||||
or_ignore=True,
|
||||
desc="set_received_txn_response",
|
||||
@@ -162,8 +161,7 @@ class TransactionStore(SQLBaseStore):
|
||||
return self.runInteraction(
|
||||
"delivered_txn",
|
||||
self._delivered_txn,
|
||||
transaction_id, destination, code,
|
||||
buffer(encode_canonical_json(response_dict)),
|
||||
transaction_id, destination, code, response_dict
|
||||
)
|
||||
|
||||
def _delivered_txn(self, txn, transaction_id, destination,
|
||||
|
||||
@@ -78,18 +78,14 @@ class StreamIdGenerator(object):
|
||||
self._current_max = None
|
||||
self._unfinished_ids = deque()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_next(self, store):
|
||||
def get_next_txn(self, txn):
|
||||
"""
|
||||
Usage:
|
||||
with yield stream_id_gen.get_next as stream_id:
|
||||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
if not self._current_max:
|
||||
yield store.runInteraction(
|
||||
"_compute_current_max",
|
||||
self._get_or_compute_current_max,
|
||||
)
|
||||
self._get_or_compute_current_max(txn)
|
||||
|
||||
with self._lock:
|
||||
self._current_max += 1
|
||||
@@ -105,7 +101,7 @@ class StreamIdGenerator(object):
|
||||
with self._lock:
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
defer.returnValue(manager())
|
||||
return manager()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_max_token(self, store):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user