mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-07 01:20:16 +00:00
Compare commits
158 Commits
v0.16.0
...
erikj/file
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9854f4c7ff | ||
|
|
518b3a3f89 | ||
|
|
10f4856b0c | ||
|
|
dfde67a6fe | ||
|
|
10c843fcfb | ||
|
|
58930da52b | ||
|
|
0870588c20 | ||
|
|
f90cf150e2 | ||
|
|
067596d341 | ||
|
|
70d650be2b | ||
|
|
b92e7955be | ||
|
|
c98e1479bd | ||
|
|
67f2c901ea | ||
|
|
eef7778af9 | ||
|
|
a17e7caeb7 | ||
|
|
f0c06ac65c | ||
|
|
76b18df3d9 | ||
|
|
0da24cac8b | ||
|
|
2e3c8acc68 | ||
|
|
8d9a884cee | ||
|
|
896bc6cd46 | ||
|
|
be3548f7e1 | ||
|
|
4adf93e0f7 | ||
|
|
651faee698 | ||
|
|
caf33b2d9b | ||
|
|
8f8798bc0d | ||
|
|
7335f0adda | ||
|
|
ef535178ff | ||
|
|
04dee11e97 | ||
|
|
dd2ccee27d | ||
|
|
b6b0132ac7 | ||
|
|
e34cb5e7dc | ||
|
|
252ee2d979 | ||
|
|
14362bf359 | ||
|
|
1ee2584307 | ||
|
|
507b8bb091 | ||
|
|
d44d11d864 | ||
|
|
2d21d43c34 | ||
|
|
0fb76c71ac | ||
|
|
8bdaf5f7af | ||
|
|
a67bf0b074 | ||
|
|
f18d7546c6 | ||
|
|
3de8168343 | ||
|
|
bb069079bb | ||
|
|
2e5a31f197 | ||
|
|
fc8007dbec | ||
|
|
1238203bc4 | ||
|
|
41f072fd0e | ||
|
|
5a6ef20ef6 | ||
|
|
be8be535f7 | ||
|
|
ab71589c0b | ||
|
|
f328d95cef | ||
|
|
aac546c978 | ||
|
|
f52cb4cd78 | ||
|
|
6783534a0f | ||
|
|
a70688445d | ||
|
|
314b146b2e | ||
|
|
7846ac3125 | ||
|
|
56ec5869c9 | ||
|
|
1ea358b28b | ||
|
|
db74dcda5b | ||
|
|
551fe80bed | ||
|
|
63bb8f0df9 | ||
|
|
70d820c875 | ||
|
|
3f7652c56f | ||
|
|
535b6bfacc | ||
|
|
f7fe0e5f67 | ||
|
|
2455ad8468 | ||
|
|
0b640aa56b | ||
|
|
aa3a4944d5 | ||
|
|
46b7362304 | ||
|
|
870c45913e | ||
|
|
05f1a4596a | ||
|
|
13517e2914 | ||
|
|
b5fb7458d5 | ||
|
|
f73fdb04a6 | ||
|
|
3a4120e49a | ||
|
|
9fe894402f | ||
|
|
0a32208e5d | ||
|
|
774f3a692c | ||
|
|
5cc7564c5c | ||
|
|
0fe0b0eeb6 | ||
|
|
0126a5d60a | ||
|
|
13e334506c | ||
|
|
6b40e4f52a | ||
|
|
d5fb561709 | ||
|
|
d8ec81cc31 | ||
|
|
bc72d381b2 | ||
|
|
4d362a61ea | ||
|
|
00c281f6a4 | ||
|
|
c8cd41cdd8 | ||
|
|
41e4b2efea | ||
|
|
0c13d45522 | ||
|
|
9f1800fba8 | ||
|
|
8f4a9bbc16 | ||
|
|
9ba2bf1570 | ||
|
|
120c238705 | ||
|
|
1c1f633b13 | ||
|
|
6660f37558 | ||
|
|
20e5b46b20 | ||
|
|
0113ad36ee | ||
|
|
3e41de05cc | ||
|
|
2884712ca7 | ||
|
|
ded01c3bf6 | ||
|
|
8c75040c25 | ||
|
|
f1073ad43d | ||
|
|
a352b68acf | ||
|
|
364d616792 | ||
|
|
bde13833cb | ||
|
|
80a1bc7db5 | ||
|
|
f1f70bf4b5 | ||
|
|
dbb5a39b64 | ||
|
|
885ee861f7 | ||
|
|
486b9a6a2d | ||
|
|
1f31381611 | ||
|
|
ed5f43a55a | ||
|
|
09a17f965c | ||
|
|
1e9026e484 | ||
|
|
a60169ea09 | ||
|
|
78a16d395c | ||
|
|
5436848955 | ||
|
|
0477368e9a | ||
|
|
0ef0655b83 | ||
|
|
a64dbae90b | ||
|
|
d41a1a91d3 | ||
|
|
15bf3e3376 | ||
|
|
d9f7fa2e57 | ||
|
|
b31c49d676 | ||
|
|
255c229f23 | ||
|
|
d12134ce37 | ||
|
|
36e2aade87 | ||
|
|
33546b58aa | ||
|
|
fdc015c6e9 | ||
|
|
e9892b4b26 | ||
|
|
50f69e2ef2 | ||
|
|
41b35412bf | ||
|
|
7dbb473339 | ||
|
|
16a8884233 | ||
|
|
5ee5b655b2 | ||
|
|
0a43219a27 | ||
|
|
eba4ff1bcb | ||
|
|
2eab219a70 | ||
|
|
95f305c35a | ||
|
|
6e7dc7c7dd | ||
|
|
b7fbc9bd95 | ||
|
|
919a2c74f6 | ||
|
|
81c07a32fd | ||
|
|
b063784b78 | ||
|
|
defa28efa1 | ||
|
|
690029d1a3 | ||
|
|
d88faf92d1 | ||
|
|
958c968d02 | ||
|
|
746b2f5657 | ||
|
|
efeabd3180 | ||
|
|
1fd6eb695d | ||
|
|
128360d4f0 | ||
|
|
17aab5827a | ||
|
|
1a815fb04f |
51
CHANGES.rst
51
CHANGES.rst
@@ -1,3 +1,52 @@
|
||||
Changes in synapse v0.16.1-r1 (2016-07-08)
|
||||
==========================================
|
||||
|
||||
THIS IS A CRITICAL SECURITY UPDATE.
|
||||
|
||||
This fixes a bug which allowed users' accounts to be accessed by unauthorised
|
||||
users.
|
||||
|
||||
Changes in synapse v0.16.1 (2016-06-20)
|
||||
=======================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix assorted bugs in ``/preview_url`` (PR #872)
|
||||
* Fix TypeError when setting unicode passwords (PR #873)
|
||||
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Turn ``use_frozen_events`` off by default (PR #877)
|
||||
* Disable responding with canonical json for federation (PR #878)
|
||||
|
||||
|
||||
Changes in synapse v0.16.1-rc1 (2016-06-15)
|
||||
===========================================
|
||||
|
||||
Features: None
|
||||
|
||||
Changes:
|
||||
|
||||
* Log requester for ``/publicRoom`` endpoints when possible (PR #856)
|
||||
* 502 on ``/thumbnail`` when can't connect to remote server (PR #862)
|
||||
* Linearize fetching of gaps on incoming events (PR #871)
|
||||
|
||||
|
||||
Bugs fixes:
|
||||
|
||||
* Fix bug where rooms where marked as published by default (PR #857)
|
||||
* Fix bug where joining room with an event with invalid sender (PR #868)
|
||||
* Fix bug where backfilled events were sent down sync streams (PR #869)
|
||||
* Fix bug where outgoing connections could wedge indefinitely, causing push
|
||||
notifications to be unreliable (PR #870)
|
||||
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Improve ``/publicRooms`` performance(PR #859)
|
||||
|
||||
|
||||
Changes in synapse v0.16.0 (2016-06-09)
|
||||
=======================================
|
||||
|
||||
@@ -28,7 +77,7 @@ Bug fixes:
|
||||
* Fix bug where synapse sent malformed transactions to AS's when retrying
|
||||
transactions (Commits 310197b, 8437906)
|
||||
|
||||
Performance Improvements:
|
||||
Performance improvements:
|
||||
|
||||
* Remove event fetching from DB threads (PR #835)
|
||||
* Change the way we cache events (PR #836)
|
||||
|
||||
@@ -43,7 +43,10 @@ Basically, PEP8
|
||||
together, or want to deliberately extend or preserve vertical/horizontal
|
||||
space)
|
||||
|
||||
Comments should follow the google code style. This is so that we can generate
|
||||
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
|
||||
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
||||
This is so that we can generate documentation with
|
||||
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
||||
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
||||
in the sphinx documentation.
|
||||
|
||||
Code should pass pep8 --max-line-length=100 without any warnings.
|
||||
|
||||
@@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
|
||||
server through the use of a secret shared between the Home Server and the
|
||||
TURN server.
|
||||
|
||||
This document described how to install coturn
|
||||
(https://code.google.com/p/coturn/) which also supports the TURN REST API,
|
||||
This document describes how to install coturn
|
||||
(https://github.com/coturn/coturn) which also supports the TURN REST API,
|
||||
and integrate it with synapse.
|
||||
|
||||
coturn Setup
|
||||
============
|
||||
|
||||
You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
|
||||
|
||||
1. Check out coturn::
|
||||
svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
|
||||
|
||||
git clone https://github.com/coturn/coturn.git coturn
|
||||
cd coturn
|
||||
|
||||
2. Configure it::
|
||||
|
||||
./configure
|
||||
|
||||
You may need to install libevent2: if so, you should do so
|
||||
You may need to install ``libevent2``: if so, you should do so
|
||||
in the way recommended by your operating system.
|
||||
You can ignore warnings about lack of database support: a
|
||||
database is unnecessary for this purpose.
|
||||
|
||||
3. Build and install it::
|
||||
|
||||
make
|
||||
make install
|
||||
|
||||
4. Make a config file in /etc/turnserver.conf. You can customise
|
||||
a config file from turnserver.conf.default. The relevant
|
||||
4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
|
||||
lines, with example values, are::
|
||||
|
||||
lt-cred-mech
|
||||
@@ -41,7 +45,7 @@ coturn Setup
|
||||
static-auth-secret=[your secret key here]
|
||||
realm=turn.myserver.org
|
||||
|
||||
See turnserver.conf.default for explanations of the options.
|
||||
See turnserver.conf for explanations of the options.
|
||||
One way to generate the static-auth-secret is with pwgen::
|
||||
|
||||
pwgen -s 64 1
|
||||
@@ -54,6 +58,7 @@ coturn Setup
|
||||
import your private key and certificate.
|
||||
|
||||
7. Start the turn server::
|
||||
|
||||
bin/turnserver -o
|
||||
|
||||
|
||||
|
||||
@@ -70,15 +70,18 @@ cd sytest
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8000}
|
||||
: ${PORT_COUNT=20}
|
||||
|
||||
./jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
mkdir -p var
|
||||
|
||||
echo >&2 "Running sytest with PostgreSQL";
|
||||
./jenkins/install_and_run.sh --python $TOX_BIN/python \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--dendron $WORKSPACE/dendron/bin/dendron \
|
||||
--synchrotron \
|
||||
--pusher \
|
||||
--port-base $PORT_BASE
|
||||
--synchrotron \
|
||||
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
|
||||
|
||||
cd ..
|
||||
|
||||
@@ -44,6 +44,7 @@ cd sytest
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8000}
|
||||
: ${PORT_COUNT=20}
|
||||
|
||||
./jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
@@ -51,7 +52,7 @@ echo >&2 "Running sytest with PostgreSQL";
|
||||
./jenkins/install_and_run.sh --coverage \
|
||||
--python $TOX_BIN/python \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--port-base $PORT_BASE
|
||||
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
|
||||
|
||||
cd ..
|
||||
cp sytest/.coverage.* .
|
||||
|
||||
@@ -41,11 +41,12 @@ cd sytest
|
||||
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8500}
|
||||
: ${PORT_COUNT=20}
|
||||
: ${PORT_BASE:=8000}
|
||||
./jenkins/install_and_run.sh --coverage \
|
||||
--python $TOX_BIN/python \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--port-base $PORT_BASE
|
||||
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
|
||||
|
||||
cd ..
|
||||
cp sytest/.coverage.* .
|
||||
|
||||
@@ -36,7 +36,7 @@
|
||||
<div class="debug">
|
||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
||||
an event was received at {{ reason.received_at|format_ts("%c") }}
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
|
||||
{% if reason.last_sent_ts %}
|
||||
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
|
||||
import sys
|
||||
|
||||
import bcrypt
|
||||
import getpass
|
||||
|
||||
import yaml
|
||||
|
||||
bcrypt_rounds=12
|
||||
password_pepper = ""
|
||||
|
||||
def prompt_for_pass():
|
||||
password = getpass.getpass("Password: ")
|
||||
@@ -28,12 +34,22 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
type=argparse.FileType('r'),
|
||||
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if "config" in args and args.config:
|
||||
config = yaml.safe_load(args.config)
|
||||
bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
|
||||
password_config = config.get("password_config", {})
|
||||
password_pepper = password_config.get("pepper", password_pepper)
|
||||
password = args.password
|
||||
|
||||
if not password:
|
||||
password = prompt_for_pass()
|
||||
|
||||
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
|
||||
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
|
||||
|
||||
|
||||
@@ -25,18 +25,26 @@ import urllib2
|
||||
import yaml
|
||||
|
||||
|
||||
def request_registration(user, password, server_location, shared_secret):
|
||||
def request_registration(user, password, server_location, shared_secret, admin=False):
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
msg=user,
|
||||
digestmod=hashlib.sha1,
|
||||
).hexdigest()
|
||||
)
|
||||
|
||||
mac.update(user)
|
||||
mac.update("\x00")
|
||||
mac.update(password)
|
||||
mac.update("\x00")
|
||||
mac.update("admin" if admin else "notadmin")
|
||||
|
||||
mac = mac.hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
"admin": admin,
|
||||
}
|
||||
|
||||
server_location = server_location.rstrip("/")
|
||||
@@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def register_new_user(user, password, server_location, shared_secret):
|
||||
def register_new_user(user, password, server_location, shared_secret, admin):
|
||||
if not user:
|
||||
try:
|
||||
default_user = getpass.getuser()
|
||||
@@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
|
||||
print "Passwords do not match"
|
||||
sys.exit(1)
|
||||
|
||||
request_registration(user, password, server_location, shared_secret)
|
||||
if not admin:
|
||||
admin = raw_input("Make admin [no]: ")
|
||||
if admin in ("y", "yes", "true"):
|
||||
admin = True
|
||||
else:
|
||||
admin = False
|
||||
|
||||
request_registration(user, password, server_location, shared_secret, bool(admin))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -119,6 +134,11 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--admin",
|
||||
action="store_true",
|
||||
help="Register new user as an admin. Will prompt if omitted.",
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
@@ -151,4 +171,4 @@ if __name__ == "__main__":
|
||||
else:
|
||||
secret = args.shared_secret
|
||||
|
||||
register_new_user(args.user, args.password, args.server_url, secret)
|
||||
register_new_user(args.user, args.password, args.server_url, secret, args.admin)
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.16.0"
|
||||
__version__ = "0.16.1-r1"
|
||||
|
||||
@@ -637,17 +637,22 @@ class Auth(object):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
|
||||
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
user_id = None
|
||||
guest = False
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
|
||||
user_id = caveat.caveat_id[len(user_prefix):]
|
||||
user = UserID.from_string(user_id)
|
||||
elif caveat.caveat_id == "guest = true":
|
||||
guest = True
|
||||
|
||||
self.validate_macaroon(
|
||||
macaroon, rights, self.hs.config.expire_access_token,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
@@ -692,7 +697,7 @@ class Auth(object):
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry):
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
|
||||
"""
|
||||
validate that a Macaroon is understood by and was signed by this server.
|
||||
|
||||
@@ -707,7 +712,7 @@ class Auth(object):
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = " + type_string)
|
||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||
v.satisfy_exact("user_id = %s" % user_id)
|
||||
v.satisfy_exact("guest = true")
|
||||
if verify_expiry:
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
|
||||
@@ -42,8 +42,9 @@ class Codes(object):
|
||||
TOO_LARGE = "M_TOO_LARGE"
|
||||
EXCLUSIVE = "M_EXCLUSIVE"
|
||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
THREEPID_IN_USE = "THREEPID_IN_USE"
|
||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
||||
@@ -147,7 +147,7 @@ class SynapseHomeServer(HomeServer):
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path, self.auth, self.content_addr
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
|
||||
@@ -266,10 +266,9 @@ def setup(config_options):
|
||||
HomeServer
|
||||
"""
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
config = HomeServerConfig.load_or_generate_config(
|
||||
"Synapse Homeserver",
|
||||
config_options,
|
||||
generate_section="Homeserver"
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
@@ -302,7 +301,6 @@ def setup(config_options):
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
content_addr=config.content_addr,
|
||||
version_string=version_string,
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
@@ -18,10 +18,8 @@ import synapse
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.emailconfig import EmailConfig
|
||||
from synapse.config.key import KeyConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
@@ -43,98 +41,13 @@ from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.pusher")
|
||||
|
||||
|
||||
class SlaveConfig(DatabaseConfig):
|
||||
def read_config(self, config):
|
||||
self.replication_url = config["replication_url"]
|
||||
self.server_name = config["server_name"]
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||
)
|
||||
self.user_agent_suffix = None
|
||||
self.start_pushers = True
|
||||
self.listeners = config["listeners"]
|
||||
self.soft_file_limit = config.get("soft_file_limit")
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.public_baseurl = config["public_baseurl"]
|
||||
|
||||
thresholds = config.get("gc_thresholds", None)
|
||||
if thresholds is not None:
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
self.gc_thresholds = (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
else:
|
||||
self.gc_thresholds = None
|
||||
|
||||
# some things used by the auth handler but not actually used in the
|
||||
# pusher codebase
|
||||
self.bcrypt_rounds = None
|
||||
self.ldap_enabled = None
|
||||
self.ldap_server = None
|
||||
self.ldap_port = None
|
||||
self.ldap_tls = None
|
||||
self.ldap_search_base = None
|
||||
self.ldap_search_property = None
|
||||
self.ldap_email_property = None
|
||||
self.ldap_full_name_property = None
|
||||
|
||||
# We would otherwise try to use the registration shared secret as the
|
||||
# macaroon shared secret if there was no macaroon_shared_secret, but
|
||||
# that means pulling in RegistrationConfig too. We don't need to be
|
||||
# backwards compaitible in the pusher codebase so just make people set
|
||||
# macaroon_shared_secret. We set this to None to prevent it referencing
|
||||
# an undefined key.
|
||||
self.registration_shared_secret = None
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("pusher.pid")
|
||||
return """\
|
||||
# Slave configuration
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
listeners: []
|
||||
# Enable a ssh manhole listener on the pusher.
|
||||
# - type: manhole
|
||||
# port: {manhole_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# Enable a metric listener on the pusher.
|
||||
# - type: http
|
||||
# port: {metrics_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# resources:
|
||||
# - names: ["metrics"]
|
||||
# compress: False
|
||||
|
||||
report_stats: False
|
||||
|
||||
daemonize: False
|
||||
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
""" % locals()
|
||||
|
||||
|
||||
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
|
||||
pass
|
||||
|
||||
|
||||
class PusherSlaveStore(
|
||||
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
|
||||
SlavedAccountDataStore
|
||||
@@ -199,7 +112,7 @@ class PusherServer(HomeServer):
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
http_client = self.get_simple_http_client()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
url = replication_url + "/remove_pushers"
|
||||
return http_client.post_json_get_json(url, {
|
||||
"remove": [{
|
||||
@@ -232,8 +145,8 @@ class PusherServer(HomeServer):
|
||||
)
|
||||
logger.info("Synapse pusher now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
for listener in self.config.listeners:
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
@@ -253,7 +166,7 @@ class PusherServer(HomeServer):
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
pusher_pool = self.get_pusherpool()
|
||||
clock = self.get_clock()
|
||||
|
||||
@@ -329,19 +242,30 @@ class PusherServer(HomeServer):
|
||||
yield sleep(30)
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
def start(config_options):
|
||||
try:
|
||||
config = PusherSlaveConfig.load_config(
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse pusher", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
sys.exit(0)
|
||||
assert config.worker_app == "synapse.app.pusher"
|
||||
|
||||
config.setup_logging()
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
if config.start_pushers:
|
||||
sys.stderr.write(
|
||||
"\nThe pushers must be disabled in the main synapse process"
|
||||
"\nbefore they can be run in a separate worker."
|
||||
"\nPlease add ``start_pushers: false`` to the main config"
|
||||
"\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.start_pushers = True
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
@@ -354,11 +278,15 @@ def setup(config_options):
|
||||
)
|
||||
|
||||
ps.setup()
|
||||
ps.start_listening()
|
||||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
if ps.config.gc_thresholds:
|
||||
gc.set_threshold(*ps.config.gc_thresholds)
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ps.replicate()
|
||||
@@ -367,30 +295,20 @@ def setup(config_options):
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return ps
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-pusher",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ps = setup(sys.argv[1:])
|
||||
|
||||
if ps.config.daemonize:
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
if ps.config.gc_thresholds:
|
||||
gc.set_threshold(*ps.config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-pusher",
|
||||
pid=ps.config.pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
ps = start(sys.argv[1:])
|
||||
|
||||
@@ -18,9 +18,8 @@ import synapse
|
||||
|
||||
from synapse.api.constants import EventTypes, PresenceState
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.appservice import AppServiceConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.http.site import SynapseSite
|
||||
@@ -63,70 +62,6 @@ import ujson as json
|
||||
logger = logging.getLogger("synapse.app.synchrotron")
|
||||
|
||||
|
||||
class SynchrotronConfig(DatabaseConfig, LoggingConfig, AppServiceConfig):
|
||||
def read_config(self, config):
|
||||
self.replication_url = config["replication_url"]
|
||||
self.server_name = config["server_name"]
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||
)
|
||||
self.user_agent_suffix = None
|
||||
self.listeners = config["listeners"]
|
||||
self.soft_file_limit = config.get("soft_file_limit")
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.macaroon_secret_key = config["macaroon_secret_key"]
|
||||
self.expire_access_token = config.get("expire_access_token", False)
|
||||
|
||||
thresholds = config.get("gc_thresholds", None)
|
||||
if thresholds is not None:
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
self.gc_thresholds = (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
else:
|
||||
self.gc_thresholds = None
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("synchroton.pid")
|
||||
return """\
|
||||
# Slave configuration
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
listeners:
|
||||
# Enable a /sync listener on the synchrontron
|
||||
#- type: http
|
||||
# port: {http_port}
|
||||
# bind_address: ""
|
||||
# Enable a ssh manhole listener on the synchrotron
|
||||
# - type: manhole
|
||||
# port: {manhole_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# Enable a metric listener on the synchrotron
|
||||
# - type: http
|
||||
# port: {metrics_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# resources:
|
||||
# - names: ["metrics"]
|
||||
# compress: False
|
||||
|
||||
report_stats: False
|
||||
|
||||
daemonize: False
|
||||
|
||||
pid_file: %(pid_file)s
|
||||
""" % locals()
|
||||
|
||||
|
||||
class SynchrotronSlavedStore(
|
||||
SlavedPushRuleStore,
|
||||
SlavedEventStore,
|
||||
@@ -163,7 +98,7 @@ class SynchrotronPresence(object):
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.syncing_users_url = hs.config.replication_url + "/syncing_users"
|
||||
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
@@ -350,8 +285,8 @@ class SynchrotronServer(HomeServer):
|
||||
)
|
||||
logger.info("Synapse synchrotron now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
for listener in self.config.listeners:
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
@@ -371,7 +306,7 @@ class SynchrotronServer(HomeServer):
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
clock = self.get_clock()
|
||||
notifier = self.get_notifier()
|
||||
presence_handler = self.get_presence_handler()
|
||||
@@ -470,19 +405,18 @@ class SynchrotronServer(HomeServer):
|
||||
return SynchrotronTyping(self)
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
def start(config_options):
|
||||
try:
|
||||
config = SynchrotronConfig.load_config(
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse synchrotron", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
sys.exit(0)
|
||||
assert config.worker_app == "synapse.app.synchrotron"
|
||||
|
||||
config.setup_logging()
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
@@ -496,11 +430,15 @@ def setup(config_options):
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.start_listening()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
change_resource_limit(ss.config.soft_file_limit)
|
||||
if ss.config.gc_thresholds:
|
||||
ss.set_threshold(*ss.config.gc_thresholds)
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
@@ -508,30 +446,20 @@ def setup(config_options):
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return ss
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-synchrotron",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ss = setup(sys.argv[1:])
|
||||
|
||||
if ss.config.daemonize:
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(ss.config.soft_file_limit)
|
||||
if ss.config.gc_thresholds:
|
||||
gc.set_threshold(*ss.config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-synchrotron",
|
||||
pid=ss.config.pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
start(sys.argv[1:])
|
||||
|
||||
@@ -14,11 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
import os.path
|
||||
import subprocess
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
|
||||
@@ -28,60 +31,181 @@ RED = "\x1b[1;31m"
|
||||
NORMAL = "\x1b[m"
|
||||
|
||||
|
||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||
if colour == NORMAL:
|
||||
stream.write(message + "\n")
|
||||
else:
|
||||
stream.write(colour + message + NORMAL + "\n")
|
||||
|
||||
|
||||
def start(configfile):
|
||||
print ("Starting ...")
|
||||
write("Starting ...")
|
||||
args = SYNAPSE
|
||||
args.extend(["--daemonize", "-c", configfile])
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
print (GREEN + "started" + NORMAL)
|
||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print (
|
||||
RED +
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode +
|
||||
NORMAL
|
||||
write(
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||
colour=RED,
|
||||
)
|
||||
|
||||
|
||||
def stop(pidfile):
|
||||
def start_worker(app, configfile, worker_configfile):
|
||||
args = [
|
||||
"python", "-B",
|
||||
"-m", app,
|
||||
"-c", configfile,
|
||||
"-c", worker_configfile
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
write(
|
||||
"error starting %s(%r) (exit code: %d); see above for logs" % (
|
||||
app, worker_configfile, e.returncode,
|
||||
),
|
||||
colour=RED,
|
||||
)
|
||||
|
||||
|
||||
def stop(pidfile, app):
|
||||
if os.path.exists(pidfile):
|
||||
pid = int(open(pidfile).read())
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
print (GREEN + "stopped" + NORMAL)
|
||||
write("stopped %s" % (app,), colour=GREEN)
|
||||
|
||||
|
||||
Worker = collections.namedtuple("Worker", [
|
||||
"app", "configfile", "pidfile", "cache_factor"
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
|
||||
|
||||
if not os.path.exists(configfile):
|
||||
sys.stderr.write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), configfile
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"action",
|
||||
choices=["start", "stop", "restart"],
|
||||
help="whether to start, stop or restart the synapse",
|
||||
)
|
||||
parser.add_argument(
|
||||
"configfile",
|
||||
nargs="?",
|
||||
default="homeserver.yaml",
|
||||
help="the homeserver config file, defaults to homserver.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--worker",
|
||||
metavar="WORKERCONFIG",
|
||||
help="start or stop a single worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--all-processes",
|
||||
metavar="WORKERCONFIGDIR",
|
||||
help="start or stop all the workers in the given directory"
|
||||
" and the main synapse process",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
if options.worker and options.all_processes:
|
||||
write(
|
||||
'Cannot use "--worker" with "--all-processes"',
|
||||
stream=sys.stderr
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
config = yaml.load(open(configfile))
|
||||
configfile = options.configfile
|
||||
|
||||
if not os.path.exists(configfile):
|
||||
write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), options.configfile
|
||||
),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
with open(configfile) as stream:
|
||||
config = yaml.load(stream)
|
||||
|
||||
pidfile = config["pid_file"]
|
||||
cache_factor = config.get("synctl_cache_factor", None)
|
||||
cache_factor = config.get("synctl_cache_factor")
|
||||
start_stop_synapse = True
|
||||
|
||||
if cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
|
||||
action = sys.argv[1] if sys.argv[1:] else "usage"
|
||||
if action == "start":
|
||||
start(configfile)
|
||||
elif action == "stop":
|
||||
stop(pidfile)
|
||||
elif action == "restart":
|
||||
stop(pidfile)
|
||||
start(configfile)
|
||||
else:
|
||||
sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],))
|
||||
sys.exit(1)
|
||||
worker_configfiles = []
|
||||
if options.worker:
|
||||
start_stop_synapse = False
|
||||
worker_configfile = options.worker
|
||||
if not os.path.exists(worker_configfile):
|
||||
write(
|
||||
"No worker config found at %r" % (worker_configfile,),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
worker_configfiles.append(worker_configfile)
|
||||
|
||||
if options.all_processes:
|
||||
worker_configdir = options.all_processes
|
||||
if not os.path.isdir(worker_configdir):
|
||||
write(
|
||||
"No worker config directory found at %r" % (worker_configdir,),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
worker_configfiles.extend(sorted(glob.glob(
|
||||
os.path.join(worker_configdir, "*.yaml")
|
||||
)))
|
||||
|
||||
workers = []
|
||||
for worker_configfile in worker_configfiles:
|
||||
with open(worker_configfile) as stream:
|
||||
worker_config = yaml.load(stream)
|
||||
worker_app = worker_config["worker_app"]
|
||||
worker_pidfile = worker_config["worker_pid_file"]
|
||||
worker_daemonize = worker_config["worker_daemonize"]
|
||||
assert worker_daemonize # TODO print something more user friendly
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||
workers.append(Worker(
|
||||
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
|
||||
))
|
||||
|
||||
action = options.action
|
||||
|
||||
if action == "stop" or action == "restart":
|
||||
for worker in workers:
|
||||
stop(worker.pidfile, worker.app)
|
||||
|
||||
if start_stop_synapse:
|
||||
stop(pidfile, "synapse.app.homeserver")
|
||||
|
||||
# TODO: Wait for synapse to actually shutdown before starting it again
|
||||
|
||||
if action == "start" or action == "restart":
|
||||
if start_stop_synapse:
|
||||
start(configfile)
|
||||
|
||||
for worker in workers:
|
||||
if worker.cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
|
||||
|
||||
start_worker(worker.app, configfile, worker.configfile)
|
||||
|
||||
if cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
else:
|
||||
os.environ.pop("SYNAPSE_CACHE_FACTOR", None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -157,9 +157,40 @@ class Config(object):
|
||||
return default_config, config
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, description, argv, generate_section=None):
|
||||
obj = cls()
|
||||
def load_config(cls, description, argv):
|
||||
config_parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"-c", "--config-path",
|
||||
action="append",
|
||||
metavar="CONFIG_FILE",
|
||||
help="Specify config file. Can be given multiple times and"
|
||||
" may specify directories containing *.yaml files."
|
||||
)
|
||||
|
||||
config_parser.add_argument(
|
||||
"--keys-directory",
|
||||
metavar="DIRECTORY",
|
||||
help="Where files such as certs and signing keys are stored when"
|
||||
" their location is given explicitly in the config."
|
||||
" Defaults to the directory containing the last config file",
|
||||
)
|
||||
|
||||
config_args = config_parser.parse_args(argv)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
obj = cls()
|
||||
obj.read_config_files(
|
||||
config_files,
|
||||
keys_directory=config_args.keys_directory,
|
||||
generate_keys=False,
|
||||
)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def load_or_generate_config(cls, description, argv):
|
||||
config_parser = argparse.ArgumentParser(add_help=False)
|
||||
config_parser.add_argument(
|
||||
"-c", "--config-path",
|
||||
@@ -176,7 +207,7 @@ class Config(object):
|
||||
config_parser.add_argument(
|
||||
"--report-stats",
|
||||
action="store",
|
||||
help="Stuff",
|
||||
help="Whether the generated config reports anonymized usage statistics",
|
||||
choices=["yes", "no"]
|
||||
)
|
||||
config_parser.add_argument(
|
||||
@@ -197,36 +228,11 @@ class Config(object):
|
||||
)
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
generate_keys = config_args.generate_keys
|
||||
|
||||
config_files = []
|
||||
if config_args.config_path:
|
||||
for config_path in config_args.config_path:
|
||||
if os.path.isdir(config_path):
|
||||
# We accept specifying directories as config paths, we search
|
||||
# inside that directory for all files matching *.yaml, and then
|
||||
# we apply them in *sorted* order.
|
||||
files = []
|
||||
for entry in os.listdir(config_path):
|
||||
entry_path = os.path.join(config_path, entry)
|
||||
if not os.path.isfile(entry_path):
|
||||
print (
|
||||
"Found subdirectory in config directory: %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
if not entry.endswith(".yaml"):
|
||||
print (
|
||||
"Found file in config directory that does not"
|
||||
" end in '.yaml': %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
files.append(entry_path)
|
||||
|
||||
config_files.extend(sorted(files))
|
||||
else:
|
||||
config_files.append(config_path)
|
||||
obj = cls()
|
||||
|
||||
if config_args.generate_config:
|
||||
if config_args.report_stats is None:
|
||||
@@ -299,28 +305,43 @@ class Config(object):
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
if config_args.keys_directory:
|
||||
config_dir_path = config_args.keys_directory
|
||||
else:
|
||||
config_dir_path = os.path.dirname(config_args.config_path[-1])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
obj.read_config_files(
|
||||
config_files,
|
||||
keys_directory=config_args.keys_directory,
|
||||
generate_keys=generate_keys,
|
||||
)
|
||||
|
||||
if generate_keys:
|
||||
return None
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
|
||||
def read_config_files(self, config_files, keys_directory=None,
|
||||
generate_keys=False):
|
||||
if not keys_directory:
|
||||
keys_directory = os.path.dirname(config_files[-1])
|
||||
|
||||
config_dir_path = os.path.abspath(keys_directory)
|
||||
|
||||
specified_config = {}
|
||||
for config_file in config_files:
|
||||
yaml_config = cls.read_config_file(config_file)
|
||||
yaml_config = self.read_config_file(config_file)
|
||||
specified_config.update(yaml_config)
|
||||
|
||||
if "server_name" not in specified_config:
|
||||
raise ConfigError(MISSING_SERVER_NAME)
|
||||
|
||||
server_name = specified_config["server_name"]
|
||||
_, config = obj.generate_config(
|
||||
_, config = self.generate_config(
|
||||
config_dir_path=config_dir_path,
|
||||
server_name=server_name,
|
||||
is_generating_file=False,
|
||||
)
|
||||
config.pop("log_config")
|
||||
config.update(specified_config)
|
||||
|
||||
if "report_stats" not in config:
|
||||
raise ConfigError(
|
||||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
|
||||
@@ -328,11 +349,51 @@ class Config(object):
|
||||
)
|
||||
|
||||
if generate_keys:
|
||||
obj.invoke_all("generate_files", config)
|
||||
self.invoke_all("generate_files", config)
|
||||
return
|
||||
|
||||
obj.invoke_all("read_config", config)
|
||||
self.invoke_all("read_config", config)
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
def find_config_files(search_paths):
|
||||
"""Finds config files using a list of search paths. If a path is a file
|
||||
then that file path is added to the list. If a search path is a directory
|
||||
then all the "*.yaml" files in that directory are added to the list in
|
||||
sorted order.
|
||||
|
||||
Args:
|
||||
search_paths(list(str)): A list of paths to search.
|
||||
|
||||
Returns:
|
||||
list(str): A list of file paths.
|
||||
"""
|
||||
|
||||
config_files = []
|
||||
if search_paths:
|
||||
for config_path in search_paths:
|
||||
if os.path.isdir(config_path):
|
||||
# We accept specifying directories as config paths, we search
|
||||
# inside that directory for all files matching *.yaml, and then
|
||||
# we apply them in *sorted* order.
|
||||
files = []
|
||||
for entry in os.listdir(config_path):
|
||||
entry_path = os.path.join(config_path, entry)
|
||||
if not os.path.isfile(entry_path):
|
||||
print (
|
||||
"Found subdirectory in config directory: %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
if not entry.endswith(".yaml"):
|
||||
print (
|
||||
"Found file in config directory that does not"
|
||||
" end in '.yaml': %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
files.append(entry_path)
|
||||
|
||||
config_files.extend(sorted(files))
|
||||
else:
|
||||
config_files.append(config_path)
|
||||
return config_files
|
||||
|
||||
@@ -27,6 +27,7 @@ class CaptchaConfig(Config):
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
## Captcha ##
|
||||
# See docs/CAPTCHA_SETUP for full details of configuring this.
|
||||
|
||||
# This Home Server's ReCAPTCHA public key.
|
||||
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
||||
|
||||
@@ -32,13 +32,15 @@ from .password import PasswordConfig
|
||||
from .jwt import JWTConfig
|
||||
from .ldap import LDAPConfig
|
||||
from .emailconfig import EmailConfig
|
||||
from .workers import WorkerConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,):
|
||||
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,
|
||||
WorkerConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -13,40 +13,88 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
||||
MISSING_LDAP3 = (
|
||||
"Missing ldap3 library. This is required for LDAP Authentication."
|
||||
)
|
||||
|
||||
|
||||
class LDAPMode(object):
|
||||
SIMPLE = "simple",
|
||||
SEARCH = "search",
|
||||
|
||||
LIST = (SIMPLE, SEARCH)
|
||||
|
||||
|
||||
class LDAPConfig(Config):
|
||||
def read_config(self, config):
|
||||
ldap_config = config.get("ldap_config", None)
|
||||
if ldap_config:
|
||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||
self.ldap_server = ldap_config["server"]
|
||||
self.ldap_port = ldap_config["port"]
|
||||
self.ldap_tls = ldap_config.get("tls", False)
|
||||
self.ldap_search_base = ldap_config["search_base"]
|
||||
self.ldap_search_property = ldap_config["search_property"]
|
||||
self.ldap_email_property = ldap_config["email_property"]
|
||||
self.ldap_full_name_property = ldap_config["full_name_property"]
|
||||
else:
|
||||
self.ldap_enabled = False
|
||||
self.ldap_server = None
|
||||
self.ldap_port = None
|
||||
self.ldap_tls = False
|
||||
self.ldap_search_base = None
|
||||
self.ldap_search_property = None
|
||||
self.ldap_email_property = None
|
||||
self.ldap_full_name_property = None
|
||||
ldap_config = config.get("ldap_config", {})
|
||||
|
||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||
|
||||
if self.ldap_enabled:
|
||||
# verify dependencies are available
|
||||
try:
|
||||
import ldap3
|
||||
ldap3 # to stop unused lint
|
||||
except ImportError:
|
||||
raise ConfigError(MISSING_LDAP3)
|
||||
|
||||
self.ldap_mode = LDAPMode.SIMPLE
|
||||
|
||||
# verify config sanity
|
||||
self.require_keys(ldap_config, [
|
||||
"uri",
|
||||
"base",
|
||||
"attributes",
|
||||
])
|
||||
|
||||
self.ldap_uri = ldap_config["uri"]
|
||||
self.ldap_start_tls = ldap_config.get("start_tls", False)
|
||||
self.ldap_base = ldap_config["base"]
|
||||
self.ldap_attributes = ldap_config["attributes"]
|
||||
|
||||
if "bind_dn" in ldap_config:
|
||||
self.ldap_mode = LDAPMode.SEARCH
|
||||
self.require_keys(ldap_config, [
|
||||
"bind_dn",
|
||||
"bind_password",
|
||||
])
|
||||
|
||||
self.ldap_bind_dn = ldap_config["bind_dn"]
|
||||
self.ldap_bind_password = ldap_config["bind_password"]
|
||||
self.ldap_filter = ldap_config.get("filter", None)
|
||||
|
||||
# verify attribute lookup
|
||||
self.require_keys(ldap_config['attributes'], [
|
||||
"uid",
|
||||
"name",
|
||||
"mail",
|
||||
])
|
||||
|
||||
def require_keys(self, config, required):
|
||||
missing = [key for key in required if key not in config]
|
||||
if missing:
|
||||
raise ConfigError(
|
||||
"LDAP enabled but missing required config values: {}".format(
|
||||
", ".join(missing)
|
||||
)
|
||||
)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# ldap_config:
|
||||
# enabled: true
|
||||
# server: "ldap://localhost"
|
||||
# port: 389
|
||||
# tls: false
|
||||
# search_base: "ou=Users,dc=example,dc=com"
|
||||
# search_property: "cn"
|
||||
# email_property: "email"
|
||||
# full_name_property: "givenName"
|
||||
# uri: "ldap://ldap.example.com:389"
|
||||
# start_tls: true
|
||||
# base: "ou=users,dc=example,dc=com"
|
||||
# attributes:
|
||||
# uid: "cn"
|
||||
# mail: "email"
|
||||
# name: "givenName"
|
||||
# #bind_dn:
|
||||
# #bind_password:
|
||||
# #filter: "(objectClass=posixAccount)"
|
||||
"""
|
||||
|
||||
@@ -126,54 +126,58 @@ class LoggingConfig(Config):
|
||||
)
|
||||
|
||||
def setup_logging(self):
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
)
|
||||
if self.log_config is None:
|
||||
setup_logging(self.log_config, self.log_file, self.verbosity)
|
||||
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if self.verbosity:
|
||||
level = logging.DEBUG
|
||||
if self.verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
logger = logging.getLogger('')
|
||||
logger.setLevel(level)
|
||||
def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
)
|
||||
if log_config is None:
|
||||
|
||||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if verbosity:
|
||||
level = logging.DEBUG
|
||||
if verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
formatter = logging.Formatter(log_format)
|
||||
if self.log_file:
|
||||
# TODO: Customisable file size / backup count
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
||||
)
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
logger = logging.getLogger('')
|
||||
logger.setLevel(level)
|
||||
|
||||
def sighup(signum, stack):
|
||||
logger.info("Closing log file due to SIGHUP")
|
||||
handler.doRollover()
|
||||
logger.info("Opened new log file due to SIGHUP")
|
||||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
||||
|
||||
# TODO(paul): obviously this is a terrible mechanism for
|
||||
# stealing SIGHUP, because it means no other part of synapse
|
||||
# can use it instead. If we want to catch SIGHUP anywhere
|
||||
# else as well, I'd suggest we find a nicer way to broadcast
|
||||
# it around.
|
||||
if getattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, sighup)
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
formatter = logging.Formatter(log_format)
|
||||
if log_file:
|
||||
# TODO: Customisable file size / backup count
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
||||
)
|
||||
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
def sighup(signum, stack):
|
||||
logger.info("Closing log file due to SIGHUP")
|
||||
handler.doRollover()
|
||||
logger.info("Opened new log file due to SIGHUP")
|
||||
|
||||
logger.addHandler(handler)
|
||||
# TODO(paul): obviously this is a terrible mechanism for
|
||||
# stealing SIGHUP, because it means no other part of synapse
|
||||
# can use it instead. If we want to catch SIGHUP anywhere
|
||||
# else as well, I'd suggest we find a nicer way to broadcast
|
||||
# it around.
|
||||
if getattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, sighup)
|
||||
else:
|
||||
with open(self.log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
observer = PythonLoggingObserver()
|
||||
observer.start()
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
|
||||
logger.addHandler(handler)
|
||||
else:
|
||||
with open(log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
|
||||
observer = PythonLoggingObserver()
|
||||
observer.start()
|
||||
|
||||
@@ -23,10 +23,14 @@ class PasswordConfig(Config):
|
||||
def read_config(self, config):
|
||||
password_config = config.get("password_config", {})
|
||||
self.password_enabled = password_config.get("enabled", True)
|
||||
self.password_pepper = password_config.get("pepper", "")
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable password for login.
|
||||
password_config:
|
||||
enabled: true
|
||||
# Uncomment and change to a secret random string for extra security.
|
||||
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||
#pepper: ""
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,7 @@ class ServerConfig(Config):
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||
self.public_baseurl = config.get("public_baseurl")
|
||||
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
|
||||
|
||||
@@ -38,19 +38,7 @@ class ServerConfig(Config):
|
||||
|
||||
self.listeners = config.get("listeners", [])
|
||||
|
||||
thresholds = config.get("gc_thresholds", None)
|
||||
if thresholds is not None:
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
self.gc_thresholds = (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
else:
|
||||
self.gc_thresholds = None
|
||||
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
|
||||
|
||||
bind_port = config.get("bind_port")
|
||||
if bind_port:
|
||||
@@ -119,26 +107,6 @@ class ServerConfig(Config):
|
||||
]
|
||||
})
|
||||
|
||||
# Attempt to guess the content_addr for the v0 content repostitory
|
||||
content_addr = config.get("content_addr")
|
||||
if not content_addr:
|
||||
for listener in self.listeners:
|
||||
if listener["type"] == "http" and not listener.get("tls", False):
|
||||
unsecure_port = listener["port"]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Could not determine 'content_addr'")
|
||||
|
||||
host = self.server_name
|
||||
if ':' not in host:
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
else:
|
||||
host = host.split(':')[0]
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
content_addr = "http://%s" % (host,)
|
||||
|
||||
self.content_addr = content_addr
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
if ":" in server_name:
|
||||
bind_port = int(server_name.split(":")[1])
|
||||
@@ -181,7 +149,6 @@ class ServerConfig(Config):
|
||||
# room directory.
|
||||
# secondary_directory_servers:
|
||||
# - matrix.org
|
||||
# - vector.im
|
||||
|
||||
# List of ports that Synapse should listen on, their purpose and their
|
||||
# configuration.
|
||||
@@ -264,3 +231,20 @@ class ServerConfig(Config):
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
" service on the given port.")
|
||||
|
||||
|
||||
def read_gc_thresholds(thresholds):
|
||||
"""Reads the three integer thresholds for garbage collection. Ensures that
|
||||
the thresholds are integers if thresholds are supplied.
|
||||
"""
|
||||
if thresholds is None:
|
||||
return None
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
return (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
|
||||
31
synapse/config/workers.py
Normal file
31
synapse/config/workers.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class WorkerConfig(Config):
|
||||
"""The workers are processes run separately to the main synapse process.
|
||||
They have their own pid_file and listener configuration. They use the
|
||||
replication_url to talk to the main synapse process."""
|
||||
|
||||
def read_config(self, config):
|
||||
self.worker_app = config.get("worker_app")
|
||||
self.worker_listeners = config.get("worker_listeners")
|
||||
self.worker_daemonize = config.get("worker_daemonize")
|
||||
self.worker_pid_file = config.get("worker_pid_file")
|
||||
self.worker_log_file = config.get("worker_log_file")
|
||||
self.worker_log_config = config.get("worker_log_config")
|
||||
self.worker_replication_url = config.get("worker_replication_url")
|
||||
@@ -31,6 +31,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FederationBase(object):
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||
include_none=False):
|
||||
|
||||
@@ -52,6 +52,8 @@ sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
||||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationClient, self).__init__(hs)
|
||||
|
||||
def start_get_pdu_cache(self):
|
||||
self._get_pdu_cache = ExpiringCache(
|
||||
|
||||
@@ -19,6 +19,7 @@ from twisted.internet import defer
|
||||
from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.events import FrozenEvent
|
||||
import synapse.metrics
|
||||
@@ -44,6 +45,12 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
|
||||
|
||||
|
||||
class FederationServer(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationServer, self).__init__(hs)
|
||||
|
||||
self._room_pdu_linearizer = Linearizer()
|
||||
self._server_linearizer = Linearizer()
|
||||
|
||||
def set_handler(self, handler):
|
||||
"""Sets the handler that the replication layer will use to communicate
|
||||
receipt of new PDUs from other home servers. The required methods are
|
||||
@@ -83,11 +90,14 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_backfill_request(self, origin, room_id, versions, limit):
|
||||
pdus = yield self.handler.on_backfill_request(
|
||||
origin, room_id, versions, limit
|
||||
)
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
pdus = yield self.handler.on_backfill_request(
|
||||
origin, room_id, versions, limit
|
||||
)
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
res = self._transaction_from_pdus(pdus).get_dict()
|
||||
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
@@ -178,24 +188,28 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_context_state_request(self, origin, room_id, event_id):
|
||||
if event_id:
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
origin, room_id, event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
|
||||
for event in auth_chain:
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
if event_id:
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
origin, room_id, event_id,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
|
||||
for event in auth_chain:
|
||||
# We sign these again because there was a bug where we
|
||||
# incorrectly signed things the first time round
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
defer.returnValue((200, {
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
@@ -274,14 +288,16 @@ class FederationServer(FederationBase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_event_auth(self, origin, room_id, event_id):
|
||||
time_now = self._clock.time_msec()
|
||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||
defer.returnValue((200, {
|
||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
}))
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
time_now = self._clock.time_msec()
|
||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||
res = {
|
||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
}
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_auth_request(self, origin, content, event_id):
|
||||
def on_query_auth_request(self, origin, content, room_id, event_id):
|
||||
"""
|
||||
Content is a dict with keys::
|
||||
auth_chain (list): A list of events that give the auth chain.
|
||||
@@ -300,32 +316,33 @@ class FederationServer(FederationBase):
|
||||
Returns:
|
||||
Deferred: Results in `dict` with the same format as `content`
|
||||
"""
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
origin, auth_chain, outlier=True
|
||||
)
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
origin, auth_chain, outlier=True
|
||||
)
|
||||
|
||||
ret = yield self.handler.on_query_auth(
|
||||
origin,
|
||||
event_id,
|
||||
signed_auth,
|
||||
content.get("rejects", []),
|
||||
content.get("missing", []),
|
||||
)
|
||||
ret = yield self.handler.on_query_auth(
|
||||
origin,
|
||||
event_id,
|
||||
signed_auth,
|
||||
content.get("rejects", []),
|
||||
content.get("missing", []),
|
||||
)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
send_content = {
|
||||
"auth_chain": [
|
||||
e.get_pdu_json(time_now)
|
||||
for e in ret["auth_chain"]
|
||||
],
|
||||
"rejects": ret.get("rejects", []),
|
||||
"missing": ret.get("missing", []),
|
||||
}
|
||||
time_now = self._clock.time_msec()
|
||||
send_content = {
|
||||
"auth_chain": [
|
||||
e.get_pdu_json(time_now)
|
||||
for e in ret["auth_chain"]
|
||||
],
|
||||
"rejects": ret.get("rejects", []),
|
||||
"missing": ret.get("missing", []),
|
||||
}
|
||||
|
||||
defer.returnValue(
|
||||
(200, send_content)
|
||||
@@ -377,11 +394,24 @@ class FederationServer(FederationBase):
|
||||
@log_function
|
||||
def on_get_missing_events(self, origin, room_id, earliest_events,
|
||||
latest_events, limit, min_depth):
|
||||
missing_events = yield self.handler.on_get_missing_events(
|
||||
origin, room_id, earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
logger.info(
|
||||
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
||||
" limit: %d, min_depth: %d",
|
||||
earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
missing_events = yield self.handler.on_get_missing_events(
|
||||
origin, room_id, earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
if len(missing_events) < 5:
|
||||
logger.info(
|
||||
"Returning %d events: %r", len(missing_events), missing_events
|
||||
)
|
||||
else:
|
||||
logger.info("Returning %d events", len(missing_events))
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
defer.returnValue({
|
||||
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
|
||||
@@ -481,42 +511,59 @@ class FederationServer(FederationBase):
|
||||
pdu.internal_metadata.outlier = True
|
||||
elif min_depth and pdu.depth > min_depth:
|
||||
if get_missing and prevs - seen:
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
# If we're missing stuff, ensure we only fetch stuff one
|
||||
# at a time.
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
# We recalculate seen, since it may have changed.
|
||||
have_seen = yield self.store.have_events(prevs)
|
||||
seen = set(have_seen.keys())
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest |= seen
|
||||
if prevs - seen:
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
missing_events = yield self.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
latest_events=[pdu],
|
||||
limit=10,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest |= seen
|
||||
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
missing_events.sort(key=lambda x: x.depth)
|
||||
logger.info(
|
||||
"Missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
|
||||
for e in missing_events:
|
||||
yield self._handle_new_pdu(
|
||||
origin,
|
||||
e,
|
||||
get_missing=False
|
||||
)
|
||||
missing_events = yield self.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
latest_events=[pdu],
|
||||
limit=10,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
missing_events.sort(key=lambda x: x.depth)
|
||||
|
||||
for e in missing_events:
|
||||
yield self._handle_new_pdu(
|
||||
origin,
|
||||
e,
|
||||
get_missing=False
|
||||
)
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
|
||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||
seen = set(have_seen.keys())
|
||||
if prevs - seen:
|
||||
logger.info(
|
||||
"Still missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
fetch_state = True
|
||||
|
||||
if fetch_state:
|
||||
|
||||
@@ -72,5 +72,7 @@ class ReplicationLayer(FederationClient, FederationServer):
|
||||
|
||||
self.hs = hs
|
||||
|
||||
super(ReplicationLayer, self).__init__(hs)
|
||||
|
||||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
@@ -37,7 +37,7 @@ class TransportLayerServer(JsonResource):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
super(TransportLayerServer, self).__init__(hs)
|
||||
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
|
||||
|
||||
self.authenticator = Authenticator(hs)
|
||||
self.ratelimiter = FederationRateLimiter(
|
||||
@@ -388,7 +388,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, context, event_id):
|
||||
new_content = yield self.handler.on_query_auth_request(
|
||||
origin, content, event_id
|
||||
origin, content, context, event_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
@@ -528,15 +528,10 @@ class PublicRoomList(BaseFederationServlet):
|
||||
PATH = "/publicRooms"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, origin, content, query):
|
||||
data = yield self.room_list_handler.get_local_public_room_list()
|
||||
defer.returnValue((200, data))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
SERVLET_CLASSES = (
|
||||
FederationSendServlet,
|
||||
|
||||
@@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.config.ldap import LDAPMode
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
@@ -28,6 +29,12 @@ import bcrypt
|
||||
import pymacaroons
|
||||
import simplejson
|
||||
|
||||
try:
|
||||
import ldap3
|
||||
except ImportError:
|
||||
ldap3 = None
|
||||
pass
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
|
||||
@@ -50,17 +57,20 @@ class AuthHandler(BaseHandler):
|
||||
self.INVALID_TOKEN_HTTP_STATUS = 401
|
||||
|
||||
self.ldap_enabled = hs.config.ldap_enabled
|
||||
self.ldap_server = hs.config.ldap_server
|
||||
self.ldap_port = hs.config.ldap_port
|
||||
self.ldap_tls = hs.config.ldap_tls
|
||||
self.ldap_search_base = hs.config.ldap_search_base
|
||||
self.ldap_search_property = hs.config.ldap_search_property
|
||||
self.ldap_email_property = hs.config.ldap_email_property
|
||||
self.ldap_full_name_property = hs.config.ldap_full_name_property
|
||||
|
||||
if self.ldap_enabled is True:
|
||||
import ldap
|
||||
logger.info("Import ldap version: %s", ldap.__version__)
|
||||
if self.ldap_enabled:
|
||||
if not ldap3:
|
||||
raise RuntimeError(
|
||||
'Missing ldap3 library. This is required for LDAP Authentication.'
|
||||
)
|
||||
self.ldap_mode = hs.config.ldap_mode
|
||||
self.ldap_uri = hs.config.ldap_uri
|
||||
self.ldap_start_tls = hs.config.ldap_start_tls
|
||||
self.ldap_base = hs.config.ldap_base
|
||||
self.ldap_filter = hs.config.ldap_filter
|
||||
self.ldap_attributes = hs.config.ldap_attributes
|
||||
if self.ldap_mode == LDAPMode.SEARCH:
|
||||
self.ldap_bind_dn = hs.config.ldap_bind_dn
|
||||
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
|
||||
@@ -452,40 +462,167 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_ldap_password(self, user_id, password):
|
||||
if not self.ldap_enabled:
|
||||
logger.debug("LDAP not configured")
|
||||
""" Attempt to authenticate a user against an LDAP Server
|
||||
and register an account if none exists.
|
||||
|
||||
Returns:
|
||||
True if authentication against LDAP was successful
|
||||
"""
|
||||
|
||||
if not ldap3 or not self.ldap_enabled:
|
||||
defer.returnValue(False)
|
||||
|
||||
import ldap
|
||||
if self.ldap_mode not in LDAPMode.LIST:
|
||||
raise RuntimeError(
|
||||
'Invalid ldap mode specified: {mode}'.format(
|
||||
mode=self.ldap_mode
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Authenticating %s with LDAP" % user_id)
|
||||
try:
|
||||
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
|
||||
logger.debug("Connecting LDAP server at %s" % ldap_url)
|
||||
l = ldap.initialize(ldap_url)
|
||||
if self.ldap_tls:
|
||||
logger.debug("Initiating TLS")
|
||||
self._connection.start_tls_s()
|
||||
server = ldap3.Server(self.ldap_uri)
|
||||
logger.debug(
|
||||
"Attempting ldap connection with %s",
|
||||
self.ldap_uri
|
||||
)
|
||||
|
||||
local_name = UserID.from_string(user_id).localpart
|
||||
|
||||
dn = "%s=%s, %s" % (
|
||||
self.ldap_search_property,
|
||||
local_name,
|
||||
self.ldap_search_base)
|
||||
logger.debug("DN for LDAP authentication: %s" % dn)
|
||||
|
||||
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
|
||||
|
||||
if not (yield self.does_user_exist(user_id)):
|
||||
handler = self.hs.get_handlers().registration_handler
|
||||
user_id, access_token = (
|
||||
yield handler.register(localpart=local_name)
|
||||
localpart = UserID.from_string(user_id).localpart
|
||||
if self.ldap_mode == LDAPMode.SIMPLE:
|
||||
# bind with the the local users ldap credentials
|
||||
bind_dn = "{prop}={value},{base}".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart,
|
||||
base=self.ldap_base
|
||||
)
|
||||
conn = ldap3.Connection(server, bind_dn, password)
|
||||
logger.debug(
|
||||
"Established ldap connection in simple mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded ldap connection in simple mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
conn.bind()
|
||||
|
||||
elif self.ldap_mode == LDAPMode.SEARCH:
|
||||
# connect with preconfigured credentials and search for local user
|
||||
conn = ldap3.Connection(
|
||||
server,
|
||||
self.ldap_bind_dn,
|
||||
self.ldap_bind_password
|
||||
)
|
||||
logger.debug(
|
||||
"Established ldap connection in search mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded ldap connection in search mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
conn.bind()
|
||||
|
||||
# find matching dn
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
if self.ldap_filter:
|
||||
query = "(&{query}{filter})".format(
|
||||
query=query,
|
||||
filter=self.ldap_filter
|
||||
)
|
||||
logger.debug("ldap search filter: %s", query)
|
||||
result = conn.search(self.ldap_base, query)
|
||||
|
||||
if result and len(conn.response) == 1:
|
||||
# found exactly one result
|
||||
user_dn = conn.response[0]['dn']
|
||||
logger.debug('ldap search found dn: %s', user_dn)
|
||||
|
||||
# unbind and reconnect, rebind with found dn
|
||||
conn.unbind()
|
||||
conn = ldap3.Connection(
|
||||
server,
|
||||
user_dn,
|
||||
password,
|
||||
auto_bind=True
|
||||
)
|
||||
else:
|
||||
# found 0 or > 1 results, abort!
|
||||
logger.warn(
|
||||
"ldap search returned unexpected (%d!=1) amount of results",
|
||||
len(conn.response)
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
logger.info(
|
||||
"User authenticated against ldap server: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
# check for existing account, if none exists, create one
|
||||
if not (yield self.does_user_exist(user_id)):
|
||||
# query user metadata for account creation
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
|
||||
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
|
||||
query = "(&{filter}{user_filter})".format(
|
||||
filter=query,
|
||||
user_filter=self.ldap_filter
|
||||
)
|
||||
logger.debug("ldap registration filter: %s", query)
|
||||
|
||||
result = conn.search(
|
||||
search_base=self.ldap_base,
|
||||
search_filter=query,
|
||||
attributes=[
|
||||
self.ldap_attributes['name'],
|
||||
self.ldap_attributes['mail']
|
||||
]
|
||||
)
|
||||
|
||||
if len(conn.response) == 1:
|
||||
attrs = conn.response[0]['attributes']
|
||||
mail = attrs[self.ldap_attributes['mail']][0]
|
||||
name = attrs[self.ldap_attributes['name']][0]
|
||||
|
||||
# create account
|
||||
registration_handler = self.hs.get_handlers().registration_handler
|
||||
user_id, access_token = (
|
||||
yield registration_handler.register(localpart=localpart)
|
||||
)
|
||||
|
||||
# TODO: bind email, set displayname with data from ldap directory
|
||||
|
||||
logger.info(
|
||||
"ldap registration successful: %d: %s (%s, %)",
|
||||
user_id,
|
||||
localpart,
|
||||
name,
|
||||
mail
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
"ldap registration failed: unexpected (%d!=1) amount of results",
|
||||
len(result)
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
defer.returnValue(True)
|
||||
except ldap.LDAPError, e:
|
||||
logger.warn("LDAP error: %s", e)
|
||||
except ldap3.core.exceptions.LDAPException as e:
|
||||
logger.warn("Error during ldap authentication: %s", e)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -613,7 +750,8 @@ class AuthHandler(BaseHandler):
|
||||
Returns:
|
||||
Hashed password (str).
|
||||
"""
|
||||
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
|
||||
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
"""Validates that self.hash(password) == stored_hash.
|
||||
@@ -626,6 +764,7 @@ class AuthHandler(BaseHandler):
|
||||
Whether self.hash(password) == stored_hash (bool).
|
||||
"""
|
||||
if stored_hash:
|
||||
return bcrypt.hashpw(password, stored_hash) == stored_hash
|
||||
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||
stored_hash.encode('utf-8')) == stored_hash
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -345,19 +345,21 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
if missing_auth:
|
||||
logger.info("Missing auth for backfill: %r", missing_auth)
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
|
||||
ev_infos = []
|
||||
for a in auth_events.values():
|
||||
@@ -399,7 +401,7 @@ class FederationHandler(BaseHandler):
|
||||
# previous to work out the state.
|
||||
# TODO: We can probably do something more clever here.
|
||||
yield self._handle_new_event(
|
||||
dest, event
|
||||
dest, event, backfilled=True,
|
||||
)
|
||||
|
||||
defer.returnValue(events)
|
||||
@@ -1016,13 +1018,16 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
res = results.values()
|
||||
for event in res:
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
# We sign these again because there was a bug where we
|
||||
# incorrectly signed things the first time round
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
else:
|
||||
@@ -1408,7 +1413,7 @@ class FederationHandler(BaseHandler):
|
||||
local_view = dict(auth_events)
|
||||
remote_view = dict(auth_events)
|
||||
remote_view.update({
|
||||
(d.type, d.state_key): d for d in different_events
|
||||
(d.type, d.state_key): d for d in different_events if d
|
||||
})
|
||||
|
||||
new_state, prev_state = self.state_handler.resolve_events(
|
||||
|
||||
@@ -21,7 +21,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
import json
|
||||
import logging
|
||||
@@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
|
||||
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
)
|
||||
|
||||
def _should_trust_id_server(self, id_server):
|
||||
if id_server not in self.trusted_id_servers:
|
||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||
logger.warn(
|
||||
"Trusting untrustworthy ID server %r even though it isn't"
|
||||
" in the trusted id list for testing because"
|
||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||
" is set in the config",
|
||||
id_server,
|
||||
)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def threepid_from_creds(self, creds):
|
||||
yield run_on_reactor()
|
||||
@@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler):
|
||||
else:
|
||||
raise SynapseError(400, "No client_secret in creds")
|
||||
|
||||
if id_server not in self.trusted_id_servers:
|
||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||
logger.warn(
|
||||
"Trusting untrustworthy ID server %r even though it isn't"
|
||||
" in the trusted id list for testing because"
|
||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||
" is set in the config",
|
||||
id_server,
|
||||
)
|
||||
else:
|
||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server)
|
||||
defer.returnValue(None)
|
||||
if not self._should_trust_id_server(id_server):
|
||||
logger.warn(
|
||||
'%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server
|
||||
)
|
||||
defer.returnValue(None)
|
||||
|
||||
data = {}
|
||||
try:
|
||||
@@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
|
||||
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
||||
yield run_on_reactor()
|
||||
|
||||
if not self._should_trust_id_server(id_server):
|
||||
raise SynapseError(
|
||||
400, "Untrusted ID server '%s'" % id_server,
|
||||
Codes.SERVER_NOT_TRUSTED
|
||||
)
|
||||
|
||||
params = {
|
||||
'email': email,
|
||||
'client_secret': client_secret,
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.types import (
|
||||
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute, run_on_reactor
|
||||
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.visibility import filter_events_for_client
|
||||
@@ -50,6 +50,20 @@ class MessageHandler(BaseHandler):
|
||||
self.validator = EventValidator()
|
||||
self.snapshot_cache = SnapshotCache()
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id):
|
||||
event = yield self.store.get_event(event_id)
|
||||
|
||||
if event.room_id != room_id:
|
||||
raise SynapseError(400, "Event is for wrong room.")
|
||||
|
||||
depth = event.depth
|
||||
|
||||
with (yield self.pagination_lock.write(room_id)):
|
||||
yield self.store.delete_old_state(room_id, depth)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||
as_client_event=True):
|
||||
@@ -85,17 +99,113 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (yield self.pagination_lock.read(room_id)):
|
||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
|
||||
if source_config.direction == 'b':
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
# requires that we have a topo token.
|
||||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = yield self.store.get_max_topological_token(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
if membership == Membership.LEAVE:
|
||||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, max_topo
|
||||
)
|
||||
|
||||
events, next_key = yield data_source.get_pagination_rows(
|
||||
requester.user, source_config, room_id
|
||||
)
|
||||
|
||||
next_token = pagin_config.from_token.copy_and_replace(
|
||||
"room_key", next_key
|
||||
)
|
||||
|
||||
if not events:
|
||||
defer.returnValue({
|
||||
"chunk": [],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
})
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store,
|
||||
user_id,
|
||||
events,
|
||||
is_peeking=(member_event_id is None),
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunk = {
|
||||
"chunk": [
|
||||
serialize_event(e, time_now, as_client_event)
|
||||
for e in events
|
||||
],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
"end": next_token.to_string(),
|
||||
}
|
||||
|
||||
defer.returnValue(chunk)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_files(self, requester, room_id, pagin_config):
|
||||
"""Get files in a room.
|
||||
|
||||
Args:
|
||||
requester (Requester): The user requesting files.
|
||||
room_id (str): The room they want files from.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config rules to apply, if any.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
Returns:
|
||||
dict: Pagination API results
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if pagin_config.from_token:
|
||||
room_token = pagin_config.from_token.room_key
|
||||
else:
|
||||
pagin_config.from_token = (
|
||||
yield self.hs.get_event_sources().get_current_token(
|
||||
direction='b'
|
||||
)
|
||||
)
|
||||
room_token = pagin_config.from_token.room_key
|
||||
|
||||
room_token = RoomStreamToken.parse(room_token)
|
||||
|
||||
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
|
||||
"room_key", str(room_token)
|
||||
)
|
||||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
|
||||
if source_config.direction == 'b':
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
# requires that we have a topo token.
|
||||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
|
||||
max_topo = yield self.store.get_max_topological_token(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
@@ -110,12 +220,12 @@ class MessageHandler(BaseHandler):
|
||||
if leave_token.topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, max_topo
|
||||
)
|
||||
|
||||
events, next_key = yield data_source.get_pagination_rows(
|
||||
requester.user, source_config, room_id
|
||||
events, next_key = yield self.store.paginate_room_file_events(
|
||||
room_id,
|
||||
from_key=source_config.from_key,
|
||||
to_key=source_config.to_key,
|
||||
direction=source_config.direction,
|
||||
limit=source_config.limit,
|
||||
)
|
||||
|
||||
next_token = pagin_config.from_token.copy_and_replace(
|
||||
@@ -140,7 +250,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
chunk = {
|
||||
"chunk": [
|
||||
serialize_event(e, time_now, as_client_event)
|
||||
serialize_event(e, time_now)
|
||||
for e in events
|
||||
],
|
||||
"start": pagin_config.from_token.to_string(),
|
||||
|
||||
@@ -36,13 +36,6 @@ class ProfileHandler(BaseHandler):
|
||||
"profile", self.on_profile_query
|
||||
)
|
||||
|
||||
distributor = hs.get_distributor()
|
||||
|
||||
distributor.observe("registered_user", self.registered_user)
|
||||
|
||||
def registered_user(self, user):
|
||||
return self.store.create_profile(user.localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user):
|
||||
if self.hs.is_mine(target_user):
|
||||
|
||||
@@ -23,7 +23,6 @@ from synapse.api.errors import (
|
||||
from ._base import BaseHandler
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
from synapse.util.distributor import registered_user
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
@@ -37,8 +36,6 @@ class RegistrationHandler(BaseHandler):
|
||||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("registered_user")
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
|
||||
self._next_generated_user_id = None
|
||||
@@ -93,7 +90,8 @@ class RegistrationHandler(BaseHandler):
|
||||
password=None,
|
||||
generate_token=True,
|
||||
guest_access_token=None,
|
||||
make_guest=False
|
||||
make_guest=False,
|
||||
admin=False,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
@@ -140,9 +138,12 @@ class RegistrationHandler(BaseHandler):
|
||||
password_hash=password_hash,
|
||||
was_guest=was_guest,
|
||||
make_guest=make_guest,
|
||||
create_profile_with_localpart=(
|
||||
# If the user was a guest then they already have a profile
|
||||
None if was_guest else user.localpart
|
||||
),
|
||||
admin=admin,
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
# autogen a sequential user ID
|
||||
attempts = 0
|
||||
@@ -160,7 +161,8 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
make_guest=make_guest
|
||||
make_guest=make_guest,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
except SynapseError:
|
||||
# if user id is taken, just generate another
|
||||
@@ -168,7 +170,6 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id = None
|
||||
token = None
|
||||
attempts += 1
|
||||
yield registered_user(self.distributor, user)
|
||||
|
||||
# We used to generate default identicons here, but nowadays
|
||||
# we want clients to generate their own as part of their branding
|
||||
@@ -201,8 +202,8 @@ class RegistrationHandler(BaseHandler):
|
||||
token=token,
|
||||
password_hash="",
|
||||
appservice_id=service_id,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
yield registered_user(self.distributor, user)
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -248,9 +249,9 @@ class RegistrationHandler(BaseHandler):
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
password_hash=None,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
yield registered_user(self.distributor, user)
|
||||
except Exception as e:
|
||||
yield self.store.add_access_token_to_user(user_id, token)
|
||||
# Ignore Registration errors
|
||||
@@ -359,7 +360,8 @@ class RegistrationHandler(BaseHandler):
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds,
|
||||
password_hash=None):
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
||||
@@ -388,17 +390,16 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||
token = self.auth_handler().generate_short_term_login_token(
|
||||
user_id, duration_seconds)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
password_hash=password_hash,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
yield self.store.user_delete_access_tokens(user_id=user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
|
||||
@@ -20,7 +20,7 @@ from ._base import BaseHandler
|
||||
|
||||
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
|
||||
from synapse.api.constants import (
|
||||
EventTypes, JoinRules, RoomCreationPreset,
|
||||
EventTypes, JoinRules, RoomCreationPreset, Membership,
|
||||
)
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.util import stringutils
|
||||
@@ -367,14 +367,10 @@ class RoomListHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_room(room_id):
|
||||
# We pull each bit of state out indvidually to avoid pulling the
|
||||
# full state into memory. Due to how the caching works this should
|
||||
# be fairly quick, even if not originally in the cache.
|
||||
def get_state(etype, state_key):
|
||||
return self.state_handler.get_current_state(room_id, etype, state_key)
|
||||
current_state = yield self.state_handler.get_current_state(room_id)
|
||||
|
||||
# Double check that this is actually a public room.
|
||||
join_rules_event = yield get_state(EventTypes.JoinRules, "")
|
||||
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
join_rule = join_rules_event.content.get("join_rule", None)
|
||||
if join_rule and join_rule != JoinRules.PUBLIC:
|
||||
@@ -382,47 +378,51 @@ class RoomListHandler(BaseHandler):
|
||||
|
||||
result = {"room_id": room_id}
|
||||
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
if len(joined_users) == 0:
|
||||
num_joined_users = len([
|
||||
1 for _, event in current_state.items()
|
||||
if event.type == EventTypes.Member
|
||||
and event.membership == Membership.JOIN
|
||||
])
|
||||
if num_joined_users == 0:
|
||||
return
|
||||
|
||||
result["num_joined_members"] = len(joined_users)
|
||||
result["num_joined_members"] = num_joined_users
|
||||
|
||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
|
||||
name_event = yield get_state(EventTypes.Name, "")
|
||||
name_event = yield current_state.get((EventTypes.Name, ""))
|
||||
if name_event:
|
||||
name = name_event.content.get("name", None)
|
||||
if name:
|
||||
result["name"] = name
|
||||
|
||||
topic_event = yield get_state(EventTypes.Topic, "")
|
||||
topic_event = current_state.get((EventTypes.Topic, ""))
|
||||
if topic_event:
|
||||
topic = topic_event.content.get("topic", None)
|
||||
if topic:
|
||||
result["topic"] = topic
|
||||
|
||||
canonical_event = yield get_state(EventTypes.CanonicalAlias, "")
|
||||
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
|
||||
if canonical_event:
|
||||
canonical_alias = canonical_event.content.get("alias", None)
|
||||
if canonical_alias:
|
||||
result["canonical_alias"] = canonical_alias
|
||||
|
||||
visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "")
|
||||
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
|
||||
visibility = None
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get("history_visibility", None)
|
||||
result["world_readable"] = visibility == "world_readable"
|
||||
|
||||
guest_event = yield get_state(EventTypes.GuestAccess, "")
|
||||
guest_event = current_state.get((EventTypes.GuestAccess, ""))
|
||||
guest = None
|
||||
if guest_event:
|
||||
guest = guest_event.content.get("guest_access", None)
|
||||
result["guest_can_join"] = guest == "can_join"
|
||||
|
||||
avatar_event = yield get_state("m.room.avatar", "")
|
||||
avatar_event = current_state.get(("m.room.avatar", ""))
|
||||
if avatar_event:
|
||||
avatar_url = avatar_event.content.get("url", None)
|
||||
if avatar_url:
|
||||
|
||||
@@ -221,6 +221,9 @@ class TypingHandler(object):
|
||||
|
||||
def get_all_typing_updates(self, last_id, current_id):
|
||||
# TODO: Work out a way to do this without scanning the entire state.
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
rows = []
|
||||
for room_id, serial in self._room_serials.items():
|
||||
if last_id < serial and serial <= current_id:
|
||||
|
||||
@@ -24,12 +24,13 @@ from synapse.http.endpoint import SpiderEndpoint
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer, reactor, ssl, protocol
|
||||
from twisted.internet import defer, reactor, ssl, protocol, task
|
||||
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
|
||||
from twisted.web.client import (
|
||||
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
||||
readBody, FileBodyProducer, PartialDownloadError,
|
||||
readBody, PartialDownloadError,
|
||||
)
|
||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web._newclient import ResponseDone
|
||||
@@ -468,3 +469,26 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
||||
|
||||
def creatorForNetloc(self, hostname, port):
|
||||
return self
|
||||
|
||||
|
||||
class FileBodyProducer(TwistedFileBodyProducer):
|
||||
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
|
||||
|
||||
We override the pauseProducing and resumeProducing methods in twisted's
|
||||
FileBodyProducer so that they do not raise exceptions if the task has
|
||||
already completed.
|
||||
"""
|
||||
|
||||
def pauseProducing(self):
|
||||
try:
|
||||
super(FileBodyProducer, self).pauseProducing()
|
||||
except task.TaskDone:
|
||||
# task has already completed
|
||||
pass
|
||||
|
||||
def resumeProducing(self):
|
||||
try:
|
||||
super(FileBodyProducer, self).resumeProducing()
|
||||
except task.NotPaused:
|
||||
# task was not paused (probably because it had already completed)
|
||||
pass
|
||||
|
||||
@@ -38,6 +38,7 @@ class HttpPusher(object):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
self.user_id = pusherdict['user_name']
|
||||
self.app_id = pusherdict['app_id']
|
||||
self.app_display_name = pusherdict['app_display_name']
|
||||
@@ -237,7 +238,9 @@ class HttpPusher(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _build_notification_dict(self, event, tweaks, badge):
|
||||
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
|
||||
ctx = yield push_tools.get_context_for_event(
|
||||
self.state_handler, event, self.user_id
|
||||
)
|
||||
|
||||
d = {
|
||||
'notification': {
|
||||
@@ -269,8 +272,8 @@ class HttpPusher(object):
|
||||
if 'content' in event:
|
||||
d['notification']['content'] = event.content
|
||||
|
||||
if len(ctx['aliases']):
|
||||
d['notification']['room_alias'] = ctx['aliases'][0]
|
||||
# We no longer send aliases separately, instead, we send the human
|
||||
# readable name of the room, which may be an alias.
|
||||
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
|
||||
d['notification']['sender_display_name'] = ctx['sender_display_name']
|
||||
if 'name' in ctx and len(ctx['name']) > 0:
|
||||
|
||||
@@ -273,16 +273,16 @@ class Mailer(object):
|
||||
|
||||
sender_state_event = room_state[("m.room.member", event.sender)]
|
||||
sender_name = name_from_member_event(sender_state_event)
|
||||
sender_avatar_url = None
|
||||
if "avatar_url" in sender_state_event.content:
|
||||
sender_avatar_url = sender_state_event.content["avatar_url"]
|
||||
sender_avatar_url = sender_state_event.content.get("avatar_url")
|
||||
|
||||
# 'hash' for deterministically picking default images: use
|
||||
# sender_hash % the number of default images to choose from
|
||||
sender_hash = string_ordinal_total(event.sender)
|
||||
|
||||
msgtype = event.content.get("msgtype")
|
||||
|
||||
ret = {
|
||||
"msgtype": event.content["msgtype"],
|
||||
"msgtype": msgtype,
|
||||
"is_historical": event.event_id != notif['event_id'],
|
||||
"id": event.event_id,
|
||||
"ts": event.origin_server_ts,
|
||||
@@ -291,9 +291,9 @@ class Mailer(object):
|
||||
"sender_hash": sender_hash,
|
||||
}
|
||||
|
||||
if event.content["msgtype"] == "m.text":
|
||||
if msgtype == "m.text":
|
||||
self.add_text_message_vars(ret, event)
|
||||
elif event.content["msgtype"] == "m.image":
|
||||
elif msgtype == "m.image":
|
||||
self.add_image_message_vars(ret, event)
|
||||
|
||||
if "body" in event.content:
|
||||
@@ -302,16 +302,17 @@ class Mailer(object):
|
||||
return ret
|
||||
|
||||
def add_text_message_vars(self, messagevars, event):
|
||||
if "format" in event.content:
|
||||
msgformat = event.content["format"]
|
||||
else:
|
||||
msgformat = None
|
||||
msgformat = event.content.get("format")
|
||||
|
||||
messagevars["format"] = msgformat
|
||||
|
||||
if msgformat == "org.matrix.custom.html":
|
||||
messagevars["body_text_html"] = safe_markup(event.content["formatted_body"])
|
||||
else:
|
||||
messagevars["body_text_html"] = safe_text(event.content["body"])
|
||||
formatted_body = event.content.get("formatted_body")
|
||||
body = event.content.get("body")
|
||||
|
||||
if msgformat == "org.matrix.custom.html" and formatted_body:
|
||||
messagevars["body_text_html"] = safe_markup(formatted_body)
|
||||
elif body:
|
||||
messagevars["body_text_html"] = safe_text(body)
|
||||
|
||||
return messagevars
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.util.presentable_names import (
|
||||
calculate_room_name, name_from_member_event
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_context_for_event(store, ev):
|
||||
name_aliases = yield store.get_room_name_and_aliases(
|
||||
ev.room_id
|
||||
)
|
||||
def get_context_for_event(state_handler, ev, user_id):
|
||||
ctx = {}
|
||||
|
||||
ctx = {'aliases': name_aliases[1]}
|
||||
if name_aliases[0] is not None:
|
||||
ctx['name'] = name_aliases[0]
|
||||
room_state = yield state_handler.get_current_state(ev.room_id)
|
||||
|
||||
their_member_events_for_room = yield store.get_current_state(
|
||||
room_id=ev.room_id,
|
||||
event_type='m.room.member',
|
||||
state_key=ev.user_id
|
||||
# we no longer bother setting room_alias, and make room_name the
|
||||
# human-readable name instead, be that m.room.namer, an alias or
|
||||
# a list of people in the room
|
||||
name = calculate_room_name(
|
||||
room_state, user_id, fallback_to_single_member=False
|
||||
)
|
||||
for mev in their_member_events_for_room:
|
||||
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
|
||||
dn = mev.content['displayname']
|
||||
if dn is not None:
|
||||
ctx['sender_display_name'] = dn
|
||||
if name:
|
||||
ctx['name'] = name
|
||||
|
||||
sender_state_event = room_state[("m.room.member", ev.sender)]
|
||||
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
|
||||
|
||||
defer.returnValue(ctx)
|
||||
|
||||
@@ -48,6 +48,9 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
||||
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
||||
},
|
||||
"ldap": {
|
||||
"ldap3>=1.0": ["ldap3>=1.0"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.room import RoomStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
@@ -64,7 +63,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
|
||||
# Cached functions can't be accessed through a class instance so we need
|
||||
# to reach inside the __dict__ to extract them.
|
||||
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
|
||||
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
||||
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
||||
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
||||
@@ -202,7 +200,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
self.get_rooms_for_user.invalidate_all()
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
self.get_room_name_and_aliases.invalidate((event.room_id,))
|
||||
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
|
||||
@@ -246,9 +243,3 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
self._get_current_state_for_key.invalidate((
|
||||
event.room_id, event.type, event.state_key
|
||||
))
|
||||
|
||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||
self.get_room_name_and_aliases.invalidate(
|
||||
(event.room_id,)
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PurgeMediaCacheRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/purge_media_cache")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.media_repository = hs.get_media_repository()
|
||||
super(PurgeMediaCacheRestServlet, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
before_ts = request.args.get("before_ts", None)
|
||||
if not before_ts:
|
||||
raise SynapseError(400, "Missing 'before_ts' arg")
|
||||
|
||||
logger.info("before_ts: %r", before_ts[0])
|
||||
|
||||
try:
|
||||
before_ts = int(before_ts[0])
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid 'before_ts' arg")
|
||||
|
||||
ret = yield self.media_repository.delete_old_remote_media(before_ts)
|
||||
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PurgeHistoryRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
yield self.handlers.message_handler.purge_history(room_id, event_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class DeactivateAccountRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, target_user_id):
|
||||
UserID.from_string(target_user_id)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
# FIXME: Theoretically there is a race here wherein user resets password
|
||||
# using threepid.
|
||||
yield self.store.user_delete_access_tokens(target_user_id)
|
||||
yield self.store.user_delete_threepids(target_user_id)
|
||||
yield self.store.user_set_password_hash(target_user_id, None)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
WhoisRestServlet(hs).register(http_server)
|
||||
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
PurgeHistoryRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -45,30 +45,27 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
||||
raise SynapseError(400, "Guest users must specify room_id param")
|
||||
if "room_id" in request.args:
|
||||
room_id = request.args["room_id"][0]
|
||||
try:
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
as_client_event = "raw" not in request.args
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
as_client_event=as_client_event,
|
||||
affect_presence=(not is_guest),
|
||||
room_id=room_id,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
except:
|
||||
logger.exception("Event stream failed")
|
||||
raise
|
||||
as_client_event = "raw" not in request.args
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
as_client_event=as_client_event,
|
||||
affect_presence=(not is_guest),
|
||||
room_id=room_id,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
|
||||
defer.returnValue((200, chunk))
|
||||
|
||||
|
||||
@@ -324,6 +324,14 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
user = register_json["user"].encode("utf-8")
|
||||
password = register_json["password"].encode("utf-8")
|
||||
admin = register_json.get("admin", None)
|
||||
|
||||
# Its important to check as we use null bytes as HMAC field separators
|
||||
if "\x00" in user:
|
||||
raise SynapseError(400, "Invalid user")
|
||||
if "\x00" in password:
|
||||
raise SynapseError(400, "Invalid password")
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
@@ -331,17 +339,21 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
msg=user,
|
||||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
password = register_json["password"].encode("utf-8")
|
||||
)
|
||||
want_mac.update(user)
|
||||
want_mac.update("\x00")
|
||||
want_mac.update(password)
|
||||
want_mac.update("\x00")
|
||||
want_mac.update("admin" if admin else "notadmin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.register(
|
||||
localpart=user,
|
||||
password=password,
|
||||
admin=bool(admin),
|
||||
)
|
||||
self._remove_session(session)
|
||||
defer.returnValue({
|
||||
@@ -410,12 +422,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||
if duration_seconds > self.direct_user_creation_max_duration:
|
||||
duration_seconds = self.direct_user_creation_max_duration
|
||||
password_hash = user_json["password_hash"].encode("utf-8") \
|
||||
if user_json.get("password_hash") else None
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.get_or_create_user(
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_seconds=duration_seconds
|
||||
duration_seconds=duration_seconds,
|
||||
password_hash=password_hash
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
|
||||
@@ -72,8 +72,6 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
||||
|
||||
def get_room_config(self, request):
|
||||
user_supplied_config = parse_json_object_from_request(request)
|
||||
# default visibility
|
||||
user_supplied_config.setdefault("visibility", "public")
|
||||
return user_supplied_config
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
@@ -279,6 +277,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
try:
|
||||
yield self.auth.get_user_by_req(request)
|
||||
except AuthError:
|
||||
# This endpoint isn't authed, but its useful to know who's hitting
|
||||
# it if they *do* supply an access token
|
||||
pass
|
||||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
data = yield handler.get_aggregated_public_room_list()
|
||||
|
||||
@@ -333,6 +338,25 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||
defer.returnValue((200, msgs))
|
||||
|
||||
|
||||
class RoomFileListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/files$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(
|
||||
request, default_limit=10, default_dir='b',
|
||||
)
|
||||
handler = self.handlers.message_handler
|
||||
msgs = yield handler.get_files(
|
||||
room_id=room_id,
|
||||
requester=requester,
|
||||
pagin_config=pagination_config,
|
||||
)
|
||||
|
||||
defer.returnValue((200, msgs))
|
||||
|
||||
|
||||
# TODO: Needs unit testing
|
||||
class RoomStateRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
|
||||
@@ -662,6 +686,7 @@ def register_servlets(hs, http_server):
|
||||
RoomCreateRestServlet(hs).register(http_server)
|
||||
RoomMemberListRestServlet(hs).register(http_server)
|
||||
RoomMessageListRestServlet(hs).register(http_server)
|
||||
RoomFileListRestServlet(hs).register(http_server)
|
||||
JoinRoomAliasServlet(hs).register(http_server)
|
||||
RoomForgetRestServlet(hs).register(http_server)
|
||||
RoomMembershipRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -15,14 +15,12 @@
|
||||
|
||||
from synapse.http.server import respond_with_json_bytes, finish_request
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, SynapseError, CodeMessageException, Codes, cs_error
|
||||
Codes, cs_error
|
||||
)
|
||||
|
||||
from twisted.protocols.basic import FileSender
|
||||
from twisted.web import server, resource
|
||||
from twisted.internet import defer
|
||||
|
||||
import base64
|
||||
import simplejson as json
|
||||
@@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
|
||||
"""
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, directory, auth, external_addr):
|
||||
def __init__(self, hs, directory):
|
||||
resource.Resource.__init__(self)
|
||||
self.hs = hs
|
||||
self.directory = directory
|
||||
self.auth = auth
|
||||
self.external_addr = external_addr.rstrip('/')
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
|
||||
if not os.path.isdir(self.directory):
|
||||
os.mkdir(self.directory)
|
||||
logger.info("ContentRepoResource : Created %s directory.",
|
||||
self.directory)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def map_request_to_name(self, request):
|
||||
# auth the user
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# namespace all file uploads on the user
|
||||
prefix = base64.urlsafe_b64encode(
|
||||
requester.user.to_string()
|
||||
).replace('=', '')
|
||||
|
||||
# use a random string for the main portion
|
||||
main_part = random_string(24)
|
||||
|
||||
# suffix with a file extension if we can make one. This is nice to
|
||||
# provide a hint to clients on the file information. We will also reuse
|
||||
# this info to spit back the content type to the client.
|
||||
suffix = ""
|
||||
if request.requestHeaders.hasHeader("Content-Type"):
|
||||
content_type = request.requestHeaders.getRawHeaders(
|
||||
"Content-Type")[0]
|
||||
suffix = "." + base64.urlsafe_b64encode(content_type)
|
||||
if (content_type.split("/")[0].lower() in
|
||||
["image", "video", "audio"]):
|
||||
file_ext = content_type.split("/")[-1]
|
||||
# be a little paranoid and only allow a-z
|
||||
file_ext = re.sub("[^a-z]", "", file_ext)
|
||||
suffix += "." + file_ext
|
||||
|
||||
file_name = prefix + main_part + suffix
|
||||
file_path = os.path.join(self.directory, file_name)
|
||||
logger.info("User %s is uploading a file to path %s",
|
||||
request.user.user_id.to_string(),
|
||||
file_path)
|
||||
|
||||
# keep trying to make a non-clashing file, with a sensible max attempts
|
||||
attempts = 0
|
||||
while os.path.exists(file_path):
|
||||
main_part = random_string(24)
|
||||
file_name = prefix + main_part + suffix
|
||||
file_path = os.path.join(self.directory, file_name)
|
||||
attempts += 1
|
||||
if attempts > 25: # really? Really?
|
||||
raise SynapseError(500, "Unable to create file.")
|
||||
|
||||
defer.returnValue(file_path)
|
||||
|
||||
def render_GET(self, request):
|
||||
# no auth here on purpose, to allow anyone to view, even across home
|
||||
@@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
|
||||
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render(request)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
respond_with_json_bytes(request, 200, {}, send_cors=True)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _async_render(self, request):
|
||||
try:
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
raise SynapseError(
|
||||
msg="Upload request body is too large",
|
||||
code=413,
|
||||
)
|
||||
|
||||
fname = yield self.map_request_to_name(request)
|
||||
|
||||
# TODO I have a suspicious feeling this is just going to block
|
||||
with open(fname, "wb") as f:
|
||||
f.write(request.content.read())
|
||||
|
||||
# FIXME (erikj): These should use constants.
|
||||
file_name = os.path.basename(fname)
|
||||
# FIXME: we can't assume what the repo's public mounted path is
|
||||
# ...plus self-signed SSL won't work to remote clients anyway
|
||||
# ...and we can't assume that it's SSL anyway, as we might want to
|
||||
# serve it via the non-SSL listener...
|
||||
url = "%s/_matrix/content/%s" % (
|
||||
self.external_addr, file_name
|
||||
)
|
||||
|
||||
respond_with_json_bytes(request, 200,
|
||||
json.dumps({"content_token": url}),
|
||||
send_cors=True)
|
||||
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
respond_with_json_bytes(request, e.code,
|
||||
json.dumps(cs_exception(e)))
|
||||
except Exception as e:
|
||||
logger.error("Failed to store file: %s" % e)
|
||||
respond_with_json_bytes(
|
||||
request,
|
||||
500,
|
||||
json.dumps({"error": "Internal server error"}),
|
||||
send_cors=True)
|
||||
|
||||
@@ -65,3 +65,9 @@ class MediaFilePaths(object):
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
)
|
||||
|
||||
@@ -26,14 +26,17 @@ from .thumbnailer import Thumbnailer
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
|
||||
import os
|
||||
import errno
|
||||
import shutil
|
||||
|
||||
import cgi
|
||||
import logging
|
||||
@@ -42,8 +45,11 @@ import urlparse
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
|
||||
|
||||
|
||||
class MediaRepository(object):
|
||||
def __init__(self, hs, filepaths):
|
||||
def __init__(self, hs):
|
||||
self.auth = hs.get_auth()
|
||||
self.client = MatrixFederationHttpClient(hs)
|
||||
self.clock = hs.get_clock()
|
||||
@@ -51,11 +57,28 @@ class MediaRepository(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
self.filepaths = filepaths
|
||||
self.downloads = {}
|
||||
self.filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer()
|
||||
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
self.clock.looping_call(
|
||||
self._update_recently_accessed_remotes,
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_recently_accessed_remotes(self):
|
||||
media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
yield self.store.update_cached_last_access_time(
|
||||
media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _makedirs(filepath):
|
||||
dirname = os.path.dirname(filepath)
|
||||
@@ -92,22 +115,12 @@ class MediaRepository(object):
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media(self, server_name, media_id):
|
||||
key = (server_name, media_id)
|
||||
download = self.downloads.get(key)
|
||||
if download is None:
|
||||
download = self._get_remote_media_impl(server_name, media_id)
|
||||
download = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
self.downloads[key] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(media_info):
|
||||
del self.downloads[key]
|
||||
return media_info
|
||||
return download.observe()
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
media_info = yield self._get_remote_media_impl(server_name, media_id)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
@@ -118,6 +131,11 @@ class MediaRepository(object):
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id
|
||||
)
|
||||
else:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
yield self.store.update_cached_last_access_time(
|
||||
[(server_name, media_id)], self.clock.time_msec()
|
||||
)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -134,10 +152,15 @@ class MediaRepository(object):
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
)
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn("Failed to fetch remoted media %r", e)
|
||||
raise SynapseError(502, "Failed to fetch remoted media")
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
@@ -410,6 +433,41 @@ class MediaRepository(object):
|
||||
"height": m_height,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_old_remote_media(self, before_ts):
|
||||
old_media = yield self.store.get_remote_media_before(before_ts)
|
||||
|
||||
deleted = 0
|
||||
|
||||
for media in old_media:
|
||||
origin = media["media_origin"]
|
||||
media_id = media["media_id"]
|
||||
file_id = media["filesystem_id"]
|
||||
key = (origin, media_id)
|
||||
|
||||
logger.info("Deleting: %r", key)
|
||||
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||
try:
|
||||
os.remove(full_path)
|
||||
except OSError as e:
|
||||
logger.warn("Failed to remove file: %r", full_path)
|
||||
if e.errno == errno.ENOENT:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
|
||||
origin, file_id
|
||||
)
|
||||
shutil.rmtree(thumbnail_dir, ignore_errors=True)
|
||||
|
||||
yield self.store.delete_remote_media(origin, media_id)
|
||||
deleted += 1
|
||||
|
||||
defer.returnValue({"deleted": deleted})
|
||||
|
||||
|
||||
class MediaRepositoryResource(Resource):
|
||||
"""File uploading and downloading.
|
||||
@@ -458,9 +516,8 @@ class MediaRepositoryResource(Resource):
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self)
|
||||
filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
|
||||
media_repo = MediaRepository(hs, filepaths)
|
||||
media_repo = hs.get_media_repository()
|
||||
|
||||
self.putChild("upload", UploadResource(hs, media_repo))
|
||||
self.putChild("download", DownloadResource(hs, media_repo))
|
||||
|
||||
@@ -252,7 +252,8 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
@@ -279,7 +280,7 @@ class PreviewUrlResource(Resource):
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * int(i.attrib['width']) * int(i.attrib['height'])
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
@@ -287,9 +288,9 @@ class PreviewUrlResource(Resource):
|
||||
og['og:image'] = images[0].attrib['src']
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url request
|
||||
# itself and benefit from the same caching etc. But for now we just rely on the
|
||||
# caching on the master request to speed things up.
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
self._rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
|
||||
@@ -45,6 +45,7 @@ from synapse.crypto.keyring import Keyring
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.api.filtering import Filtering
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
|
||||
@@ -113,6 +114,7 @@ class HomeServer(object):
|
||||
'filtering',
|
||||
'http_client_context_factory',
|
||||
'simple_http_client',
|
||||
'media_repository',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
@@ -233,6 +235,9 @@ class HomeServer(object):
|
||||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
def build_media_repository(self):
|
||||
return MediaRepository(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
||||
@@ -807,6 +807,11 @@ class SQLBaseStore(object):
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "more than one row matched")
|
||||
|
||||
def _simple_delete(self, table, keyvalues, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_txn, table, keyvalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_txn(txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
|
||||
@@ -138,6 +138,9 @@ class AccountDataStore(SQLBaseStore):
|
||||
A deferred pair of lists of tuples of stream_id int, user_id string,
|
||||
room_id string, type string, and content string.
|
||||
"""
|
||||
if last_room_id == current_id and last_global_id == current_id:
|
||||
return defer.succeed(([], []))
|
||||
|
||||
def get_updated_account_data_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, account_data_type, content"
|
||||
|
||||
@@ -32,7 +32,7 @@ def create_engine(database_config):
|
||||
|
||||
if engine_class:
|
||||
module = importlib.import_module(name)
|
||||
return engine_class(module)
|
||||
return engine_class(module, database_config)
|
||||
|
||||
raise RuntimeError(
|
||||
"Unsupported database engine '%s'" % (name,)
|
||||
|
||||
@@ -19,9 +19,10 @@ from ._base import IncorrectDatabaseSetup
|
||||
class PostgresEngine(object):
|
||||
single_threaded = False
|
||||
|
||||
def __init__(self, database_module):
|
||||
def __init__(self, database_module, database_config):
|
||||
self.module = database_module
|
||||
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
||||
self.synchronous_commit = database_config.get("synchronous_commit", True)
|
||||
|
||||
def check_database(self, txn):
|
||||
txn.execute("SHOW SERVER_ENCODING")
|
||||
@@ -40,9 +41,19 @@ class PostgresEngine(object):
|
||||
db_conn.set_isolation_level(
|
||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
)
|
||||
# Asynchronous commit, don't wait for the server to call fsync before
|
||||
# ending the transaction.
|
||||
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
|
||||
if not self.synchronous_commit:
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("SET synchronous_commit TO OFF")
|
||||
cursor.close()
|
||||
|
||||
def is_deadlock(self, error):
|
||||
if isinstance(error, self.module.DatabaseError):
|
||||
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html
|
||||
# "40001" serialization_failure
|
||||
# "40P01" deadlock_detected
|
||||
return error.pgcode in ["40001", "40P01"]
|
||||
return False
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import struct
|
||||
class Sqlite3Engine(object):
|
||||
single_threaded = True
|
||||
|
||||
def __init__(self, database_module):
|
||||
def __init__(self, database_module, database_config):
|
||||
self.module = database_module
|
||||
|
||||
def check_database(self, txn):
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
from ._base import SQLBaseStore
|
||||
from twisted.internet import defer
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.types import RoomStreamToken
|
||||
from .stream import lower_bound
|
||||
|
||||
import logging
|
||||
import ujson as json
|
||||
@@ -73,6 +75,9 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
|
||||
stream_ordering = results[0][0]
|
||||
topological_ordering = results[0][1]
|
||||
token = RoomStreamToken(
|
||||
topological_ordering, stream_ordering
|
||||
)
|
||||
|
||||
sql = (
|
||||
"SELECT sum(notif), sum(highlight)"
|
||||
@@ -80,15 +85,10 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
" WHERE"
|
||||
" user_id = ?"
|
||||
" AND room_id = ?"
|
||||
" AND ("
|
||||
" topological_ordering > ?"
|
||||
" OR (topological_ordering = ? AND stream_ordering > ?)"
|
||||
")"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
user_id, room_id,
|
||||
topological_ordering, topological_ordering, stream_ordering
|
||||
))
|
||||
" AND %s"
|
||||
) % (lower_bound(token, self.database_engine, inclusive=False),)
|
||||
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return {
|
||||
@@ -152,7 +152,7 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
if max_stream_ordering is not None:
|
||||
sql += " AND ep.stream_ordering <= ?"
|
||||
args.append(max_stream_ordering)
|
||||
sql += " ORDER BY ep.stream_ordering ASC LIMIT ?"
|
||||
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
args.append(limit)
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
@@ -176,14 +176,16 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
if max_stream_ordering is not None:
|
||||
sql += " AND ep.stream_ordering <= ?"
|
||||
args.append(max_stream_ordering)
|
||||
sql += " ORDER BY ep.stream_ordering ASC"
|
||||
sql += " ORDER BY ep.stream_ordering DESC LIMIT ?"
|
||||
args.append(limit)
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
no_read_receipt = yield self.runInteraction(
|
||||
"get_unread_push_actions_for_user_in_range", get_no_receipt
|
||||
)
|
||||
|
||||
defer.returnValue([
|
||||
# Make a list of dicts from the two sets of results.
|
||||
notifs = [
|
||||
{
|
||||
"event_id": row[0],
|
||||
"room_id": row[1],
|
||||
@@ -191,7 +193,16 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
"actions": json.loads(row[3]),
|
||||
"received_ts": row[4],
|
||||
} for row in after_read_receipt + no_read_receipt
|
||||
])
|
||||
]
|
||||
|
||||
# Now sort it so it's ordered correctly, since currently it will
|
||||
# contain results from the first query, correctly ordered, followed
|
||||
# by results from the second query, but we want them all ordered
|
||||
# by received_ts
|
||||
notifs.sort(key=lambda r: -(r['received_ts'] or 0))
|
||||
|
||||
# Now return the first `limit`
|
||||
defer.returnValue(notifs[:limit])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_time_of_last_push_action_before(self, stream_ordering):
|
||||
|
||||
@@ -23,6 +23,7 @@ from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
from collections import deque, namedtuple
|
||||
@@ -355,7 +356,6 @@ class EventsStore(SQLBaseStore):
|
||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
|
||||
|
||||
# Add an entry to the current_state_resets table to record the point
|
||||
# where we clobbered the current state
|
||||
@@ -666,12 +666,6 @@ class EventsStore(SQLBaseStore):
|
||||
(event.room_id, event.type, event.state_key,)
|
||||
)
|
||||
|
||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||
txn.call_after(
|
||||
self.get_room_name_and_aliases.invalidate,
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
@@ -1288,6 +1282,156 @@ class EventsStore(SQLBaseStore):
|
||||
)
|
||||
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
|
||||
|
||||
def delete_old_state(self, room_id, topological_ordering):
|
||||
return self.runInteraction(
|
||||
"delete_old_state",
|
||||
self._delete_old_state_txn, room_id, topological_ordering
|
||||
)
|
||||
|
||||
def _delete_old_state_txn(self, txn, room_id, topological_ordering):
|
||||
"""Deletes old room state
|
||||
"""
|
||||
|
||||
# Tables that should be pruned:
|
||||
# event_auth
|
||||
# event_backward_extremities
|
||||
# event_content_hashes
|
||||
# event_destinations
|
||||
# event_edge_hashes
|
||||
# event_edges
|
||||
# event_forward_extremities
|
||||
# event_json
|
||||
# event_push_actions
|
||||
# event_reference_hashes
|
||||
# event_search
|
||||
# event_signatures
|
||||
# event_to_state_groups
|
||||
# events
|
||||
# rejections
|
||||
# room_depth
|
||||
# state_groups
|
||||
# state_groups_state
|
||||
|
||||
# First ensure that we're not about to delete all the forward extremeties
|
||||
txn.execute(
|
||||
"SELECT e.event_id, e.depth FROM events as e "
|
||||
"INNER JOIN event_forward_extremities as f "
|
||||
"ON e.event_id = f.event_id "
|
||||
"AND e.room_id = f.room_id "
|
||||
"WHERE f.room_id = ?",
|
||||
(room_id,)
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
max_depth = max(row[0] for row in rows)
|
||||
|
||||
if max_depth <= topological_ordering:
|
||||
# We need to ensure we don't delete all the events from the datanase
|
||||
# otherwise we wouldn't be able to send any events (due to not
|
||||
# having any backwards extremeties)
|
||||
raise SynapseError(
|
||||
400, "topological_ordering is greater than forward extremeties"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"SELECT event_id, state_key FROM events"
|
||||
" LEFT JOIN state_events USING (room_id, event_id)"
|
||||
" WHERE room_id = ? AND topological_ordering < ?",
|
||||
(room_id, topological_ordering,)
|
||||
)
|
||||
event_rows = txn.fetchall()
|
||||
|
||||
# We calculate the new entries for the backward extremeties by finding
|
||||
# all events that point to events that are to be purged
|
||||
txn.execute(
|
||||
"SELECT e.event_id FROM events as e"
|
||||
" INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
|
||||
" INNER JOIN events as e2 ON e2.event_id = ed.event_id"
|
||||
" WHERE e.room_id = ? AND e.topological_ordering < ?"
|
||||
" AND e2.topological_ordering >= ?",
|
||||
(room_id, topological_ordering, topological_ordering)
|
||||
)
|
||||
new_backwards_extrems = txn.fetchall()
|
||||
|
||||
# Get all state groups that are only referenced by events that are
|
||||
# to be deleted.
|
||||
txn.execute(
|
||||
"SELECT state_group FROM event_to_state_groups"
|
||||
" INNER JOIN events USING (event_id)"
|
||||
" WHERE state_group IN ("
|
||||
" SELECT DISTINCT state_group FROM events"
|
||||
" INNER JOIN event_to_state_groups USING (event_id)"
|
||||
" WHERE room_id = ? AND topological_ordering < ?"
|
||||
" )"
|
||||
" GROUP BY state_group HAVING MAX(topological_ordering) < ?",
|
||||
(room_id, topological_ordering, topological_ordering)
|
||||
)
|
||||
state_rows = txn.fetchall()
|
||||
txn.executemany(
|
||||
"DELETE FROM state_groups_state WHERE state_group = ?",
|
||||
state_rows
|
||||
)
|
||||
txn.executemany(
|
||||
"DELETE FROM state_groups WHERE id = ?",
|
||||
state_rows
|
||||
)
|
||||
# Delete all non-state
|
||||
txn.executemany(
|
||||
"DELETE FROM event_to_state_groups WHERE event_id = ?",
|
||||
[(event_id,) for event_id, _ in event_rows]
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
|
||||
(topological_ordering, room_id,)
|
||||
)
|
||||
|
||||
# Delete all remote non-state events
|
||||
to_delete = [
|
||||
(event_id,) for event_id, state_key in event_rows
|
||||
if state_key is None and not self.hs.is_mine_id(event_id)
|
||||
]
|
||||
for table in (
|
||||
"events",
|
||||
"event_json",
|
||||
"event_auth",
|
||||
"event_content_hashes",
|
||||
"event_destinations",
|
||||
"event_edge_hashes",
|
||||
"event_edges",
|
||||
"event_forward_extremities",
|
||||
"event_push_actions",
|
||||
"event_reference_hashes",
|
||||
"event_search",
|
||||
"event_signatures",
|
||||
"rejections",
|
||||
"event_backward_extremities",
|
||||
):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
to_delete
|
||||
)
|
||||
|
||||
# Update backward extremeties
|
||||
txn.executemany(
|
||||
"INSERT INTO event_backward_extremities (room_id, event_id)"
|
||||
" VALUES (?, ?)",
|
||||
[(room_id, event_id) for event_id, in new_backwards_extrems]
|
||||
)
|
||||
|
||||
txn.executemany(
|
||||
"DELETE FROM events WHERE event_id = ?",
|
||||
to_delete
|
||||
)
|
||||
# Mark all state and own events as outliers
|
||||
txn.executemany(
|
||||
"UPDATE events SET outlier = ?"
|
||||
" WHERE event_id = ?",
|
||||
[
|
||||
(True, event_id,) for event_id, state_key in event_rows
|
||||
if state_key is not None or self.hs.is_mine_id(event_id)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple("AllNewEventsResult", [
|
||||
"new_forward_events", "new_backfill_events",
|
||||
|
||||
@@ -157,10 +157,25 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"created_ts": time_now_ms,
|
||||
"upload_name": upload_name,
|
||||
"filesystem_id": filesystem_id,
|
||||
"last_access_ts": time_now_ms,
|
||||
},
|
||||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
||||
def update_cached_last_access_time(self, origin_id_tuples, time_ts):
|
||||
def update_cache_txn(txn):
|
||||
sql = (
|
||||
"UPDATE remote_media_cache SET last_access_ts = ?"
|
||||
" WHERE media_origin = ? AND media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, (
|
||||
(time_ts, media_origin, media_id)
|
||||
for media_origin, media_id in origin_id_tuples
|
||||
))
|
||||
|
||||
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
|
||||
|
||||
def get_remote_media_thumbnails(self, origin, media_id):
|
||||
return self._simple_select_list(
|
||||
"remote_media_cache_thumbnails",
|
||||
@@ -190,3 +205,32 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
},
|
||||
desc="store_remote_media_thumbnail",
|
||||
)
|
||||
|
||||
def get_remote_media_before(self, before_ts):
|
||||
sql = (
|
||||
"SELECT media_origin, media_id, filesystem_id"
|
||||
" FROM remote_media_cache"
|
||||
" WHERE last_access_ts < ?"
|
||||
)
|
||||
|
||||
return self._execute(
|
||||
"get_remote_media_before", self.cursor_to_dict, sql, before_ts
|
||||
)
|
||||
|
||||
def delete_remote_media(self, media_origin, media_id):
|
||||
def delete_remote_media_txn(txn):
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
"remote_media_cache",
|
||||
keyvalues={
|
||||
"media_origin": media_origin, "media_id": media_id
|
||||
},
|
||||
)
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
"remote_media_cache_thumbnails",
|
||||
keyvalues={
|
||||
"media_origin": media_origin, "media_id": media_id
|
||||
},
|
||||
)
|
||||
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 32
|
||||
SCHEMA_VERSION = 33
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
@@ -118,6 +118,9 @@ class PresenceStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
def get_all_presence_updates(self, last_id, current_id):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_presence_updates_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, state, last_active_ts,"
|
||||
|
||||
@@ -421,6 +421,9 @@ class PushRuleStore(SQLBaseStore):
|
||||
|
||||
def get_all_push_rule_updates(self, last_id, current_id, limit):
|
||||
"""Get all the push rules changes that have happend on the server"""
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def get_all_push_rule_updates_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
|
||||
|
||||
@@ -76,7 +76,8 @@ class RegistrationStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, user_id, token, password_hash,
|
||||
was_guest=False, make_guest=False, appservice_id=None):
|
||||
was_guest=False, make_guest=False, appservice_id=None,
|
||||
create_profile_with_localpart=None, admin=False):
|
||||
"""Attempts to register an account.
|
||||
|
||||
Args:
|
||||
@@ -88,6 +89,8 @@ class RegistrationStore(SQLBaseStore):
|
||||
make_guest (boolean): True if the the new user should be guest,
|
||||
false to add a regular user account.
|
||||
appservice_id (str): The ID of the appservice registering the user.
|
||||
create_profile_with_localpart (str): Optionally create a profile for
|
||||
the given localpart.
|
||||
Raises:
|
||||
StoreError if the user_id could not be registered.
|
||||
"""
|
||||
@@ -99,7 +102,9 @@ class RegistrationStore(SQLBaseStore):
|
||||
password_hash,
|
||||
was_guest,
|
||||
make_guest,
|
||||
appservice_id
|
||||
appservice_id,
|
||||
create_profile_with_localpart,
|
||||
admin
|
||||
)
|
||||
self.get_user_by_id.invalidate((user_id,))
|
||||
self.is_guest.invalidate((user_id,))
|
||||
@@ -112,7 +117,9 @@ class RegistrationStore(SQLBaseStore):
|
||||
password_hash,
|
||||
was_guest,
|
||||
make_guest,
|
||||
appservice_id
|
||||
appservice_id,
|
||||
create_profile_with_localpart,
|
||||
admin,
|
||||
):
|
||||
now = int(self.clock.time())
|
||||
|
||||
@@ -120,29 +127,48 @@ class RegistrationStore(SQLBaseStore):
|
||||
|
||||
try:
|
||||
if was_guest:
|
||||
txn.execute("UPDATE users SET"
|
||||
" password_hash = ?,"
|
||||
" upgrade_ts = ?,"
|
||||
" is_guest = ?"
|
||||
" WHERE name = ?",
|
||||
[password_hash, now, 1 if make_guest else 0, user_id])
|
||||
# Ensure that the guest user actually exists
|
||||
# ``allow_none=False`` makes this raise an exception
|
||||
# if the row isn't in the database.
|
||||
self._simple_select_one_txn(
|
||||
txn,
|
||||
"users",
|
||||
keyvalues={
|
||||
"name": user_id,
|
||||
"is_guest": 1,
|
||||
},
|
||||
retcols=("name",),
|
||||
allow_none=False,
|
||||
)
|
||||
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
"users",
|
||||
keyvalues={
|
||||
"name": user_id,
|
||||
"is_guest": 1,
|
||||
},
|
||||
updatevalues={
|
||||
"password_hash": password_hash,
|
||||
"upgrade_ts": now,
|
||||
"is_guest": 1 if make_guest else 0,
|
||||
"appservice_id": appservice_id,
|
||||
"admin": 1 if admin else 0,
|
||||
}
|
||||
)
|
||||
else:
|
||||
txn.execute("INSERT INTO users "
|
||||
"("
|
||||
" name,"
|
||||
" password_hash,"
|
||||
" creation_ts,"
|
||||
" is_guest,"
|
||||
" appservice_id"
|
||||
") "
|
||||
"VALUES (?,?,?,?,?)",
|
||||
[
|
||||
user_id,
|
||||
password_hash,
|
||||
now,
|
||||
1 if make_guest else 0,
|
||||
appservice_id,
|
||||
])
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"users",
|
||||
values={
|
||||
"name": user_id,
|
||||
"password_hash": password_hash,
|
||||
"creation_ts": now,
|
||||
"is_guest": 1 if make_guest else 0,
|
||||
"appservice_id": appservice_id,
|
||||
"admin": 1 if admin else 0,
|
||||
}
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise StoreError(
|
||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||
@@ -157,6 +183,12 @@ class RegistrationStore(SQLBaseStore):
|
||||
(next_id, user_id, token,)
|
||||
)
|
||||
|
||||
if create_profile_with_localpart:
|
||||
txn.execute(
|
||||
"INSERT INTO profiles(user_id) VALUES (?)",
|
||||
(create_profile_with_localpart,)
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
return self._simple_select_one(
|
||||
@@ -373,6 +405,15 @@ class RegistrationStore(SQLBaseStore):
|
||||
defer.returnValue(ret['user_id'])
|
||||
defer.returnValue(None)
|
||||
|
||||
def user_delete_threepids(self, user_id):
|
||||
return self._simple_delete(
|
||||
"user_threepids",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
desc="user_delete_threepids",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_all_users(self):
|
||||
"""Counts all users registered on the homeserver."""
|
||||
|
||||
@@ -18,7 +18,6 @@ from twisted.internet import defer
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from .engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
import collections
|
||||
@@ -35,6 +34,108 @@ OpsLevel = collections.namedtuple(
|
||||
|
||||
|
||||
class RoomStore(SQLBaseStore):
|
||||
EVENT_FILES_UPDATE_NAME = "event_files"
|
||||
|
||||
FILE_MSGTYPES = (
|
||||
"m.image",
|
||||
"m.video",
|
||||
"m.file",
|
||||
"m.audio",
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomStore, self).__init__(hs)
|
||||
self.register_background_update_handler(
|
||||
self.EVENT_FILES_UPDATE_NAME, self._background_reindex_files
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_files(self, progress, batch_size):
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
|
||||
def reindex_txn(txn):
|
||||
sql = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id, room_id,"
|
||||
" type, content FROM events"
|
||||
" WHERE ? <= stream_ordering AND stream_ordering < ?"
|
||||
" AND type = 'm.room.message'"
|
||||
" AND content LIKE ?"
|
||||
" ORDER BY stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
|
||||
txn.execute(sql, (target_min_stream_id, max_stream_id, '%url%', batch_size))
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
min_stream_id = rows[-1]["stream_ordering"]
|
||||
|
||||
event_files_rows = []
|
||||
for row in rows:
|
||||
try:
|
||||
so = row["stream_ordering"]
|
||||
to = row["topological_ordering"]
|
||||
event_id = row["event_id"]
|
||||
room_id = row["room_id"]
|
||||
try:
|
||||
content = json.loads(row["content"])
|
||||
except:
|
||||
continue
|
||||
|
||||
msgtype = content["msgtype"]
|
||||
if msgtype not in self.FILE_MSGTYPES:
|
||||
continue
|
||||
|
||||
url = content["url"]
|
||||
|
||||
if not isinstance(url, basestring):
|
||||
continue
|
||||
if not isinstance(msgtype, basestring):
|
||||
continue
|
||||
except (KeyError, AttributeError):
|
||||
# If the event is missing a necessary field then
|
||||
# skip over it.
|
||||
continue
|
||||
|
||||
event_files_rows.append({
|
||||
"topological_ordering": to,
|
||||
"stream_ordering": so,
|
||||
"event_id": event_id,
|
||||
"room_id": room_id,
|
||||
"msgtype": msgtype,
|
||||
"url": url,
|
||||
})
|
||||
|
||||
self._simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_files",
|
||||
values=event_files_rows,
|
||||
)
|
||||
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
"max_stream_id_exclusive": min_stream_id,
|
||||
"rows_inserted": rows_inserted + len(event_files_rows)
|
||||
}
|
||||
|
||||
self._background_update_progress_txn(
|
||||
txn, self.EVENT_FILES_UPDATE_NAME, progress
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
||||
result = yield self.runInteraction(
|
||||
self.EVENT_FILES_UPDATE_NAME, reindex_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
yield self._end_background_update(self.EVENT_FILES_UPDATE_NAME)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_room(self, room_id, room_creator_user_id, is_public):
|
||||
@@ -143,6 +244,22 @@ class RoomStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
def _store_room_message_txn(self, txn, event):
|
||||
msgtype = event.content.get("msgtype")
|
||||
url = event.content.get("url")
|
||||
if msgtype in self.FILE_MSGTYPES and url:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_files",
|
||||
values={
|
||||
"topological_ordering": event.depth,
|
||||
"stream_ordering": event.internal_metadata.stream_ordering,
|
||||
"room_id": event.room_id,
|
||||
"event_id": event.event_id,
|
||||
"msgtype": msgtype,
|
||||
"url": url,
|
||||
}
|
||||
)
|
||||
|
||||
if hasattr(event, "content") and "body" in event.content:
|
||||
self._store_event_search_txn(
|
||||
txn, event, "content.body", event.content["body"]
|
||||
@@ -192,49 +309,6 @@ class RoomStore(SQLBaseStore):
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_room_name_and_aliases(self, room_id):
|
||||
def get_room_name(txn):
|
||||
sql = (
|
||||
"SELECT name FROM room_names"
|
||||
" INNER JOIN current_state_events USING (room_id, event_id)"
|
||||
" WHERE room_id = ?"
|
||||
" LIMIT 1"
|
||||
)
|
||||
|
||||
txn.execute(sql, (room_id,))
|
||||
rows = txn.fetchall()
|
||||
if rows:
|
||||
return rows[0][0]
|
||||
else:
|
||||
return None
|
||||
|
||||
return [row[0] for row in txn.fetchall()]
|
||||
|
||||
def get_room_aliases(txn):
|
||||
sql = (
|
||||
"SELECT content FROM current_state_events"
|
||||
" INNER JOIN events USING (room_id, event_id)"
|
||||
" WHERE room_id = ?"
|
||||
)
|
||||
txn.execute(sql, (room_id,))
|
||||
return [row[0] for row in txn.fetchall()]
|
||||
|
||||
name = yield self.runInteraction("get_room_name", get_room_name)
|
||||
alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases)
|
||||
|
||||
aliases = []
|
||||
|
||||
for c in alias_contents:
|
||||
try:
|
||||
content = json.loads(c)
|
||||
except:
|
||||
continue
|
||||
|
||||
aliases.extend(content.get('aliases', []))
|
||||
|
||||
defer.returnValue((name, aliases))
|
||||
|
||||
def add_event_report(self, room_id, event_id, user_id, reason, content,
|
||||
received_ts):
|
||||
next_id = self._event_reports_id_gen.get_next()
|
||||
|
||||
72
synapse/storage/schema/delta/33/msgtype.py
Normal file
72
synapse/storage/schema/delta/33/msgtype.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
|
||||
import logging
|
||||
import ujson
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CREATE_TABLE = """
|
||||
CREATE TABLE event_files(
|
||||
topological_ordering BIGINT NOT NULL,
|
||||
stream_ordering BIGINT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
msgtype TEXT NOT NULL,
|
||||
url TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX event_files_rm_id ON event_files(room_id, event_id);
|
||||
CREATE INDEX event_files_order ON event_files(
|
||||
room_id, topological_ordering, stream_ordering
|
||||
);
|
||||
CREATE INDEX event_files_order_stream ON event_files(room_id, stream_ordering);
|
||||
"""
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
for statement in get_statements(CREATE_TABLE.splitlines()):
|
||||
cur.execute(statement)
|
||||
|
||||
cur.execute("SELECT MIN(stream_ordering) FROM events")
|
||||
rows = cur.fetchall()
|
||||
min_stream_id = rows[0][0]
|
||||
|
||||
cur.execute("SELECT MAX(stream_ordering) FROM events")
|
||||
rows = cur.fetchall()
|
||||
max_stream_id = rows[0][0]
|
||||
|
||||
if min_stream_id is not None and max_stream_id is not None:
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": min_stream_id,
|
||||
"max_stream_id_exclusive": max_stream_id + 1,
|
||||
"rows_inserted": 0,
|
||||
}
|
||||
progress_json = ujson.dumps(progress)
|
||||
|
||||
sql = (
|
||||
"INSERT into background_updates (update_name, progress_json)"
|
||||
" VALUES (?, ?)"
|
||||
)
|
||||
|
||||
sql = database_engine.convert_param_style(sql)
|
||||
|
||||
cur.execute(sql, ("event_files", progress_json))
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
pass
|
||||
31
synapse/storage/schema/delta/33/remote_media_ts.py
Normal file
31
synapse/storage/schema/delta/33/remote_media_ts.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
|
||||
|
||||
ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT"
|
||||
|
||||
|
||||
def run_create(cur, database_engine, *args, **kwargs):
|
||||
cur.execute(ALTER_TABLE)
|
||||
|
||||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
cur.execute(
|
||||
database_engine.convert_param_style(
|
||||
"UPDATE remote_media_cache SET last_access_ts = ?"
|
||||
),
|
||||
(int(time.time() * 1000),)
|
||||
)
|
||||
@@ -40,6 +40,7 @@ from synapse.util.caches.descriptors import cached
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.types import RoomStreamToken
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
import logging
|
||||
|
||||
@@ -54,25 +55,43 @@ _STREAM_TOKEN = "stream"
|
||||
_TOPOLOGICAL_TOKEN = "topological"
|
||||
|
||||
|
||||
def lower_bound(token):
|
||||
def lower_bound(token, engine, inclusive=False):
|
||||
inclusive = "=" if inclusive else ""
|
||||
if token.topological is None:
|
||||
return "(%d < %s)" % (token.stream, "stream_ordering")
|
||||
return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
|
||||
else:
|
||||
return "(%d < %s OR (%d = %s AND %d < %s))" % (
|
||||
if isinstance(engine, PostgresEngine):
|
||||
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
|
||||
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
|
||||
# use the later form when running against postgres.
|
||||
return "((%d,%d) <%s (%s,%s))" % (
|
||||
token.topological, token.stream, inclusive,
|
||||
"topological_ordering", "stream_ordering",
|
||||
)
|
||||
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
|
||||
token.topological, "topological_ordering",
|
||||
token.topological, "topological_ordering",
|
||||
token.stream, "stream_ordering",
|
||||
token.stream, inclusive, "stream_ordering",
|
||||
)
|
||||
|
||||
|
||||
def upper_bound(token):
|
||||
def upper_bound(token, engine, inclusive=True):
|
||||
inclusive = "=" if inclusive else ""
|
||||
if token.topological is None:
|
||||
return "(%d >= %s)" % (token.stream, "stream_ordering")
|
||||
return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
|
||||
else:
|
||||
return "(%d > %s OR (%d = %s AND %d >= %s))" % (
|
||||
if isinstance(engine, PostgresEngine):
|
||||
# Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
|
||||
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
|
||||
# use the later form when running against postgres.
|
||||
return "((%d,%d) >%s (%s,%s))" % (
|
||||
token.topological, token.stream, inclusive,
|
||||
"topological_ordering", "stream_ordering",
|
||||
)
|
||||
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
|
||||
token.topological, "topological_ordering",
|
||||
token.topological, "topological_ordering",
|
||||
token.stream, "stream_ordering",
|
||||
token.stream, inclusive, "stream_ordering",
|
||||
)
|
||||
|
||||
|
||||
@@ -308,18 +327,22 @@ class StreamStore(SQLBaseStore):
|
||||
args = [False, room_id]
|
||||
if direction == 'b':
|
||||
order = "DESC"
|
||||
bounds = upper_bound(RoomStreamToken.parse(from_key))
|
||||
bounds = upper_bound(
|
||||
RoomStreamToken.parse(from_key), self.database_engine
|
||||
)
|
||||
if to_key:
|
||||
bounds = "%s AND %s" % (
|
||||
bounds, lower_bound(RoomStreamToken.parse(to_key))
|
||||
)
|
||||
bounds = "%s AND %s" % (bounds, lower_bound(
|
||||
RoomStreamToken.parse(to_key), self.database_engine
|
||||
))
|
||||
else:
|
||||
order = "ASC"
|
||||
bounds = lower_bound(RoomStreamToken.parse(from_key))
|
||||
bounds = lower_bound(
|
||||
RoomStreamToken.parse(from_key), self.database_engine
|
||||
)
|
||||
if to_key:
|
||||
bounds = "%s AND %s" % (
|
||||
bounds, upper_bound(RoomStreamToken.parse(to_key))
|
||||
)
|
||||
bounds = "%s AND %s" % (bounds, upper_bound(
|
||||
RoomStreamToken.parse(to_key), self.database_engine
|
||||
))
|
||||
|
||||
if int(limit) > 0:
|
||||
args.append(int(limit))
|
||||
@@ -371,6 +394,82 @@ class StreamStore(SQLBaseStore):
|
||||
|
||||
defer.returnValue((events, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def paginate_room_file_events(self, room_id, from_key, to_key=None,
|
||||
direction='b', limit=-1):
|
||||
# Tokens really represent positions between elements, but we use
|
||||
# the convention of pointing to the event before the gap. Hence
|
||||
# we have a bit of asymmetry when it comes to equalities.
|
||||
args = [room_id]
|
||||
if direction == 'b':
|
||||
order = "DESC"
|
||||
bounds = upper_bound(
|
||||
RoomStreamToken.parse(from_key), self.database_engine
|
||||
)
|
||||
if to_key:
|
||||
bounds = "%s AND %s" % (bounds, lower_bound(
|
||||
RoomStreamToken.parse(to_key), self.database_engine
|
||||
))
|
||||
else:
|
||||
order = "ASC"
|
||||
bounds = lower_bound(
|
||||
RoomStreamToken.parse(from_key), self.database_engine
|
||||
)
|
||||
if to_key:
|
||||
bounds = "%s AND %s" % (bounds, upper_bound(
|
||||
RoomStreamToken.parse(to_key), self.database_engine
|
||||
))
|
||||
|
||||
if int(limit) > 0:
|
||||
args.append(int(limit))
|
||||
limit_str = " LIMIT ?"
|
||||
else:
|
||||
limit_str = ""
|
||||
|
||||
sql = (
|
||||
"SELECT * FROM event_files"
|
||||
" WHERE room_id = ? AND %(bounds)s"
|
||||
" ORDER BY topological_ordering %(order)s,"
|
||||
" stream_ordering %(order)s %(limit)s"
|
||||
) % {
|
||||
"bounds": bounds,
|
||||
"order": order,
|
||||
"limit": limit_str
|
||||
}
|
||||
|
||||
def f(txn):
|
||||
txn.execute(sql, args)
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
if rows:
|
||||
topo = rows[-1]["topological_ordering"]
|
||||
toke = rows[-1]["stream_ordering"]
|
||||
if direction == 'b':
|
||||
# Tokens are positions between events.
|
||||
# This token points *after* the last event in the chunk.
|
||||
# We need it to point to the event before it in the chunk
|
||||
# when we are going backwards so we subtract one from the
|
||||
# stream part.
|
||||
toke -= 1
|
||||
next_token = str(RoomStreamToken(topo, toke))
|
||||
else:
|
||||
# TODO (erikj): We should work out what to do here instead.
|
||||
next_token = to_key if to_key else from_key
|
||||
|
||||
return rows, next_token,
|
||||
|
||||
rows, token = yield self.runInteraction("paginate_file_events", f)
|
||||
|
||||
events = yield self._get_events(
|
||||
[r["event_id"] for r in rows],
|
||||
get_prev_content=True
|
||||
)
|
||||
|
||||
self._set_before_and_after(events, rows)
|
||||
|
||||
defer.returnValue((events, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
|
||||
rows, token = yield self.get_recent_event_ids_for_room(
|
||||
@@ -487,13 +586,13 @@ class StreamStore(SQLBaseStore):
|
||||
row["topological_ordering"], row["stream_ordering"],)
|
||||
)
|
||||
|
||||
def get_max_topological_token_for_stream_and_room(self, room_id, stream_key):
|
||||
def get_max_topological_token(self, room_id, stream_key):
|
||||
sql = (
|
||||
"SELECT max(topological_ordering) FROM events"
|
||||
" WHERE room_id = ? AND stream_ordering < ?"
|
||||
)
|
||||
return self._execute(
|
||||
"get_max_topological_token_for_stream_and_room", None,
|
||||
"get_max_topological_token", None,
|
||||
sql, room_id, stream_key,
|
||||
).addCallback(
|
||||
lambda r: r[0][0] if r else 0
|
||||
@@ -586,32 +685,60 @@ class StreamStore(SQLBaseStore):
|
||||
retcols=["stream_ordering", "topological_ordering"],
|
||||
)
|
||||
|
||||
stream_ordering = results["stream_ordering"]
|
||||
topological_ordering = results["topological_ordering"]
|
||||
|
||||
query_before = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND (topological_ordering < ?"
|
||||
" OR (topological_ordering = ? AND stream_ordering < ?))"
|
||||
" ORDER BY topological_ordering DESC, stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
token = RoomStreamToken(
|
||||
results["topological_ordering"],
|
||||
results["stream_ordering"],
|
||||
)
|
||||
|
||||
query_after = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND (topological_ordering > ?"
|
||||
" OR (topological_ordering = ? AND stream_ordering > ?))"
|
||||
" ORDER BY topological_ordering ASC, stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
# SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
|
||||
# So we give pass it to SQLite3 as the UNION ALL of the two queries.
|
||||
|
||||
txn.execute(
|
||||
query_before,
|
||||
(
|
||||
room_id, topological_ordering, topological_ordering,
|
||||
stream_ordering, before_limit,
|
||||
query_before = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND topological_ordering < ?"
|
||||
" UNION ALL"
|
||||
" SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
|
||||
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
|
||||
)
|
||||
)
|
||||
before_args = (
|
||||
room_id, token.topological,
|
||||
room_id, token.topological, token.stream,
|
||||
before_limit,
|
||||
)
|
||||
|
||||
query_after = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND topological_ordering > ?"
|
||||
" UNION ALL"
|
||||
" SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
|
||||
" ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
|
||||
)
|
||||
after_args = (
|
||||
room_id, token.topological,
|
||||
room_id, token.topological, token.stream,
|
||||
after_limit,
|
||||
)
|
||||
else:
|
||||
query_before = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND %s"
|
||||
" ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
|
||||
) % (upper_bound(token, self.database_engine, inclusive=False),)
|
||||
|
||||
before_args = (room_id, before_limit)
|
||||
|
||||
query_after = (
|
||||
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||
" WHERE room_id = ? AND %s"
|
||||
" ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
|
||||
) % (lower_bound(token, self.database_engine, inclusive=False),)
|
||||
|
||||
after_args = (room_id, after_limit)
|
||||
|
||||
txn.execute(query_before, before_args)
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
events_before = [r["event_id"] for r in rows]
|
||||
@@ -623,17 +750,11 @@ class StreamStore(SQLBaseStore):
|
||||
))
|
||||
else:
|
||||
start_token = str(RoomStreamToken(
|
||||
topological_ordering,
|
||||
stream_ordering - 1,
|
||||
token.topological,
|
||||
token.stream - 1,
|
||||
))
|
||||
|
||||
txn.execute(
|
||||
query_after,
|
||||
(
|
||||
room_id, topological_ordering, topological_ordering,
|
||||
stream_ordering, after_limit,
|
||||
)
|
||||
)
|
||||
txn.execute(query_after, after_args)
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
events_after = [r["event_id"] for r in rows]
|
||||
@@ -644,10 +765,7 @@ class StreamStore(SQLBaseStore):
|
||||
rows[-1]["stream_ordering"],
|
||||
))
|
||||
else:
|
||||
end_token = str(RoomStreamToken(
|
||||
topological_ordering,
|
||||
stream_ordering,
|
||||
))
|
||||
end_token = str(token)
|
||||
|
||||
return {
|
||||
"before": {
|
||||
|
||||
@@ -68,6 +68,9 @@ class TagsStore(SQLBaseStore):
|
||||
A deferred list of tuples of stream_id int, user_id string,
|
||||
room_id string, tag string and content string.
|
||||
"""
|
||||
if last_id == current_id:
|
||||
defer.returnValue([])
|
||||
|
||||
def get_all_updated_tags_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, room_id"
|
||||
|
||||
@@ -56,7 +56,7 @@ class PaginationConfig(object):
|
||||
|
||||
@classmethod
|
||||
def from_request(cls, request, raise_invalid_params=True,
|
||||
default_limit=None):
|
||||
default_limit=None, default_dir='f'):
|
||||
def get_param(name, default=None):
|
||||
lst = request.args.get(name, [])
|
||||
if len(lst) > 1:
|
||||
@@ -68,7 +68,7 @@ class PaginationConfig(object):
|
||||
else:
|
||||
return default
|
||||
|
||||
direction = get_param("dir", 'f')
|
||||
direction = get_param("dir", default_dir)
|
||||
if direction not in ['f', 'b']:
|
||||
raise SynapseError(400, "'dir' parameter is invalid.")
|
||||
|
||||
|
||||
@@ -22,7 +22,10 @@ Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
|
||||
|
||||
|
||||
def get_domain_from_id(string):
|
||||
return string.split(":", 1)[1]
|
||||
try:
|
||||
return string.split(":", 1)[1]
|
||||
except IndexError:
|
||||
raise SynapseError(400, "Invalid ID: %r", string)
|
||||
|
||||
|
||||
class DomainSpecificString(
|
||||
|
||||
@@ -194,3 +194,85 @@ class Linearizer(object):
|
||||
self.key_to_defer.pop(key, None)
|
||||
|
||||
defer.returnValue(_ctx_manager())
|
||||
|
||||
|
||||
class ReadWriteLock(object):
|
||||
"""A deferred style read write lock.
|
||||
|
||||
Example:
|
||||
|
||||
with (yield read_write_lock.read("test_key")):
|
||||
# do some work
|
||||
"""
|
||||
|
||||
# IMPLEMENTATION NOTES
|
||||
#
|
||||
# We track the most recent queued reader and writer deferreds (which get
|
||||
# resolved when they release the lock).
|
||||
#
|
||||
# Read: We know its safe to acquire a read lock when the latest writer has
|
||||
# been resolved. The new reader is appeneded to the list of latest readers.
|
||||
#
|
||||
# Write: We know its safe to acquire the write lock when both the latest
|
||||
# writers and readers have been resolved. The new writer replaces the latest
|
||||
# writer.
|
||||
|
||||
def __init__(self):
|
||||
# Latest readers queued
|
||||
self.key_to_current_readers = {}
|
||||
|
||||
# Latest writer queued
|
||||
self.key_to_current_writer = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def read(self, key):
|
||||
new_defer = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
||||
curr_writer = self.key_to_current_writer.get(key, None)
|
||||
|
||||
curr_readers.add(new_defer)
|
||||
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
yield curr_writer
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
new_defer.callback(None)
|
||||
self.key_to_current_readers.get(key, set()).discard(new_defer)
|
||||
|
||||
defer.returnValue(_ctx_manager())
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def write(self, key):
|
||||
new_defer = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.get(key, set())
|
||||
curr_writer = self.key_to_current_writer.get(key, None)
|
||||
|
||||
# We wait on all latest readers and writer.
|
||||
to_wait_on = list(curr_readers)
|
||||
if curr_writer:
|
||||
to_wait_on.append(curr_writer)
|
||||
|
||||
# We can clear the list of current readers since the new writer waits
|
||||
# for them to finish.
|
||||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
|
||||
yield defer.gatherResults(to_wait_on)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
new_defer.callback(None)
|
||||
if self.key_to_current_writer[key] == new_defer:
|
||||
self.key_to_current_writer.pop(key)
|
||||
|
||||
defer.returnValue(_ctx_manager())
|
||||
|
||||
@@ -27,10 +27,6 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def registered_user(distributor, user):
|
||||
return distributor.fire("registered_user", user)
|
||||
|
||||
|
||||
def user_left_room(distributor, user, room_id):
|
||||
return preserve_context_over_fn(
|
||||
distributor.fire,
|
||||
|
||||
@@ -25,7 +25,8 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
|
||||
ALL_ALONE = "Empty Room"
|
||||
|
||||
|
||||
def calculate_room_name(room_state, user_id, fallback_to_members=True):
|
||||
def calculate_room_name(room_state, user_id, fallback_to_members=True,
|
||||
fallback_to_single_member=True):
|
||||
"""
|
||||
Works out a user-facing name for the given room as per Matrix
|
||||
spec recommendations.
|
||||
@@ -129,6 +130,8 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True):
|
||||
return name_from_member_event(all_members[0])
|
||||
else:
|
||||
return ALL_ALONE
|
||||
elif len(other_members) == 1 and not fallback_to_single_member:
|
||||
return None
|
||||
else:
|
||||
return descriptor_from_member_events(other_members)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
|
||||
shutil.rmtree(self.dir)
|
||||
|
||||
def test_generate_config_generates_files(self):
|
||||
HomeServerConfig.load_config("", [
|
||||
HomeServerConfig.load_or_generate_config("", [
|
||||
"--generate-config",
|
||||
"-c", self.file,
|
||||
"--report-stats=yes",
|
||||
|
||||
@@ -34,6 +34,8 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
||||
self.generate_config_and_remove_lines_containing("server_name")
|
||||
with self.assertRaises(Exception):
|
||||
HomeServerConfig.load_config("", ["-c", self.file])
|
||||
with self.assertRaises(Exception):
|
||||
HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
||||
|
||||
def test_generates_and_loads_macaroon_secret_key(self):
|
||||
self.generate_config()
|
||||
@@ -54,11 +56,24 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
||||
"was: %r" % (config.macaroon_secret_key,)
|
||||
)
|
||||
|
||||
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
||||
self.assertTrue(
|
||||
hasattr(config, "macaroon_secret_key"),
|
||||
"Want config to have attr macaroon_secret_key"
|
||||
)
|
||||
if len(config.macaroon_secret_key) < 5:
|
||||
self.fail(
|
||||
"Want macaroon secret key to be string of at least length 5,"
|
||||
"was: %r" % (config.macaroon_secret_key,)
|
||||
)
|
||||
|
||||
def test_load_succeeds_if_macaroon_secret_key_missing(self):
|
||||
self.generate_config_and_remove_lines_containing("macaroon")
|
||||
config1 = HomeServerConfig.load_config("", ["-c", self.file])
|
||||
config2 = HomeServerConfig.load_config("", ["-c", self.file])
|
||||
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
||||
self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key)
|
||||
self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key)
|
||||
|
||||
def test_disable_registration(self):
|
||||
self.generate_config()
|
||||
@@ -70,14 +85,17 @@ class ConfigLoadingTestCase(unittest.TestCase):
|
||||
config = HomeServerConfig.load_config("", ["-c", self.file])
|
||||
self.assertFalse(config.enable_registration)
|
||||
|
||||
config = HomeServerConfig.load_or_generate_config("", ["-c", self.file])
|
||||
self.assertFalse(config.enable_registration)
|
||||
|
||||
# Check that either config value is clobbered by the command line.
|
||||
config = HomeServerConfig.load_config("", [
|
||||
config = HomeServerConfig.load_or_generate_config("", [
|
||||
"-c", self.file, "--enable-registration"
|
||||
])
|
||||
self.assertTrue(config.enable_registration)
|
||||
|
||||
def generate_config(self):
|
||||
HomeServerConfig.load_config("", [
|
||||
HomeServerConfig.load_or_generate_config("", [
|
||||
"--generate-config",
|
||||
"-c", self.file,
|
||||
"--report-stats=yes",
|
||||
|
||||
@@ -41,14 +41,15 @@ class RegistrationTestCase(unittest.TestCase):
|
||||
handlers=None,
|
||||
http_client=None,
|
||||
expire_access_token=True)
|
||||
self.auth_handler = Mock(
|
||||
generate_short_term_login_token=Mock(return_value='secret'))
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
self.hs.get_handlers().profile_handler = Mock()
|
||||
self.mock_handler = Mock(spec=[
|
||||
"generate_short_term_login_token",
|
||||
])
|
||||
|
||||
self.hs.get_handlers().auth_handler = self.mock_handler
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
@@ -56,8 +57,6 @@ class RegistrationTestCase(unittest.TestCase):
|
||||
local_part = "someone"
|
||||
display_name = "someone"
|
||||
user_id = "@someone:test"
|
||||
mock_token = self.mock_handler.generate_short_term_login_token
|
||||
mock_token.return_value = 'secret'
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
local_part, display_name, duration_ms)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
@@ -75,8 +74,6 @@ class RegistrationTestCase(unittest.TestCase):
|
||||
local_part = "frank"
|
||||
display_name = "Frank"
|
||||
user_id = "@frank:test"
|
||||
mock_token = self.mock_handler.generate_short_term_login_token
|
||||
mock_token.return_value = 'secret'
|
||||
result_user_id, result_token = yield self.handler.get_or_create_user(
|
||||
local_part, display_name, duration_ms)
|
||||
self.assertEquals(result_user_id, user_id)
|
||||
|
||||
@@ -58,47 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
def tearDown(self):
|
||||
[unpatch() for unpatch in self.unpatches]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_name_and_aliases(self):
|
||||
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
yield self.persist(type="m.room.name", key="", name="name1")
|
||||
yield self.persist(
|
||||
type="m.room.aliases", key="blue", aliases=["#1:blue"]
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
"get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"])
|
||||
)
|
||||
|
||||
# Set the room name.
|
||||
yield self.persist(type="m.room.name", key="", name="name2")
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"])
|
||||
)
|
||||
|
||||
# Set the room aliases.
|
||||
yield self.persist(
|
||||
type="m.room.aliases", key="blue", aliases=["#2:blue"]
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
"get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"])
|
||||
)
|
||||
|
||||
# Leave and join the room clobbering the state.
|
||||
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
|
||||
yield self.persist(
|
||||
type="m.room.member", key=USER_ID, membership="join",
|
||||
reset_state=[create]
|
||||
)
|
||||
yield self.replicate()
|
||||
|
||||
yield self.check(
|
||||
"get_room_name_and_aliases", (ROOM_ID,), (None, [])
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_members(self):
|
||||
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
|
||||
85
tests/util/test_rwlock.py
Normal file
85
tests/util/test_rwlock.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from tests import unittest
|
||||
|
||||
from synapse.util.async import ReadWriteLock
|
||||
|
||||
|
||||
class ReadWriteLockTestCase(unittest.TestCase):
|
||||
|
||||
def _assert_called_before_not_after(self, lst, first_false):
|
||||
for i, d in enumerate(lst[:first_false]):
|
||||
self.assertTrue(d.called, msg="%d was unexpectedly false" % i)
|
||||
|
||||
for i, d in enumerate(lst[first_false:]):
|
||||
self.assertFalse(
|
||||
d.called, msg="%d was unexpectedly true" % (i + first_false)
|
||||
)
|
||||
|
||||
def test_rwlock(self):
|
||||
rwlock = ReadWriteLock()
|
||||
|
||||
key = object()
|
||||
|
||||
ds = [
|
||||
rwlock.read(key), # 0
|
||||
rwlock.read(key), # 1
|
||||
rwlock.write(key), # 2
|
||||
rwlock.write(key), # 3
|
||||
rwlock.read(key), # 4
|
||||
rwlock.read(key), # 5
|
||||
rwlock.write(key), # 6
|
||||
]
|
||||
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
|
||||
with ds[0].result:
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
|
||||
with ds[1].result:
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
self._assert_called_before_not_after(ds, 3)
|
||||
|
||||
with ds[2].result:
|
||||
self._assert_called_before_not_after(ds, 3)
|
||||
self._assert_called_before_not_after(ds, 4)
|
||||
|
||||
with ds[3].result:
|
||||
self._assert_called_before_not_after(ds, 4)
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
|
||||
with ds[5].result:
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
|
||||
with ds[4].result:
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
self._assert_called_before_not_after(ds, 7)
|
||||
|
||||
with ds[6].result:
|
||||
pass
|
||||
|
||||
d = rwlock.write(key)
|
||||
self.assertTrue(d.called)
|
||||
with d.result:
|
||||
pass
|
||||
|
||||
d = rwlock.read(key)
|
||||
self.assertTrue(d.called)
|
||||
with d.result:
|
||||
pass
|
||||
@@ -54,7 +54,9 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
||||
config.trusted_third_party_id_servers = []
|
||||
config.room_invite_state_types = []
|
||||
|
||||
config.use_frozen_dicts = True
|
||||
config.database_config = {"name": "sqlite3"}
|
||||
config.ldap_enabled = False
|
||||
|
||||
if "clock" not in kargs:
|
||||
kargs["clock"] = MockClock()
|
||||
|
||||
4
tox.ini
4
tox.ini
@@ -11,7 +11,7 @@ deps =
|
||||
setenv =
|
||||
PYTHONDONTWRITEBYTECODE = no_byte_code
|
||||
commands =
|
||||
/bin/bash -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
|
||||
/bin/sh -c "find {toxinidir} -name '*.pyc' -delete ; coverage run {env:COVERAGE_OPTS:} --source={toxinidir}/synapse \
|
||||
{envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
|
||||
{env:DUMP_COVERAGE_COMMAND:coverage report -m}
|
||||
|
||||
@@ -26,4 +26,4 @@ skip_install = True
|
||||
basepython = python2.7
|
||||
deps =
|
||||
flake8
|
||||
commands = /bin/bash -c "flake8 synapse tests {env:PEP8SUFFIX:}"
|
||||
commands = /bin/sh -c "flake8 synapse tests {env:PEP8SUFFIX:}"
|
||||
|
||||
Reference in New Issue
Block a user