Compare commits

..

2 Commits

Author SHA1 Message Date
Erik Johnston
9854f4c7ff Add basic file API 2016-07-13 16:29:35 +01:00
Erik Johnston
518b3a3f89 Track in DB file message events 2016-07-13 15:16:02 +01:00
172 changed files with 2687 additions and 7645 deletions

View File

@@ -1,158 +1,3 @@
Changes in synapse v0.17.1-rc1 (2016-08-22)
===========================================
Features:
* Add notification API (PR #1028)
Changes:
* Don't print stack traces when failing to get remote keys (PR #996)
* Various federation /event/ perf improvements (PR #998)
* Only process one local membership event per room at a time (PR #1005)
* Move default display name push rule (PR #1011, #1023)
* Fix up preview URL API. Add tests. (PR #1015)
* Set ``Content-Security-Policy`` on media repo (PR #1021)
* Make notify_interested_services faster (PR #1022)
* Add usage stats to prometheus monitoring (PR #1037)
Bug fixes:
* Fix token login (PR #993)
* Fix CAS login (PR #994, #995)
* Fix /sync to not clobber status_msg (PR #997)
* Fix redacted state events to include prev_content (PR #1003)
* Fix some bugs in the auth/ldap handler (PR #1007)
* Fix backfill request to limit URI length, so that remotes don't reject the
requests due to path length limits (PR #1012)
* Fix AS push code to not send duplicate events (PR #1025)
Changes in synapse v0.17.0 (2016-08-08)
=======================================
This release contains significant security bug fixes regarding authenticating
events received over federation. PLEASE UPGRADE.
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Changes:
* Add federation /version API (PR #990)
* Make psutil dependency optional (PR #992)
Bug fixes:
* Fix URL preview API to exclude HTML comments in description (PR #988)
* Fix error handling of remote joins (PR #991)
Changes in synapse v0.17.0-rc4 (2016-08-05)
===========================================
Changes:
* Change the way we summarize URLs when previewing (PR #973)
* Add new ``/state_ids/`` federation API (PR #979)
* Speed up processing of ``/state/`` response (PR #986)
Bug fixes:
* Fix event persistence when event has already been partially persisted
(PR #975, #983, #985)
* Fix port script to also copy across backfilled events (PR #982)
Changes in synapse v0.17.0-rc3 (2016-08-02)
===========================================
Changes:
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
* Add some basic admin API docs (PR #963)
Bug fixes:
* Send the correct host header when fetching keys (PR #941)
* Fix joining a room that has missing auth events (PR #964)
* Fix various push bugs (PR #966, #970)
* Fix adding emails on registration (PR #968)
Changes in synapse v0.17.0-rc2 (2016-08-02)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08) Changes in synapse v0.16.1-r1 (2016-07-08)
========================================== ==========================================

View File

@@ -14,7 +14,6 @@ recursive-include docs *
recursive-include res * recursive-include res *
recursive-include scripts * recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py recursive-include tests *.py
recursive-include synapse/static *.css recursive-include synapse/static *.css
@@ -24,7 +23,5 @@ recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh exclude jenkins*.sh
exclude jenkins*
recursive-exclude jenkins *.sh
prune demo/etc prune demo/etc

View File

@@ -11,8 +11,8 @@ VoIP. The basics you need to know to get up and running are:
like ``#matrix:matrix.org`` or ``#test:localhost:8448``. like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future - Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
you will normally refer to yourself and others using a third party identifier you will normally refer to yourself and others using a 3PID: email
(3PID): email address, phone number, etc rather than manipulating Matrix user IDs) address, phone number, etc rather than manipulating Matrix user IDs)
The overall architecture is:: The overall architecture is::
@@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
- Python 2.7 - Python 2.7
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org - At least 512 MB RAM.
Synapse is written in python but some of the libraries is uses are written in Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the C. So before we can install synapse itself we need a working C compiler and the
@@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs: IDs:
1) Use the machine's own hostname as available on public DNS in the form of 1) Use the machine's own hostname as available on public DNS in the form of
its A records. This is easier to set up initially, perhaps for its A or AAAA records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV. testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV 2) Set up a SRV record for your domain name. This requires you create a SRV

View File

@@ -27,7 +27,7 @@ running:
# Pull the latest version of the master branch. # Pull the latest version of the master branch.
git pull git pull
# Update the versions of synapse's python dependencies. # Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.15.0 Upgrading to v0.15.0

View File

@@ -1,12 +0,0 @@
Admin APIs
==========
This directory includes documentation for the various synapse specific admin
APIs available.
Only users that are server admins can use these APIs. A user can be marked as a
server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register.

View File

@@ -1,15 +0,0 @@
Purge History API
=================
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.

View File

@@ -1,19 +0,0 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
{
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View File

@@ -1,97 +0,0 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View File

@@ -4,19 +4,84 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
./dendron/jenkins/build_dendron.sh # Write coverage reports to a separate file for each process
./sytest/jenkins/prep_sytest_for_postgres.sh export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/install_and_run.sh \ # Output flake8 violations to violations.flake8.log
--synapse-directory $WORKSPACE \ # Don't exit with non-0 status code on Jenkins,
--dendron $WORKSPACE/dendron/bin/dendron \ # so that the build steps continue and a later step can decided whether to
--pusher \ # UNSTABLE or FAILURE this build.
--synchrotron \ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
--federation-reader \
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .dendron-base ]]; then
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
else
(cd .dendron-base; git fetch -p)
fi
rm -rf dendron
git clone .dendron-base dendron --shared
cd dendron
: ${GOPATH:=${WORKSPACE}/.gopath}
if [[ "${GOPATH}" != *:* ]]; then
mkdir -p "${GOPATH}"
export PATH="${GOPATH}/bin:${PATH}"
fi
export GOPATH
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
go get github.com/constabulary/gb/...
gb generate
gb build
cd ..
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
mkdir -p var
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..

View File

@@ -4,14 +4,61 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/prep_sytest_for_postgres.sh # Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
./sytest/jenkins/install_and_run.sh \ rm .coverage* || echo "No coverage files to remove"
--synapse-directory $WORKSPACE \
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -4,12 +4,55 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/install_and_run.sh \ # Output flake8 violations to violations.flake8.log
--synapse-directory $WORKSPACE \ # Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_COUNT=20}
: ${PORT_BASE:=8000}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -22,9 +22,4 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove" rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27 tox -e py27

View File

@@ -1,44 +0,0 @@
#! /bin/bash
# This clones a project from github into a named subdirectory
# If the project has a branch with the same name as this branch
# then it will checkout that branch after cloning.
# Otherwise it will checkout "origin/develop."
# The first argument is the name of the directory to checkout
# the branch into.
# The second argument is the URL of the remote repository to checkout.
# Usually something like https://github.com/matrix-org/sytest.git
set -eux
NAME=$1
PROJECT=$2
BASE=".$NAME-base"
# Update our mirror.
if [ ! -d ".$NAME-base" ]; then
# Create a local mirror of the source repository.
# This saves us from having to download the entire repository
# when this script is next run.
git clone "$PROJECT" "$BASE" --mirror
else
# Fetch any updates from the source repository.
(cd "$BASE"; git fetch -p)
fi
# Remove the existing repository so that we have a clean copy
rm -rf "$NAME"
# Cloning with --shared means that we will share portions of the
# .git directory with our local mirror.
git clone "$BASE" "$NAME" --shared
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd "$NAME"
# check out the relevant branch
git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop"
)

View File

@@ -1,20 +0,0 @@
#! /bin/bash
cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox
mkdir -p $TOX_DIR
if ! [ $TOX_DIR -ef .tox ]; then
ln -s "$TOX_DIR" .tox
fi
# set up the virtualenv
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View File

@@ -116,19 +116,17 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = [] authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items(): for key, sig in signed_json["signatures"][origin_name].items():
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( authorization_headers.append(bytes(
origin_name, key, sig, "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
) origin_name, key, sig,
authorization_headers.append(bytes(header)) )
sys.stderr.write(header) ))
sys.stderr.write("\n")
result = requests.get( result = requests.get(
lookup(destination, path), lookup(destination, path),
headers={"Authorization": authorization_headers[0]}, headers={"Authorization": authorization_headers[0]},
verify=False, verify=False,
) )
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json() return result.json()
@@ -143,7 +141,6 @@ def main():
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)
print ""
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"], "events": ["processed", "outlier"],
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
@@ -92,12 +92,8 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__[ _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
"_simple_select_one_onecol_txn"
]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@@ -162,40 +158,31 @@ class Porter(object):
def setup_table(self, table): def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
row = yield self.postgres_store._simple_select_one( next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"), retcol="rowid",
allow_none=True, allow_none=True,
) )
total_to_port = None total_to_port = None
if row is None: if next_chunk is None:
if table == "sent_transactions": if table == "sent_transactions":
forward_chunk, already_ported, total_to_port = ( next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions() yield self._setup_sent_transactions()
) )
backward_chunk = 0
else: else:
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "rowid": 1}
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 next_chunk = 1
backward_chunk = 0
already_ported = 0 already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, next_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
@@ -209,85 +196,46 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "rowid": 0}
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 next_chunk = 1
backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, next_chunk
) )
defer.returnValue( defer.returnValue((table, already_ported, total_to_port, next_chunk))
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk, def handle_table(self, table, postgres_size, table_size, next_chunk):
backward_chunk):
if not table_size: if not table_size:
return return
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search": if table == "event_search":
yield self.handle_search_table( yield self.handle_search_table(postgres_size, table_size, next_chunk)
postgres_size, table_size, forward_chunk, backward_chunk
)
return return
forward_select = ( select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,) % (table,)
) )
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
)
do_forward = [True]
do_backward = [True]
while True: while True:
def r(txn): def r(txn):
forward_rows = [] txn.execute(select, (next_chunk, self.batch_size,))
backward_rows = [] rows = txn.fetchall()
if do_forward[0]: headers = [column[0] for column in txn.description]
txn.execute(forward_select, (forward_chunk, self.batch_size,))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
if do_backward[0]: return headers, rows
txn.execute(backward_select, (backward_chunk, self.batch_size,))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
if forward_rows or backward_rows: headers, rows = yield self.sqlite_store.runInteraction("select", r)
headers = [column[0] for column in txn.description]
else:
headers = None
return headers, forward_rows, backward_rows if rows:
next_chunk = rows[-1][0] + 1
headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
if frows or brows:
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows
self._convert_rows(table, headers, rows) self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
@@ -299,10 +247,7 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
updatevalues={ updatevalues={"rowid": next_chunk},
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@@ -314,8 +259,7 @@ class Porter(object):
return return
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk, def handle_search_table(self, postgres_size, table_size, next_chunk):
backward_chunk):
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@@ -326,7 +270,7 @@ class Porter(object):
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,)) txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@@ -335,7 +279,7 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r) headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows: if rows:
forward_chunk = rows[-1][0] + 1 next_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a # We have to treat event_search differently since it has a
# different structure in the two different databases. # different structure in the two different databases.
@@ -368,10 +312,7 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": "event_search"}, keyvalues={"table_name": "event_search"},
updatevalues={ updatevalues={"rowid": next_chunk},
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@@ -383,6 +324,7 @@ class Porter(object):
else: else:
return return
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
@@ -453,32 +395,10 @@ class Porter(object):
txn.execute( txn.execute(
"CREATE TABLE port_from_sqlite3 (" "CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
" forward_rowid bigint NOT NULL," " rowid bigint NOT NULL"
" backward_rowid bigint NOT NULL"
")" ")"
) )
# The old port script created a table with just a "rowid" column.
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
)
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
@@ -538,7 +458,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _setup_sent_transactions(self): def _setup_sent_transactions(self):
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000 yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination # And save the max transaction id from each destination
select = ( select = (
@@ -594,11 +514,7 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": "sent_transactions", "rowid": next_chunk}
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
@@ -619,18 +535,13 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count)) defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): def _get_remaining_count_to_port(self, table, next_chunk):
frows = yield self.sqlite_store.execute_sql( rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
forward_chunk, next_chunk,
) )
brows = yield self.sqlite_store.execute_sql( defer.returnValue(rows[0][0])
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
)
defer.returnValue(frows[0][0] + brows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
@@ -641,10 +552,10 @@ class Porter(object):
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults( remaining, done = yield defer.gatherResults(
[ [
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk), self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table), self._get_already_ported_count(table),
], ],
consumeErrors=True, consumeErrors=True,
@@ -775,7 +686,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len - len(table), i+2, left_margin + max_len - len(table),
table, table,
curses.A_BOLD | color, curses.A_BOLD | color,
) )
@@ -783,18 +694,18 @@ class CursesProgress(Progress):
size = 20 size = 20
progress = "[%s%s]" % ( progress = "[%s%s]" % (
"#" * int(perc * size / 100), "#" * int(perc*size/100),
" " * (size - int(perc * size / 100)), " " * (size - int(perc*size/100)),
) )
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space, i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(
rows - 1, 0, rows-1, 0,
"Press any key to exit...", "Press any key to exit...",
) )

View File

@@ -16,5 +16,7 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503
[pep8]
max-line-length = 90

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.17.1-rc1" __version__ = "0.16.1-r1"

View File

@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import pymacaroons
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64
import logging
import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ class Auth(object):
"user_id = ", "user_id = ",
]) ])
def check(self, event, auth_events, do_sig_check=True): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@@ -79,13 +79,6 @@ class Auth(object):
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender)
# Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now) # are trusting that this is allowed (at least for now)
@@ -93,12 +86,6 @@ class Auth(object):
return True return True
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME # FIXME
return True return True
@@ -121,22 +108,6 @@ class Auth(object):
# FIXME: Temp hack # FIXME: Temp hack
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True return True
logger.debug( logger.debug(
@@ -376,10 +347,6 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content: if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events): if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.") raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
@@ -570,7 +537,9 @@ class Auth(object):
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object tuple of:
UserID (str)
Access token ID (str)
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@@ -579,7 +548,9 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args) user_id = yield self._get_appservice_user_id(request.args)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id)) defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(access_token, rights)
@@ -587,10 +558,6 @@ class Auth(object):
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
"User-Agent", "User-Agent",
@@ -602,8 +569,7 @@ class Auth(object):
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent, user_agent=user_agent
device_id=device_id,
) )
if is_guest and not allow_guest: if is_guest and not allow_guest:
@@ -613,8 +579,7 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(synapse.types.create_requester( defer.returnValue(Requester(user, token_id, is_guest))
user, token_id, is_guest, device_id))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -664,10 +629,7 @@ class Auth(object):
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token) ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -675,25 +637,33 @@ class Auth(object):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
user_id = self.get_user_id_from_macaroon(macaroon) user_prefix = "user_id = "
user = UserID.from_string(user_id) user = None
user_id = None
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user_id = caveat.caveat_id[len(user_prefix):]
user = UserID.from_string(user_id)
elif caveat.caveat_id == "guest = true":
guest = True
self.validate_macaroon( self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token, macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id, user_id=user_id,
) )
guest = False if user is None:
for caveat in macaroon.caveats: raise AuthError(
if caveat.caveat_id == "guest = true": self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
guest = True errcode=Codes.UNKNOWN_TOKEN
)
if guest: if guest:
ret = { ret = {
"user": user, "user": user,
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
"device_id": None,
} }
elif rights == "delete_pusher": elif rights == "delete_pusher":
# We don't store these tokens in the database # We don't store these tokens in the database
@@ -701,20 +671,13 @@ class Auth(object):
"user": user, "user": user,
"is_guest": False, "is_guest": False,
"token_id": None, "token_id": None,
"device_id": None,
} }
else: else:
# This codepath exists for several reasons: # This codepath exists so that we can actually return a
# * so that we can actually return a token ID, which is used # token ID, because we use token IDs in place of device
# in some parts of the schema (where we probably ought to # identifiers throughout the codebase.
# use device IDs instead) # TODO(daniel): Remove this fallback when device IDs are
# * the only way we currently have to invalidate an # properly implemented.
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str) ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user: if ret["user"] != user:
logger.error( logger.error(
@@ -734,29 +697,6 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
AuthError if there is no user_id caveat in the macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
@@ -768,7 +708,6 @@ class Auth(object):
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet. token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
@@ -812,14 +751,10 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False, "is_guest": False,
"device_id": ret.get("device_id"),
} }
defer.returnValue(user_info) defer.returnValue(user_info)

View File

@@ -43,7 +43,6 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE" THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"

View File

@@ -191,17 +191,6 @@ class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
@@ -220,10 +209,9 @@ class Filter(object):
event.get("room_id", None), event.get("room_id", None),
sender, sender,
event.get("type", None), event.get("type", None),
"url" in event.get("content", {})
) )
def check_fields(self, room_id, sender, event_type, contains_url): def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
Returns: Returns:
@@ -237,20 +225,15 @@ class Filter(object):
for name, match_func in literal_keys.items(): for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,) not_name = "not_%s" % (name,)
disallowed_values = getattr(self, not_name) disallowed_values = self.filter_json.get(not_name, [])
if any(map(match_func, disallowed_values)): if any(map(match_func, disallowed_values)):
return False return False
allowed_values = getattr(self, name) allowed_values = self.filter_json.get(name, None)
if allowed_values is not None: if allowed_values is not None:
if not any(map(match_func, allowed_values)): if not any(map(match_func, allowed_values)):
return False return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True return True
def filter_rooms(self, room_ids): def filter_rooms(self, room_ids):

View File

@@ -16,11 +16,13 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse import python_dependencies # noqa: E402 from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try: try:
python_dependencies.check_requirements() check_requirements()
except python_dependencies.MissingRequirementError as e: except MissingRequirementError as e:
message = "\n".join([ message = "\n".join([
"Missing Requirement: %s" % (e.message,), "Missing Requirement: %s" % (e.message,),
"To install run:", "To install run:",

View File

@@ -1,209 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -1,206 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
TransactionStore,
BaseSlavedStore,
):
pass
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -51,7 +51,6 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
@@ -285,7 +284,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
version_string = "Synapse/" + get_version_string(synapse) version_string = get_version_string("Synapse", synapse)
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)
@@ -336,8 +335,6 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs)
reactor.callWhenRunning(start) reactor.callWhenRunning(start)
return hs return hs
@@ -385,8 +382,6 @@ def run(hs):
start_time = hs.get_clock().time() start_time = hs.get_clock().time()
stats = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def phone_stats_home(): def phone_stats_home():
logger.info("Gathering stats for reporting") logger.info("Gathering stats for reporting")
@@ -395,10 +390,7 @@ def run(hs):
if uptime < 0: if uptime < 0:
uptime = 0 uptime = 0
# If the stats directory is empty then this is the first time we've stats = {}
# reported stats.
first_time = not stats
stats["homeserver"] = hs.config.server_name stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now stats["timestamp"] = now
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime
@@ -411,25 +403,6 @@ def run(hs):
daily_messages = yield hs.get_datastore().count_daily_messages() daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None: if daily_messages is not None:
stats["daily_messages"] = daily_messages stats["daily_messages"] = daily_messages
else:
stats.pop("daily_messages", None)
if first_time:
# Add callbacks to report the synapse stats as metrics whenever
# prometheus requests them, typically every 30s.
# As some of the stats are expensive to calculate we only update
# them when synapse phones home to matrix.org every 24 hours.
metrics = get_metrics_for("synapse.usage")
metrics.add_callback("timestamp", lambda: stats["timestamp"])
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
metrics.add_callback("total_users", lambda: stats["total_users"])
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
metrics.add_callback(
"daily_active_users", lambda: stats["daily_active_users"]
)
metrics.add_callback(
"daily_messages", lambda: stats.get("daily_messages", 0)
)
logger.info("Reporting stats to matrix.org: %s" % (stats,)) logger.info("Reporting stats to matrix.org: %s" % (stats,))
try: try:

View File

@@ -1,212 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -80,6 +80,11 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__ DataStore.get_profile_displayname.__func__
) )
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@@ -163,6 +168,7 @@ class PusherServer(HomeServer):
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool() pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey): def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
@@ -214,11 +220,21 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids min_stream_id, max_stream_id, affected_room_ids
) )
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
poke_pushers(result) poke_pushers(result)
except: except:
@@ -257,7 +273,7 @@ def start(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string=get_version_string("Synapse", synapse),
database_engine=database_engine, database_engine=database_engine,
) )

View File

@@ -26,7 +26,6 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@@ -75,6 +74,11 @@ class SynchrotronSlavedStore(
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@@ -85,23 +89,17 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[ get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted" "get_presence_list_accepted"
] ]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000 UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object): class SynchrotronPresence(object):
def __init__(self, hs): def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info() active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = { self.user_to_current_state = {
@@ -121,13 +119,11 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state, ignore_status_msg=False): def set_state(self, user, state):
# TODO Hows this supposed to work? # TODO Hows this supposed to work?
pass pass
get_states = PresenceHandler.get_states.__func__ get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -198,39 +194,19 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False self._need_to_send_sync = False
yield self._send_syncing_users_now() yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result): def process_replication(self, result):
stream = result.get("presence", {"rows": []}) stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]: for row in stream["rows"]:
( (
position, user_id, state, last_active_ts, position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) = row ) = row
state = UserPresenceState( self.user_to_current_state[user_id] = UserPresenceState(
user_id, state, last_active_ts, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) )
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object): class SynchrotronTyping(object):
@@ -290,12 +266,10 @@ class SynchrotronServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource) sync.register_servlets(self, resource)
events.register_servlets(self, resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource, "/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
@@ -333,10 +307,15 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client() http_client = self.get_simple_http_client()
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier() notifier = self.get_notifier()
presence_handler = self.get_presence_handler() presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler() typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream( def notify_from_stream(
result, stream_name, stream_key, room=None, user=None result, stream_name, stream_key, room=None, user=None
): ):
@@ -398,15 +377,22 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id" result, "typing", "typing_key", room="room_id"
) )
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args.update(typing_handler.stream_positions()) args.update(typing_handler.stream_positions())
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
yield presence_handler.process_replication(result) presence_handler.process_replication(result)
notify(result) notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)
@@ -438,7 +424,7 @@ def start(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string=get_version_string("Synapse", synapse),
database_engine=database_engine, database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(), application_service_handler=SynchrotronApplicationService(),
) )

View File

@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging import logging
import re import re
@@ -81,17 +79,13 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None): sender=None, id=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
@@ -144,66 +138,65 @@ class ApplicationService(object):
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
@defer.inlineCallbacks def _matches_user(self, event, member_list):
def _matches_user(self, event, store): if (hasattr(event, "sender") and
if not event: self.is_interested_in_user(event.sender)):
defer.returnValue(False) return True
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key # also check m.room.member state key
if (event.type == EventTypes.Member and if (hasattr(event, "type") and event.type == EventTypes.Member
self.is_interested_in_user(event.state_key)): and hasattr(event, "state_key")
defer.returnValue(True) and self.is_interested_in_user(event.state_key)):
return True
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
# check joined member events # check joined member events
for user_id in member_list: for user_id in member_list:
if self.is_interested_in_user(user_id): if self.is_interested_in_user(user_id):
defer.returnValue(True) return True
defer.returnValue(False) return False
def _matches_room_id(self, event): def _matches_room_id(self, event):
if hasattr(event, "room_id"): if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
@defer.inlineCallbacks def _matches_aliases(self, event, alias_list):
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list: for alias in alias_list:
if self.is_interested_in_alias(alias): if self.is_interested_in_alias(alias):
defer.returnValue(True) return True
defer.returnValue(False) return False
@defer.inlineCallbacks def is_interested(self, event, restrict_to=None, aliases_for_event=None,
def is_interested(self, event, store=None): member_list=None):
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event(Event): The event to check.
store(DataStore) restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
Returns: Returns:
bool: True if this service would like to know about this event. bool: True if this service would like to know about this event.
""" """
# Do cheap checks first if aliases_for_event is None:
if self._matches_room_id(event): aliases_for_event = []
defer.returnValue(True) if member_list is None:
member_list = []
if (yield self._matches_aliases(event, store)): if restrict_to and restrict_to not in ApplicationService.NS_LIST:
defer.returnValue(True) # this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if (yield self._matches_user(event, store)): if not restrict_to:
defer.returnValue(True) return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
defer.returnValue(False) or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id): def is_interested_in_user(self, user_id):
return ( return (
@@ -223,9 +216,6 @@ class ApplicationService(object):
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View File

@@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging import logging
import urllib import urllib
@@ -25,28 +24,6 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient): class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and """This class manages HS -> AS communications, including querying and
pushing. pushing.
@@ -94,43 +71,6 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex) logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)

View File

@@ -48,12 +48,9 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from synapse.util.logcontext import preserve_fn from twisted.internet import defer
from synapse.util.metrics import Measure
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -76,7 +73,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController( self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer self.clock, self.store, self.as_api, create_recoverer
) )
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) self.queuer = _ServiceQueuer(self.txn_ctrl)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@@ -97,36 +94,38 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run. this schedules any other events in the queue to run.
""" """
def __init__(self, txn_ctrl, clock): def __init__(self, txn_ctrl):
self.queued_events = {} # dict of {service_id: [events]} self.queued_events = {} # dict of {service_id: [events]}
self.requests_in_flight = set() self.pending_requests = {} # dict of {service_id: Deferred}
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event): def enqueue(self, service, event):
# if this service isn't being sent something # if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event) if not self.pending_requests.get(service.id):
preserve_fn(self._send_request)(service) self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
@defer.inlineCallbacks def _send_request(self, service, events):
def _send_request(self, service): # send request and add callbacks
if service.id in self.requests_in_flight: d = self.txn_ctrl.send(service, events)
return d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
self.requests_in_flight.add(service.id) def _on_request_finish(self, service):
try: self.pending_requests[service.id] = None
while True: # if there are queued events, then send them.
events = self.queued_events.pop(service.id, []) if (service.id in self.queued_events
if not events: and len(self.queued_events[service.id]) > 0):
return self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
with Measure(self.clock, "servicequeuer.send"): def _on_request_fail(self, err):
try: logger.error("AS request failed: %s", err)
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
class _TransactionController(object): class _TransactionController(object):
@@ -156,6 +155,8 @@ class _TransactionController(object):
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
self._start_recoverer(service) self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):

View File

@@ -28,7 +28,6 @@ class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs): def default_config(cls, **kwargs):
return """\ return """\
@@ -123,15 +122,6 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError( raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj "Missing/bad type 'exclusive' key in %s", regex_obj
) )
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],
@@ -139,5 +129,4 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols,
) )

View File

@@ -77,12 +77,10 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self): def __init__(self):
self.remote_key = defer.Deferred() self.remote_key = defer.Deferred()
self.host = None self.host = None
self._peer = None
def connectionMade(self): def connectionMade(self):
self._peer = self.transport.getPeer() self.host = self.transport.getHost()
logger.debug("Connected to %s", self._peer) logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host: if self.host:
self.sendHeader(b"Host", self.host) self.sendHeader(b"Host", self.host)
@@ -126,10 +124,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug( logger.debug("Timeout waiting for response from %s", self.host)
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response")) self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
@@ -138,5 +133,4 @@ class SynapseKeyClientFactory(Factory):
def protocol(self): def protocol(self):
protocol = SynapseKeyClientProtocol() protocol = SynapseKeyClientProtocol()
protocol.path = self.path protocol.path = self.path
protocol.host = self.host
return protocol return protocol

View File

@@ -22,7 +22,6 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn preserve_fn
) )
from synapse.util.metrics import Measure
from twisted.internet import defer from twisted.internet import defer
@@ -45,25 +44,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple("VerifyRequest", ( KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class KeyLookupError(ValueError):
pass
class Keyring(object): class Keyring(object):
@@ -93,32 +74,39 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each list of deferreds indicating success or failure to verify each
json object's signature for the given server_name. json object's signature for the given server_name.
""" """
verify_requests = [] group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json: for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
deferred = defer.fail(SynapseError( deferreds[group_id] = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
else: else:
deferred = defer.Deferred() deferreds[group_id] = defer.Deferred()
verify_request = VerifyKeyRequest( group = KeyGroup(server_name, group_id, key_ids)
server_name, key_ids, json_object, deferred
)
verify_requests.append(verify_request) group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_key_deferred(verify_request): def handle_key_deferred(group, deferred):
server_name = verify_request.server_name server_name = group.server_name
try: try:
_, key_id, verify_key = yield verify_request.deferred _, _, key_id, verify_key = yield deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@@ -140,7 +128,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = verify_request.json_object json_object = group_id_to_json[group.group_id]
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
@@ -169,34 +157,36 @@ class Keyring(object):
# Actually start fetching keys. # Actually start fetching keys.
wait_on_deferred.addBoth( wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(verify_requests) lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
) )
# When we've finished fetching all the keys for a given server_name, # When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed. # any lookups waiting will proceed.
server_to_request_ids = {} server_to_gids = {}
def remove_deferreds(res, server_name, verify_request): def remove_deferreds(res, server_name, group_id):
request_id = id(verify_request) server_to_gids[server_name].discard(group_id)
server_to_request_ids[server_name].discard(request_id) if not server_to_gids[server_name]:
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) d = server_to_deferred.pop(server_name, None)
if d: if d:
d.callback(None) d.callback(None)
return res return res
for verify_request in verify_requests: for g_id, deferred in deferreds.items():
server_name = verify_request.server_name server_name = group_id_to_group[g_id].server_name
request_id = id(verify_request) server_to_gids.setdefault(server_name, set()).add(g_id)
server_to_request_ids.setdefault(server_name, set()).add(request_id) deferred.addBoth(remove_deferreds, server_name, g_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
preserve_context_over_fn(handle_key_deferred, verify_request) preserve_context_over_fn(
for verify_request in verify_requests handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -230,7 +220,7 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for """Takes a dict of KeyGroups and tries to find at least one key for
each group. each group.
""" """
@@ -244,68 +234,65 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): merged_results = {}
merged_results = {}
missing_keys = {} missing_keys = {}
for verify_request in verify_requests: for group in group_id_to_group.values():
missing_keys.setdefault(verify_request.server_name, set()).update( missing_keys.setdefault(group.server_name, set()).update(
verify_request.key_ids group.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
) )
for server_name, groups in missing_groups.items()
}
for fn in key_fetch_fns: for group in missing_groups.values():
results = yield fn(missing_keys.items()) group_id_to_deferred[group.group_id].errback(SynapseError(
merged_results.update(results) 401,
"No key for %s with id %s" % (
# We now need to figure out which verify requests we have keys group.server_name, group.key_ids,
# for and which we don't ),
missing_keys = {} Codes.UNAUTHORIZED,
requests_missing_keys = [] ))
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext():
verify_request.deferred.callback((
server_name,
key_id,
result_keys[key_id],
))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys:
break
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err): def on_err(err):
for verify_request in verify_requests: for deferred in group_id_to_deferred.values():
if not verify_request.deferred.called: if not deferred.called:
verify_request.deferred.errback(err) deferred.errback(err)
do_iterations().addErrback(on_err) do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults( res = yield defer.gatherResults(
@@ -369,7 +356,7 @@ class Keyring(object):
) )
except Exception as e: except Exception as e:
logger.info( logger.info(
"Unable to get key %r for %r directly: %s %s", "Unable to getting key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
@@ -431,7 +418,7 @@ class Keyring(object):
for response in responses: for response in responses:
if (u"signatures" not in response if (u"signatures" not in response
or perspective_name not in response[u"signatures"]): or perspective_name not in response[u"signatures"]):
raise KeyLookupError( raise ValueError(
"Key response not signed by perspective server" "Key response not signed by perspective server"
" %r" % (perspective_name,) " %r" % (perspective_name,)
) )
@@ -454,13 +441,13 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]), list(response[u"signatures"][perspective_name]),
list(perspective_keys) list(perspective_keys)
) )
raise KeyLookupError( raise ValueError(
"Response not signed with a known key for perspective" "Response not signed with a known key for perspective"
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
processed_response = yield self.process_v2_response( processed_response = yield self.process_v2_response(
perspective_name, response, only_from_server=False perspective_name, response
) )
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
@@ -497,10 +484,10 @@ class Keyring(object):
if (u"signatures" not in response if (u"signatures" not in response
or server_name not in response[u"signatures"]): or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server") raise ValueError("Key response not signed by remote server")
if "tls_fingerprints" not in response: if "tls_fingerprints" not in response:
raise KeyLookupError("Key response missing TLS fingerprints") raise ValueError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate( certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate crypto.FILETYPE_ASN1, tls_certificate
@@ -514,7 +501,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"]) response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints: if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise KeyLookupError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
@@ -540,7 +527,7 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_ids=[], only_from_server=True): requested_ids=[]):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@@ -564,16 +551,9 @@ class Keyring(object):
results = {} results = {}
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise KeyLookupError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@@ -641,15 +621,15 @@ class Keyring(object):
if ("signatures" not in response if ("signatures" not in response
or server_name not in response["signatures"]): or server_name not in response["signatures"]):
raise KeyLookupError("Key response not signed by remote server") raise ValueError("Key response not signed by remote server")
if "tls_certificate" not in response: if "tls_certificate" not in response:
raise KeyLookupError("Key response missing TLS certificate") raise ValueError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"] tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64: if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise KeyLookupError("TLS certificate doesn't match") raise ValueError("TLS certificate doesn't match")
# Cache the result in the datastore. # Cache the result in the datastore.
@@ -665,7 +645,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]: for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]: if key_id not in response["verify_keys"]:
raise KeyLookupError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )

View File

@@ -88,8 +88,6 @@ def prune_event(event):
if "age_ts" in event.unsigned: if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)( return type(event)(
allowed_fields, allowed_fields,

View File

@@ -51,34 +51,10 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationClient, self).__init__(hs) super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
old_dict = self.pdu_destination_tried
self.pdu_destination_tried = {}
for event_id, destination_dict in old_dict.items():
destination_dict = {
dest: time
for dest, time in destination_dict.items()
if time + PDU_RETRY_TIME_MS > now
}
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self): def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
@@ -260,19 +236,12 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache: if self._get_pdu_cache:
ev = self._get_pdu_cache.get(event_id) e = self._get_pdu_cache.get(event_id)
if ev: if e:
defer.returnValue(ev) defer.returnValue(e)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
if last_attempt + PDU_RETRY_TIME_MS > now:
continue
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
@@ -300,19 +269,25 @@ class FederationClient(FederationBase):
break break
pdu_attempts[destination] = now except SynapseError:
except SynapseError as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
) )
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
@@ -339,42 +314,6 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
try:
# First we try and ask for just the IDs, as thats far quicker if
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id,
)
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events(
[destination], room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
logger.warn("Failed to get %r", failed_to_fetch)
event_map = {
ev.event_id: ev for ev in fetched_events
}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue((pdus, auth_chain))
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
else:
raise e
result = yield self.transport_layer.get_room_state( result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id, destination, room_id, event_id=event_id,
) )
@@ -388,93 +327,18 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
seen_events = yield self.store.get_events([
ev.event_id for ev in itertools.chain(pdus, auth_chain)
])
signed_pdus = yield self._check_sigs_and_hash_and_fetch( signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, destination, pdus, outlier=True
[p for p in pdus if p.event_id not in seen_events],
outlier=True
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
) )
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, destination, auth_chain, outlier=True
[p for p in auth_chain if p.event_id not in seen_events],
outlier=True
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth)) defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
def get_events(self, destinations, room_id, event_ids, return_local=True):
"""Fetch events from some remote destinations, checking if we already
have them.
Args:
destinations (list)
room_id (str)
event_ids (list)
return_local (bool): Whether to include events we already have in
the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
signed_events = []
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
def random_server_list():
srvs = list(destinations)
random.shuffle(srvs)
return srvs
batch_size = 20
missing_events = list(missing_events)
for i in xrange(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
self.get_pdu(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for success, result in res:
if success:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
defer.returnValue((signed_events, failed_to_fetch))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):
@@ -550,19 +414,14 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )
break break
except CodeMessageException as e: except CodeMessageException:
if not 500 <= e.code < 600: raise
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_%s via %s: %s", "Failed to make_%s via %s: %s",
membership, destination, e.message membership, destination, e.message
) )
raise
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@@ -634,14 +493,8 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException as e: except CodeMessageException:
if not 500 <= e.code < 600: raise
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",

View File

@@ -21,11 +21,10 @@ from .units import Transaction, Edu
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
@@ -49,15 +48,9 @@ class FederationServer(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationServer, self).__init__(hs) super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer() self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer() self._server_linearizer = Linearizer()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are receipt of new PDUs from other home servers. The required methods are
@@ -195,71 +188,33 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, origin, room_id, event_id): def on_context_state_request(self, origin, room_id, event_id):
if not event_id: with (yield self._server_linearizer.queue((origin, room_id))):
raise NotImplementedError("Specify an event") if event_id:
pdus = yield self.handler.get_state_for_pdu(
in_room = yield self.auth.check_host_in_room(room_id, origin) origin, room_id, event_id,
if not in_room: )
raise AuthError(403, "Host not in room.") auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
) )
else:
resp = yield result
defer.returnValue((200, resp)) for event in auth_chain:
# We sign these again because there was a bug where we
@defer.inlineCallbacks # incorrectly signed things the first time round
def on_state_ids_request(self, origin, room_id, event_id): if self.hs.is_mine_id(event.event_id):
if not event_id: event.signatures.update(
raise NotImplementedError("Specify an event") compute_event_signature(
event,
in_room = yield self.auth.check_host_in_room(room_id, origin) self.hs.hostname,
if not in_room: self.hs.config.signing_key[0]
raise AuthError(403, "Host not in room.") )
)
pdus = yield self.handler.get_state_for_pdu( else:
room_id, event_id, raise NotImplementedError("Specify an event")
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
defer.returnValue((200, { defer.returnValue((200, {
"pdu_ids": [pdu.event_id for pdu in pdus],
"auth_chain_ids": [pdu.event_id for pdu in auth_chain],
}))
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus], "pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}) }))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -393,9 +348,27 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content) query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -605,7 +578,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
except: except:
logger.exception("Failed to get state for event: %s", pdu.event_id) logger.warn("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu( yield self.handler.on_receive_pdu(
origin, origin,

View File

@@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import ( from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func
import synapse.metrics import synapse.metrics
import logging import logging
@@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.clock = hs.get_clock() self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track # Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are # of which destinations have transactions in flight and when they are
@@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self._clock.time_msec())
def can_send_to(self, destination): def can_send_to(self, destination):
"""Can we send messages to the given server? """Can we send messages to the given server?
@@ -119,215 +119,266 @@ class TransactionQueue(object):
if not destinations: if not destinations:
return return
deferreds = []
for destination in destinations: for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append( self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order) (pdu, deferred, order)
) )
preserve_context_over_fn( def chain(failure):
self._attempt_new_transaction, destination if not deferred.called:
) deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
# NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
self.pending_edus_by_dest.setdefault(destination, []).append(edu) deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
preserve_context_over_fn( (edu, deferred)
self._attempt_new_transaction, destination
) )
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination): def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
deferred = defer.Deferred()
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
self.pending_failures_by_dest.setdefault( self.pending_failures_by_dest.setdefault(
destination, [] destination, []
).append(failure) ).append(
(failure, deferred)
preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def chain(f):
if not deferred.called:
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send failure to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor() yield run_on_reactor()
while True:
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
"TX [%s] Transaction already in progress",
destination
)
return
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) # list of (pending_pdu, deferred, order)
pending_edus = self.pending_edus_by_dest.pop(destination, []) if destination in self.pending_transactions:
pending_failures = self.pending_failures_by_dest.pop(destination, []) # XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
if pending_pdus: # we need application-layer timeouts of some flavour of these
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", # requests
destination, len(pending_pdus)) logger.debug(
"TX [%s] Transaction already in progress",
if not pending_pdus and not pending_edus and not pending_failures: destination
logger.debug("TX [%s] Nothing to send", destination)
return
yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures
) )
return
@measure_func("_send_new_transaction") pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@defer.inlineCallbacks pending_edus = self.pending_edus_by_dest.pop(destination, [])
def _send_new_transaction(self, destination, pending_pdus, pending_edus, pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_failures):
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus] pdus = [x[0] for x in pending_pdus]
edus = pending_edus edus = [x[0] for x in pending_edus]
failures = [x.get_dict() for x in pending_failures] failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try: txn_id = str(self._next_txn_id)
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination) limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
txn_id = str(self._next_txn_id) logger.debug(
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
limiter = yield get_retry_limiter( logger.debug("TX [%s] Persisting transaction...", destination)
destination,
self.clock,
self.store,
)
logger.debug( transaction = Transaction.create_new(
"TX [%s] {%s} Attempting new transaction" origin_server_ts=int(self._clock.time_msec()),
" (pdus: %d, edus: %d, failures: %d)", transaction_id=txn_id,
destination, txn_id, origin=self.server_name,
len(pending_pdus), destination=destination,
len(pending_edus), pdus=pdus,
len(pending_failures) edus=edus,
) pdu_failures=failures,
)
logger.debug("TX [%s] Persisting transaction...", destination) self._next_txn_id += 1
transaction = Transaction.create_new( yield self.transaction_actions.prepare_to_send(transaction)
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1 logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
yield self.transaction_actions.prepare_to_send(transaction) with limiter:
# Actually send the transaction
logger.debug("TX [%s] Persisted transaction", destination) # FIXME (erikj): This is a bit of a hack to make the Pdu age
logger.info( # keys work
"TX [%s] {%s} Sending transaction [%s]," def json_data_cb():
" (PDUs: %d, EDUs: %d, failures: %d)", data = transaction.get_dict()
destination, txn_id, now = int(self._clock.time_msec())
transaction.transaction_id, if "pdus" in data:
len(pending_pdus), for p in data["pdus"]:
len(pending_edus), if "age_ts" in p:
len(pending_failures), unsigned = p.setdefault("unsigned", {})
) unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
with limiter: try:
# Actually send the transaction response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
) )
code = 200
logger.debug("TX [%s] Sent transaction", destination) if response:
logger.debug("TX [%s] Marking as delivered...", destination) for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
if code != 200:
for p in pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
except NotRetryingDestination:
logger.info( logger.info(
"TX [%s] not ready for retry yet - " "TX [%s] {%s} got %d response",
"dropping transaction for now", destination, txn_id, code
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
) )
for p in pdus: logger.debug("TX [%s] Sent transaction", destination)
logger.info("Failed to send event %s to %s", p.event_id, destination) logger.debug("TX [%s] Marking as delivered...", destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for p in pdus: yield self.transaction_actions.delivered(
logger.info("Failed to send event %s to %s", p.event_id, destination) transaction, code, response
)
finally: logger.debug("TX [%s] Marked as delivered", destination)
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None) logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
deferred.callback(None)
else:
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)

View File

@@ -54,28 +54,6 @@ class TransportLayerClient(object):
destination, path=path, args={"event_id": event_id}, destination, path=path, args={"event_id": event_id},
) )
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
destination (str): The host name of the remote home server we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = PREFIX + "/state_ids/%s/" % room_id
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@log_function @log_function
def get_event(self, destination, event_id, timeout=None): def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server. """ Requests the pdu with give id and origin from the given server.

View File

@@ -18,14 +18,13 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
import functools import functools
import logging import logging
import simplejson as json
import re import re
import synapse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -61,16 +60,6 @@ class TransportLayerServer(JsonResource):
) )
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@@ -78,7 +67,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request, content): def authenticate_request(self, request):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@@ -86,11 +75,18 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
if content is not None: content = None
json_request["content"] = content
origin = None origin = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
def parse_auth_header(header_str): def parse_auth_header(header_str):
try: try:
params = auth.split(" ")[1].split(",") params = auth.split(" ")[1].split(",")
@@ -107,14 +103,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return (origin, key, sig) return (origin, key, sig)
except: except:
raise AuthenticationError( raise SynapseError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
raise NoAuthenticationError( raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@@ -125,7 +121,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]: if not json_request["signatures"]:
raise NoAuthenticationError( raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@@ -134,12 +130,10 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
defer.returnValue(origin) defer.returnValue((origin, content))
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name, def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler): room_list_handler):
self.handler = handler self.handler = handler
@@ -147,46 +141,29 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler self.room_list_handler = room_list_handler
def _wrap(self, func): def _wrap(self, code):
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(func) @functools.wraps(code)
def new_func(request, *args, **kwargs): def new_code(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try: try:
origin = yield authenticator.authenticate_request(request, content) (origin, content) = yield authenticator.authenticate_request(request)
except NoAuthenticationError: with ratelimiter.ratelimit(origin) as d:
origin = None yield d
if self.REQUIRE_AUTH: response = yield code(
logger.exception("authenticate_request failed") origin, content, request.args, *args, **kwargs
raise )
except: except:
logger.exception("authenticate_request failed") logger.exception("authenticate_request failed")
raise raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response) defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish # Extra logic that functools.wraps() doesn't finish
new_func.__self__ = func.__self__ new_code.__self__ = code.__self__
return new_func return new_code
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$") pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -294,17 +271,6 @@ class FederationStateServlet(BaseFederationServlet):
) )
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/"
def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request(
origin,
room_id,
query.get("event_id", [None])[0],
)
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/"
@@ -401,8 +367,10 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, origin, content, query):
return self.handler.on_query_client_keys(origin, content) response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
@@ -452,10 +420,9 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, request):
content = parse_json_object_from_request(request)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
@@ -477,6 +444,11 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception raise last_exception
defer.returnValue((200, {})) defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet): class OpenIdUserInfo(BaseFederationServlet):
""" """
@@ -497,11 +469,9 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo" PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query): def on_GET(self, request):
token = query.get("access_token", [None])[0] token = parse_string(request, "access_token")
if token is None: if token is None:
defer.returnValue((401, { defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required" "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -518,6 +488,11 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id})) defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet): class PublicRoomList(BaseFederationServlet):
""" """
@@ -558,26 +533,11 @@ class PublicRoomList(BaseFederationServlet):
defer.returnValue((200, data)) defer.returnValue((200, data))
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
FederationStateIdsServlet,
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
@@ -595,7 +555,6 @@ SERVLET_CLASSES = (
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo, OpenIdUserInfo,
PublicRoomList, PublicRoomList,
FederationVersionServlet,
) )

View File

@@ -19,6 +19,7 @@ from .room import (
) )
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@@ -30,21 +31,10 @@ from .search import SearchHandler
class Handlers(object): class Handlers(object):
""" Deprecated. A collection of handlers. """ A collection of all the event handlers.
At some point most of the classes whose name ended "Handler" were There's no need to lazily create these; we'll just make them all eagerly
accessed through this class. at construction time.
However this makes it painful to unit test the handlers and to run cut
down versions of synapse that only use specific handlers because using a
single handler required creating all of the handlers. So some of the
handlers have been lifted out of the Handlers object and are now accessed
directly through the homeserver object itself.
Any new handlers should follow the new pattern of being accessed through
the homeserver object and should not be added to the Handlers object.
The remaining handlers should be moved out of the handlers object.
""" """
def __init__(self, hs): def __init__(self, hs):
@@ -52,6 +42,8 @@ class Handlers(object):
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)

View File

@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.constants import Membership, EventTypes
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.types import UserID from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,15 +31,11 @@ class BaseHandler(object):
Common base class for the event handlers. Common base class for the event handlers.
Attributes: Attributes:
store (synapse.storage.DataStore): store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler): state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@@ -124,8 +120,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = synapse.types.create_requester( requester = Requester(target_user, "", True)
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership( yield handler.update_membership(
requester, requester,

View File

@@ -16,8 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.appservice import ApplicationService
from synapse.util.logcontext import preserve_fn
import logging import logging
@@ -43,73 +42,36 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api() self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler() self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, current_id): def notify_interested_services(self, event):
"""Notifies (pushes) all application services interested in this event. """Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any Pushing is done asynchronously, so this method won't block for any
prolonged length of time. prolonged length of time.
Args: Args:
current_id(int): The current maximum ID. event(Event): The event to push out to interested services.
""" """
services = yield self.store.get_app_services() # Gather interested services
if not services or not self.notify_appservices: services = yield self._get_services_for_event(event)
return if len(services) == 0:
return # no services need notifying
self.current_max = max(self.current_max, current_id) # Do we know this user exists? If not, poke the user query API for
if self.is_processing: # all services which match that user regex. This needs to block as these
return # user queries need to be made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
with Measure(self.clock, "notify_interested_services"): if not self.started_scheduler:
self.is_processing = True self.scheduler.start().addErrback(log_failure)
try: self.started_scheduler = True
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
if not events: # Fork off pushes to these services
break for service in services:
self.scheduler.submit_event_for_as(service, event)
for event in events:
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
continue # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
for service in services:
preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user_exists(self, user_id): def query_user_exists(self, user_id):
@@ -142,12 +104,11 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
services = yield self.store.get_app_services() alias_query_services = yield self._get_services_for_event(
alias_query_services = [ event=None,
s for s in services if ( restrict_to=ApplicationService.NS_ALIASES,
s.is_interested_in_alias(room_alias_str) alias_list=[room_alias_str]
) )
]
for alias_service in alias_query_services: for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias( is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str alias_service, room_alias_str
@@ -160,35 +121,34 @@ class ApplicationServicesHandler(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields): def _get_services_for_event(self, event, restrict_to="", alias_list=None):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
self.appservice_api.query_3pe(service, kind, protocol, fields)
for service in services
], consumeErrors=True)
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
Args: Args:
event(Event): The event to check. Can be None if alias_list is not. event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns: Returns:
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
yield s.is_interested(event, self.store) s.is_interested(event, restrict_to, alias_list, member_list)
) )
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@@ -203,14 +163,6 @@ class ApplicationServicesHandler(object):
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):

View File

@@ -45,10 +45,6 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = { self.checkers = {
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
@@ -70,14 +66,13 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base self.ldap_base = hs.config.ldap_base
self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH: if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@@ -235,6 +230,7 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -244,7 +240,11 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
return self._check_password(user_id, password) if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@@ -280,17 +280,8 @@ class AuthHandler(BaseHandler):
data = pde.response data = pde.response
resp_body = simplejson.loads(data) resp_body = simplejson.loads(data)
if 'success' in resp_body: if 'success' in resp_body and resp_body['success']:
# Note that we do NOT check the hostname here: we explicitly defer.returnValue(True)
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -357,84 +348,67 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
def validate_password_login(self, user_id, password): @defer.inlineCallbacks
def login_with_password(self, user_id, password):
""" """
Authenticates the user with their username and password. Authenticates the user with their username and password.
Used only by the v1 login API. Used only by the v1 login API.
Args: Args:
user_id (str): complete @user:id user_id (str): User ID
password (str): Password password (str): Password
Returns:
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args:
user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however, if not (yield self._check_password(user_id, password)):
# it's possible we raced against a DELETE operation. The thing we logger.warn("Failed password login for user %s", user_id)
# really don't want is active access_tokens without a record of the raise LoginError(403, "", errcode=Codes.FORBIDDEN)
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token)) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def get_login_tuple_for_user_id(self, user_id):
""" """
Checks to see if a user with the given id exists. Will check case Gets login tuple for the user with the given user ID.
insensitively, but return None if there are multiple inexact matches. The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
Args: Args:
(str) user_id: complete @user:id user_id (str): User ID
Returns: Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or A tuple of:
multiple matches The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
""" """
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try: try:
res = yield self._find_user_id_and_pwd_hash(user_id) yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(res[0]) defer.returnValue(True)
except LoginError: except LoginError:
defer.returnValue(None) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
@@ -464,45 +438,27 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
"""Authenticate a user against the LDAP and local databases. """
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns: Returns:
(str) the canonical_user_id True if the user_id successfully authenticated
Raises:
LoginError if the password was incorrect
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap: if valid_ldap:
defer.returnValue(user_id) defer.returnValue(True)
result = yield self._check_local_password(user_id, password) valid_local_password = yield self._check_local_password(user_id, password)
defer.returnValue(result) if valid_local_password:
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
user_id is checked case insensitively, but will throw if there are defer.returnValue(self.validate_hash(password, password_hash))
multiple inexact matches. except LoginError:
defer.returnValue(False)
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
@@ -614,7 +570,7 @@ class AuthHandler(BaseHandler):
) )
# check for existing account, if none exists, create one # check for existing account, if none exists, create one
if not (yield self.check_user_exists(user_id)): if not (yield self.does_user_exist(user_id)):
# query user metadata for account creation # query user metadata for account creation
query = "({prop}={value})".format( query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'], prop=self.ldap_attributes['uid'],
@@ -660,7 +616,7 @@ class AuthHandler(BaseHandler):
else: else:
logger.warn( logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results", "ldap registration failed: unexpected (%d!=1) amount of results",
len(conn.response) len(result)
) )
defer.returnValue(False) defer.returnValue(False)
@@ -670,26 +626,23 @@ class AuthHandler(BaseHandler):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def issue_access_token(self, user_id):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token, yield self.store.add_access_token_to_user(user_id, access_token)
device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_refresh_token(self, user_id, device_id=None): def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id) refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token, yield self.store.add_refresh_token_to_user(user_id, refresh_token)
device_id)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None, def generate_access_token(self, user_id, extra_caveats=None):
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats: for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
@@ -719,14 +672,13 @@ class AuthHandler(BaseHandler):
return macaroon.serialize() return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon) auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True)
return user_id return self.get_user_from_macaroon(macaroon)
except Exception: except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@@ -737,11 +689,21 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_id = requester.access_token_id if requester else None except_access_token_ids = [requester.access_token_id] if requester else []
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
@@ -750,10 +712,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_id user_id, except_access_token_ids
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id user_id, except_access_token_ids
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@@ -1,181 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api import errors
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string_with_symbols(16)
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys())
)
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),)
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})

View File

@@ -1,139 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import logging
from twisted.internet import defer
from synapse.api import errors
import synapse.types
logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks
def query_devices(self, query_body):
""" Handle a device key query from a client
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict)
for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id)
queries_by_domain[user.domain][user_id] = device_ids
# do the queries
# TODO: do these in parallel
results = {}
for destination, destination_query in queries_by_domain.items():
if destination == self.server_name:
res = yield self.query_local_devices(destination_query)
else:
res = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}
)
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
results[user_id] = keys
defer.returnValue((200, {"device_keys": results}))
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise errors.SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})

View File

@@ -124,7 +124,7 @@ class FederationHandler(BaseHandler):
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event auth_chain, state, event
) )
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
@@ -249,7 +249,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
domain = get_domain_from_id(ev.state_key) domain = UserID.from_string(ev.state_key).domain
except: except:
continue continue
@@ -274,7 +274,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return This will attempt to get more events from the remote. This may return
@@ -284,6 +284,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill( events = yield self.replication_layer.backfill(
dest, dest,
room_id, room_id,
@@ -332,59 +335,32 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state events_to_state[e_id] = state
required_auth = set(
a_id
for event in events + state_events.values() + auth_events.values()
for a_id, _ in event.auth_events
)
auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
})
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch
)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
required_auth.update(
a_id for event in results for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_events( seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth:
logger.info("Missing auth for backfill: %r", missing_auth)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = [] ev_infos = []
for a in auth_events.values(): for a in auth_events.values():
if a.event_id in seen_events: if a.event_id in seen_events:
@@ -396,7 +372,6 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in a.auth_events for a_id, _ in a.auth_events
if a_id in auth_events
} }
}) })
@@ -408,7 +383,6 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events for a_id, _ in event_map[e_id].auth_events
if a_id in auth_events
} }
}) })
@@ -452,10 +426,6 @@ class FederationHandler(BaseHandler):
) )
max_depth = sorted_extremeties_tuple[0][1] max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth: if current_depth > max_depth:
logger.debug( logger.debug(
"Not backfilling as we don't need to. %d < %d", "Not backfilling as we don't need to. %d < %d",
@@ -667,7 +637,7 @@ class FederationHandler(BaseHandler):
pass pass
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event auth_chain, state, event
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
@@ -718,9 +688,7 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e) logger.warn("Failed to create join %r because %s", event, e)
raise e raise e
# The remote hasn't signed it yet, obviously. We'll do the full checks self.auth.check(event, auth_events=context.current_state)
# when we get the event back in `on_send_join_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
defer.returnValue(event) defer.returnValue(event)
@@ -950,9 +918,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
# The remote hasn't signed it yet, obviously. We'll do the full checks self.auth.check(event, auth_events=context.current_state)
# when we get the event back in `on_send_leave_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@@ -1021,9 +987,14 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, [event_id] room_id, [event_id]
) )
@@ -1094,17 +1065,16 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
if self.hs.is_mine_id(event.event_id): # FIXME: This is a temporary work around where we occasionally
# FIXME: This is a temporary work around where we occasionally # return events slightly differently than when they were
# return events slightly differently than when they were # originally signed
# originally signed event.signatures.update(
event.signatures.update( compute_event_signature(
compute_event_signature( event,
event, self.hs.hostname,
self.hs.hostname, self.hs.config.signing_key[0]
self.hs.config.signing_key[0]
)
) )
)
if do_auth: if do_auth:
in_room = yield self.auth.check_host_in_room( in_room = yield self.auth.check_host_in_room(
@@ -1114,12 +1084,6 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server(
origin, event.room_id, [event]
)
event = events[0]
defer.returnValue(event) defer.returnValue(event)
else: else:
defer.returnValue(None) defer.returnValue(None)
@@ -1150,12 +1114,11 @@ class FederationHandler(BaseHandler):
backfilled=backfilled, backfilled=backfilled,
) )
if not backfilled: # this intentionally does not yield: we don't care about the result
# this intentionally does not yield: we don't care about the result # and don't need to wait for it.
# and don't need to wait for it. preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
preserve_fn(self.hs.get_pusherpool().on_new_notifications)( event_stream_id, max_stream_id
event_stream_id, max_stream_id )
)
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1187,19 +1150,11 @@ class FederationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_auth_tree(self, origin, auth_events, state, event): def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event seperately. Persists the event seperately.
Will attempt to fetch missing auth events.
Args:
origin (str): Where the events came from
auth_events (list)
state (list)
event (Event)
Returns: Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event 2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event` call for `event`
@@ -1212,7 +1167,7 @@ class FederationHandler(BaseHandler):
event_map = { event_map = {
e.event_id: e e.event_id: e
for e in itertools.chain(auth_events, state, [event]) for e in auth_events
} }
create_event = None create_event = None
@@ -1221,29 +1176,10 @@ class FederationHandler(BaseHandler):
create_event = e create_event = e
break break
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events:
if e_id not in event_map:
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu(
[origin],
e_id,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
else:
logger.info("Failed to find auth event %r", e_id)
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
auth_for_e = { auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events for e_id, _ in e.auth_events
if e_id in event_map
} }
if create_event: if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event auth_for_e[(EventTypes.Create, "")] = create_event

View File

@@ -66,7 +66,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True, event_filter=None): as_client_event=True):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@@ -75,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@@ -129,13 +129,8 @@ class MessageHandler(BaseHandler):
room_id, max_topo room_id, max_topo
) )
events, next_key = yield self.store.paginate_room_events( events, next_key = yield data_source.get_pagination_rows(
room_id=room_id, requester.user, source_config, room_id
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
event_filter=event_filter,
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@@ -149,9 +144,6 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, self.store,
user_id, user_id,
@@ -172,6 +164,101 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def get_files(self, requester, room_id, pagin_config):
"""Get files in a room.
Args:
requester (Requester): The user requesting files.
room_id (str): The room they want files from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
)
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id
)
if source_config.direction == 'b':
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token(
room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
events, next_key = yield self.store.paginate_room_file_events(
room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield filter_events_for_client(
self.store,
user_id,
events,
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now)
for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
""" """

View File

@@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_interested_parties(self, states, calculate_remote_hosts=True): def _get_interested_parties(self, states):
"""Given a list of states return which entities (rooms, users, servers) """Given a list of states return which entities (rooms, users, servers)
are interested in the given states. are interested in the given states.
@@ -526,15 +526,14 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {} hosts_to_states = {}
if calculate_remote_hosts: for room_id, states in room_ids_to_states.items():
for room_id, states in room_ids_to_states.items(): local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) if not local_states:
if not local_states: continue
continue
hosts = yield self.store.get_joined_hosts_for_room(room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
for host in hosts: for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items(): for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
@@ -566,16 +565,6 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states) self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states): def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers. """Sends state updates to remote servers.
@@ -683,7 +672,7 @@ class PresenceHandler(object):
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False): def set_state(self, target_user, state):
"""Set the presence state of the user. """Set the presence state of the user.
""" """
status_msg = state.get("status_msg", None) status_msg = state.get("status_msg", None)
@@ -700,13 +689,10 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id) prev_state = yield self.current_state_for_user(user_id)
new_fields = { new_fields = {
"state": presence "state": presence,
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
} }
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE: if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec() new_fields["last_active_ts"] = self.clock.time_msec()

View File

@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID from synapse.types import UserID, Requester
from ._base import BaseHandler from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -165,9 +165,7 @@ class ProfileHandler(BaseHandler):
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the user isn't a guest because we don't let guests set
# profile or avatar data. # profile or avatar data.
# XXX why are we recreating `requester` here for each room? requester = Requester(user, "", False)
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, user,

View File

@@ -14,19 +14,18 @@
# limitations under the License. # limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
import synapse.types from synapse.types import UserID, Requester
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from synapse.http.client import CaptchaServerHttpClient
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -53,13 +52,6 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
if localpart[0] == '_':
raise SynapseError(
400,
"User ID may not begin with _",
Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@@ -107,13 +99,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@@ -209,13 +196,15 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart, create_profile_with_localpart=user.localpart,
) )
defer.returnValue(user_id) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
@@ -371,7 +360,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_in_ms, def get_or_create_user(self, localpart, displayname, duration_seconds,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@@ -401,8 +390,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_access_token( token = self.auth_handler().generate_short_term_login_token(
user_id, None, duration_in_ms) user_id, duration_seconds)
if need_register: if need_register:
yield self.store.register( yield self.store.register(
@@ -418,9 +407,8 @@ class RegistrationHandler(BaseHandler):
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, requester, displayname user, Requester(user, token, False), displayname
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View File

@@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler):
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs) self.response_cache = ResponseCache()
self.remote_list_request_cache = ResponseCache(hs) self.remote_list_request_cache = ResponseCache()
self.remote_list_cache = {} self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call( self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL

View File

@@ -14,22 +14,24 @@
# limitations under the License. # limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, EventTypes, Membership,
) )
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -141,7 +143,7 @@ class RoomMemberHandler(BaseHandler):
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
): ):
key = (room_id,) key = (target, room_id,)
with (yield self.member_linearizer.queue(key)): with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership( result = yield self._update_membership(
@@ -313,7 +315,7 @@ class RoomMemberHandler(BaseHandler):
) )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = synapse.types.create_requester(target_user) requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context) prev_event = message_handler.deduplicate_state_event(event, context)

View File

@@ -138,7 +138,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs) self.response_cache = ResponseCache()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
@@ -464,10 +464,10 @@ class SyncHandler(object):
else: else:
state = {} state = {}
defer.returnValue({ defer.returnValue({
(e.type, e.state_key): e (e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values()) for e in sync_config.filter_collection.filter_room_state(state.values())
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config): def unread_notifs_for_room_id(self, room_id, sync_config):
@@ -485,9 +485,9 @@ class SyncHandler(object):
) )
defer.returnValue(notifs) defer.returnValue(notifs)
# There is no new information in this period, so your notification # There is no new information in this period, so your notification
# count is whatever it was last time. # count is whatever it was last time.
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False): def generate_sync_result(self, sync_config, since_token=None, full_state=False):

View File

@@ -155,7 +155,9 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn(send_request) response = yield preserve_context_over_fn(
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break

View File

@@ -19,7 +19,6 @@ from synapse.api.errors import (
) )
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure
import synapse.metrics import synapse.metrics
import synapse.events import synapse.events
@@ -75,12 +74,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0 _next_request_id = 0
def request_handler(include_metrics=False): def request_handler(report_metrics=True):
"""Decorator for ``wrap_request_handler``""" """Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, include_metrics) return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
def wrap_request_handler(request_handler, include_metrics=False): def wrap_request_handler(request_handler, report_metrics):
"""Wraps a method that acts as a request handler with the necessary logging """Wraps a method that acts as a request handler with the necessary logging
and exception handling. and exception handling.
@@ -104,56 +103,54 @@ def wrap_request_handler(request_handler, include_metrics=False):
_next_request_id += 1 _next_request_id += 1
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"): if report_metrics:
request_metrics = RequestMetrics() request_metrics = RequestMetrics()
request_metrics.start(self.clock, name=self.__class__.__name__) request_metrics.start(self.clock)
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try:
with PreserveLoggingContext(request_context):
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
try: try:
with PreserveLoggingContext(request_context): if report_metrics:
if include_metrics:
yield request_handler(self, request, request_metrics)
else:
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
try:
request_metrics.stop( request_metrics.stop(
self.clock, request self.clock, request, self.__class__.__name__
) )
except Exception as e: except:
logger.warn("Failed to stop metrics: %r", e) pass
return wrapped_request_handler return wrapped_request_handler
@@ -208,7 +205,6 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )
@@ -223,9 +219,9 @@ class JsonResource(HttpServer, resource.Resource):
# It does its own metric reporting because _async_render dispatches to # It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report # a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself. # against rather than the JsonResource itself.
@request_handler(include_metrics=True) @request_handler(report_metrics=False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request, request_metrics): def _async_render(self, request):
""" This gets called from render() every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
@@ -234,6 +230,9 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, 200, {}) self._send_response(request, 200, {})
return return
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
@@ -247,6 +246,12 @@ class JsonResource(HttpServer, resource.Resource):
callback = path_entry.callback callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items() for name, value in m.groupdict().items()
@@ -257,13 +262,10 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
servlet_instance = getattr(callback, "__self__", None) try:
if servlet_instance is not None: request_metrics.stop(self.clock, request, servlet_classname)
servlet_classname = servlet_instance.__class__.__name__ except:
else: pass
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
return return
@@ -295,12 +297,11 @@ class JsonResource(HttpServer, resource.Resource):
class RequestMetrics(object): class RequestMetrics(object):
def start(self, clock, name): def start(self, clock):
self.start = clock.time_msec() self.start = clock.time_msec()
self.start_context = LoggingContext.current_context() self.start_context = LoggingContext.current_context()
self.name = name
def stop(self, clock, request): def stop(self, clock, request, servlet_classname):
context = LoggingContext.current_context() context = LoggingContext.current_context()
tag = "" tag = ""
@@ -314,26 +315,26 @@ class RequestMetrics(object):
) )
return return
incoming_requests_counter.inc(request.method, self.name, tag) incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by( response_timer.inc_by(
clock.time_msec() - self.start, request.method, clock.time_msec() - self.start, request.method,
self.name, tag servlet_classname, tag
) )
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by( response_ru_utime.inc_by(
ru_utime, request.method, self.name, tag ru_utime, request.method, servlet_classname, tag
) )
response_ru_stime.inc_by( response_ru_stime.inc_by(
ru_stime, request.method, self.name, tag ru_stime, request.method, servlet_classname, tag
) )
response_db_txn_count.inc_by( response_db_txn_count.inc_by(
context.db_txn_count, request.method, self.name, tag context.db_txn_count, request.method, servlet_classname, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, self.name, tag context.db_txn_duration, request.method, servlet_classname, tag
) )

View File

@@ -27,8 +27,7 @@ import gc
from twisted.internet import reactor from twisted.internet import reactor
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
MemoryUsageMetric,
) )
@@ -67,21 +66,6 @@ class Metrics(object):
return self._register(CacheMetric, *args, **kwargs) return self._register(CacheMetric, *args, **kwargs)
def register_memory_metrics(hs):
try:
import psutil
process = psutil.Process()
process.memory_info().rss
except (ImportError, AttributeError):
logger.warn(
"psutil is not installed or incorrect version."
" Disabling memory metrics."
)
return
metric = MemoryUsageMetric(hs, psutil)
all_metrics.append(metric)
def get_metrics_for(pkg_name): def get_metrics_for(pkg_name):
""" Returns a Metrics instance for conveniently creating metrics """ Returns a Metrics instance for conveniently creating metrics
namespaced with the given name prefix. """ namespaced with the given name prefix. """

View File

@@ -153,43 +153,3 @@ class CacheMetric(object):
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
] ]
class MemoryUsageMetric(object):
"""Keeps track of the current memory usage, using psutil.
The class will keep the current min/max/sum/counts of rss over the last
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
"""
UPDATE_HZ = 2 # number of times to get memory per second
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
def __init__(self, hs, psutil):
clock = hs.get_clock()
self.memory_snapshots = []
self.process = psutil.Process()
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
def _update_curr_values(self):
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
self.memory_snapshots.append(self.process.memory_info().rss)
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
def render(self):
if not self.memory_snapshots:
return []
max_rss = max(self.memory_snapshots)
min_rss = min(self.memory_snapshots)
sum_rss = sum(self.memory_snapshots)
len_rss = len(self.memory_snapshots)
return [
"process_psutil_rss:max %d" % max_rss,
"process_psutil_rss:min %d" % min_rss,
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]

View File

@@ -20,7 +20,6 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
import synapse.metrics import synapse.metrics
@@ -68,8 +67,10 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user_id, rooms, current_token, time_now_ms): def __init__(self, user_id, rooms, current_token, time_now_ms,
appservice=None):
self.user_id = user_id self.user_id = user_id
self.appservice = appservice
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
@@ -106,6 +107,11 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id) notifier.user_to_user_stream.pop(self.user_id)
if self.appservice:
notifier.appservice_to_user_streams.get(
self.appservice, set()
).discard(self)
def count_listeners(self): def count_listeners(self):
return len(self.notify_deferred.observers()) return len(self.notify_deferred.observers())
@@ -136,6 +142,7 @@ class Notifier(object):
def __init__(self, hs): def __init__(self, hs):
self.user_to_user_stream = {} self.user_to_user_stream = {}
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@@ -161,6 +168,8 @@ class Notifier(object):
all_user_streams |= x all_user_streams |= x
for x in self.user_to_user_stream.values(): for x in self.user_to_user_stream.values():
all_user_streams.add(x) all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
@@ -173,6 +182,10 @@ class Notifier(object):
"users", "users",
lambda: len(self.user_to_user_stream), lambda: len(self.user_to_user_stream),
) )
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
)
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
@@ -215,7 +228,21 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.appservice_handler.notify_interested_services(room_stream_id) self.appservice_handler.notify_interested_services(event)
app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
app_streams |= app_user_streams
if event.type == EventTypes.Member and event.membership == Membership.JOIN: if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id) self._user_joined_room(event.state_key, event.room_id)
@@ -224,33 +251,34 @@ class Notifier(object):
"room_key", room_stream_id, "room_key", room_stream_id,
users=extra_users, users=extra_users,
rooms=[event.room_id], rooms=[event.room_id],
extra_streams=app_streams,
) )
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
with PreserveLoggingContext(): with PreserveLoggingContext():
with Measure(self.clock, "on_new_event"): user_streams = set()
user_streams = set()
for user in users: for user in users:
user_stream = self.user_to_user_stream.get(str(user)) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None: if user_stream is not None:
user_streams.add(user_stream) user_streams.add(user_stream)
for room in rooms: for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
for user_stream in user_streams: for user_stream in user_streams:
try: try:
user_stream.notify(stream_key, new_token, time_now_ms) user_stream.notify(stream_key, new_token, time_now_ms)
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
self.notify_replication() self.notify_replication()
def on_new_replication_data(self): def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
@@ -266,6 +294,7 @@ class Notifier(object):
""" """
user_stream = self.user_to_user_stream.get(user_id) user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id) rooms = yield self.store.get_rooms_for_user(user_id)
@@ -273,6 +302,7 @@ class Notifier(object):
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user_id=user_id, user_id=user_id,
rooms=room_ids, rooms=room_ids,
appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
) )
@@ -447,6 +477,11 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream) s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user_id, room_id): def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id) new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None: if new_user_stream is not None:

View File

@@ -38,16 +38,15 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"): with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context.state_group, context.current_state event, self.hs, self.store, context.current_state
) )
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context.current_state event, context.current_state
) )
context.push_actions = [ context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.items()
] ]

View File

@@ -217,27 +217,6 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify' 'dont_notify'
] ]
}, },
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
] ]
@@ -263,6 +242,23 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{ {
'rule_id': 'global/underride/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [

View File

@@ -36,11 +36,35 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, state_group, current_state): def evaluator_for_event(event, hs, store, current_state):
rules_by_user = yield store.bulk_get_push_rules_for_room( room_id = event.room_id
event.room_id, state_group, current_state # We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
) )
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite': if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@@ -48,12 +72,12 @@ def evaluator_for_event(event, hs, store, state_group, current_state):
if invited_user and hs.is_mine_id(invited_user): if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher: if has_pusher:
rules_by_user[invited_user] = yield store.get_push_rules_for_user( user_ids.add(invited_user)
invited_user
) rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
event.room_id, rules_by_user, store room_id, rules_by_user, user_ids, store
)) ))
@@ -66,9 +90,10 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, store): def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
@@ -93,11 +92,7 @@ class EmailPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
try: self.timed_call.cancel()
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@@ -145,8 +140,9 @@ class EmailPusher(object):
being run. being run.
""" """
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering) self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None soonest_due_at = None
@@ -194,10 +190,7 @@ class EmailPusher(object):
soonest_due_at = should_notify_at soonest_due_at = should_notify_at
if self.timed_call is not None: if self.timed_call is not None:
try: self.timed_call.cancel()
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None self.timed_call = None
if soonest_due_at is not None: if soonest_due_at is not None:

View File

@@ -16,7 +16,6 @@
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
import push_rule_evaluator import push_rule_evaluator
@@ -110,11 +109,7 @@ class HttpPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
try: self.timed_call.cancel()
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def _process(self): def _process(self):
@@ -146,8 +141,7 @@ class HttpPusher(object):
run once per pusher. run once per pusher.
""" """
fn = self.store.get_unread_push_actions_for_user_in_range_for_http unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
unprocessed = yield fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )

View File

@@ -54,7 +54,7 @@ def get_context_for_event(state_handler, ev, user_id):
room_state = yield state_handler.get_current_state(ev.room_id) room_state = yield state_handler.get_current_state(ev.room_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.namer, an alias or
# a list of people in the room # a list of people in the room
name = calculate_room_name( name = calculate_room_name(
room_state, user_id, fallback_to_single_member=False room_state, user_id, fallback_to_single_member=False

View File

@@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_access_token_id=None): def remove_pushers_by_user(self, user_id, except_token_ids=[]):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access tokens id %r", "Removing all pushers for user %s except access tokens ids %r",
user_id, except_access_token_id user_id, except_token_ids
) )
for p in all: for p in all:
if p['user_name'] == user_id and p['access_token'] != except_access_token_id: if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View File

@@ -51,9 +51,6 @@ CONDITIONAL_REQUIREMENTS = {
"ldap": { "ldap": {
"ldap3>=1.0": ["ldap3>=1.0"], "ldap3>=1.0": ["ldap3>=1.0"],
}, },
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
} }

View File

@@ -41,7 +41,6 @@ STREAM_NAMES = (
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("state",), ("state",),
("caches",),
) )
@@ -71,7 +70,6 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers. * "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules. * "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers. * "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters: The API takes two additional query parameters:
@@ -131,7 +129,6 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token() state_token = self.store.get_state_stream_token()
caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@@ -143,7 +140,6 @@ class ReplicationResource(Resource):
push_rules_token, push_rules_token,
pushers_token, pushers_token,
state_token, state_token,
caches_token,
)) ))
@request_handler() @request_handler()
@@ -192,7 +188,6 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@@ -384,20 +379,6 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id" "position", "type", "state_key", "event_id"
)) ))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches
caches = request_streams.get("caches")
if caches is not None:
updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit
)
writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@@ -426,7 +407,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "push_rules", "pushers", "state"
))): ))):
__slots__ = [] __slots__ = []

View File

@@ -14,43 +14,15 @@
# limitations under the License. # limitations under the License.
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs) super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self): def stream_positions(self):
pos = {} return {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result): def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)

View File

@@ -28,13 +28,3 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__ get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__

View File

@@ -1,23 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
]

View File

@@ -93,11 +93,8 @@ class SlavedEventStore(BaseSlavedStore):
StreamStore.__dict__["get_recent_event_ids_for_room"] StreamStore.__dict__["get_recent_event_ids_for_room"]
) )
get_unread_push_actions_for_user_in_range_for_http = ( get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ DataStore.get_unread_push_actions_for_user_in_range.__func__
)
get_unread_push_actions_for_user_in_range_for_email = (
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
) )
get_push_action_users_in_range = ( get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__ DataStore.get_push_action_users_in_range.__func__
@@ -145,15 +142,6 @@ class SlavedEventStore(BaseSlavedStore):
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__
get_missing_events = DataStore.get_missing_events.__func__
_get_missing_events = DataStore._get_missing_events.__func__
get_auth_chain = DataStore.get_auth_chain.__func__
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()

View File

@@ -1,33 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.keys import KeyStore
class SlavedKeyStore(BaseSlavedStore):
_get_server_verify_key = KeyStore.__dict__[
"_get_server_verify_key"
]
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
store_server_verify_key = DataStore.store_server_verify_key.__func__
get_server_certificate = DataStore.get_server_certificate.__func__
store_server_certificate = DataStore.store_server_certificate.__func__
get_server_keys_json = DataStore.get_server_keys_json.__func__
store_server_keys_json = DataStore.store_server_keys_json.__func__

View File

@@ -25,9 +25,6 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens # TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[ get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token" "get_user_by_access_token"
] ].orig
_query_for_auth = DataStore._query_for_auth.__func__ _query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View File

@@ -1,21 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
class RoomStore(BaseSlavedStore):
get_public_room_ids = DataStore.get_public_room_ids.__func__

View File

@@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[
"get_destination_retry_timings"
].orig
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
# For now, don't record the destination rety timings
def set_destination_retry_timings(*args, **kwargs):
return defer.succeed(None)

View File

@@ -46,9 +46,6 @@ from synapse.rest.client.v2_alpha import (
account_data, account_data,
report_event, report_event,
openid, openid,
notifications,
devices,
thirdparty,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@@ -93,6 +90,3 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource) account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)

View File

@@ -28,10 +28,6 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(WhoisRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@@ -86,10 +82,6 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
) )
def __init__(self, hs):
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@@ -52,11 +52,8 @@ class ClientV1RestServlet(RestServlet):
""" """
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()

View File

@@ -36,10 +36,6 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
@@ -150,7 +146,6 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs) super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):

View File

@@ -32,10 +32,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
self.event_stream_handler = hs.get_event_stream_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req( requester = yield self.auth.get_user_by_req(
@@ -50,6 +46,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args: if "room_id" in request.args:
room_id = request.args["room_id"][0] room_id = request.args["room_id"][0]
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args: if "timeout" in request.args:
@@ -60,7 +57,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
chunk = yield self.event_stream_handler.get_stream( chunk = yield handler.get_stream(
requester.user.to_string(), requester.user.to_string(),
pagin_config, pagin_config,
timeout=timeout, timeout=timeout,
@@ -83,12 +80,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id) handler = self.handlers.event_handler
event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:

View File

@@ -23,10 +23,6 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$") PATTERNS = client_path_patterns("/initialSync$")
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)

View File

@@ -54,9 +54,11 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@@ -107,6 +109,17 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE): LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission) result = yield self.do_jwt_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield self.http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission) result = yield self.do_token_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
@@ -132,23 +145,15 @@ class LoginRestServlet(ClientV1RestServlet):
).to_string() ).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_id = yield auth_handler.validate_password_login( user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id, user_id=user_id,
password=login_submission["password"], password=login_submission["password"])
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id,
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
@@ -160,23 +165,61 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
device_id = yield self._register_device(user_id, login_submission) user_id, access_token, refresh_token = (
access_token, refresh_token = ( yield auth_handler.get_login_tuple_for_user_id(user_id)
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id,
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
@defer.inlineCallbacks @defer.inlineCallbacks
def do_jwt_login(self, login_submission): def do_jwt_login(self, login_submission):
token = login_submission.get("token", None) token = login_submission.get("token", None)
@@ -202,27 +245,18 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if registered_user_id: if user_exists:
device_id = yield self._register_device( user_id, access_token, refresh_token = (
registered_user_id, login_submission yield auth_handler.get_login_tuple_for_user_id(user_id)
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name")
)
) )
result = { result = {
"user_id": registered_user_id, "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = ( user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
@@ -234,25 +268,32 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
def _register_device(self, user_id, login_submission): # TODO Delete this after all CAS clients switch to token login instead
"""Register a device for a user. def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
This is called after the user's credentials have been validated, but return (user, attributes)
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class SAML2RestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet):
@@ -261,7 +302,6 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs) super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path self.sp_config = hs.config.saml2_config_path
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@@ -299,6 +339,18 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet): class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@@ -330,8 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@@ -364,13 +414,13 @@ class CasTicketServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if not registered_user_id: if not user_exists:
registered_user_id, _ = ( user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
login_token = auth_handler.generate_short_term_login_token(registered_user_id) login_token = auth_handler.generate_short_term_login_token(user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token) login_token)
request.redirect(redirect_url) request.redirect(redirect_url)
@@ -384,39 +434,30 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts) return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
user = None root = ET.fromstring(cas_response_body)
attributes = None if not root.tag.endswith("serviceResponse"):
try: raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
root = ET.fromstring(cas_response_body) if not root[0].tag.endswith("authenticationSuccess"):
if not root.tag.endswith("serviceResponse"): raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
raise Exception("root of CAS response is not serviceResponse") for child in root[0]:
success = (root[0].tag.endswith("authenticationSuccess")) if child.tag.endswith("user"):
for child in root[0]: user = child.text
if child.tag.endswith("user"): if child.tag.endswith("attributes"):
user = child.text attributes = {}
if child.tag.endswith("attributes"): for attribute in child:
attributes = {} # ElementTree library expands the namespace in attribute tags
for attribute in child: # to the full URL of the namespace.
# ElementTree library expands the namespace in # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# attribute tags to the full URL of the namespace. # We don't care about namespace here and it will always be encased in
# We don't care about namespace here and it will always # curly braces, so we remove them.
# be encased in curly braces, so we remove them. if "}" in attribute.tag:
tag = attribute.tag attributes[attribute.tag.split("}")[1]] = attribute.text
if "}" in tag: else:
tag = tag.split("}")[1] attributes[attribute.tag] = attribute.text
attributes[tag] = attribute.text if user is None or attributes is None:
if user is None: raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
raise Exception("CAS response does not contain user")
if attributes is None: return (user, attributes)
raise Exception("CAS response does not contain attributes")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
return user, attributes
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
@@ -426,3 +467,5 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server) CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server) CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View File

@@ -24,10 +24,6 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@@ -66,10 +62,6 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
@@ -107,10 +99,6 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = UserID.from_string(user_id) user = UserID.from_string(user_id)

View File

@@ -52,10 +52,6 @@ class RegisterRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__(hs) super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as: # sessions are stored as:
# self.sessions = { # self.sessions = {
@@ -64,8 +60,6 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage # TODO: persistent storage
self.sessions = {} self.sessions = {}
self.enable_registration = hs.config.enable_registration self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request): def on_GET(self, request):
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
@@ -305,10 +299,9 @@ class RegisterRestServlet(ClientV1RestServlet):
user_localpart = register_json["user"].encode("utf-8") user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler handler = self.handlers.registration_handler
user_id = yield handler.appservice_register( (user_id, token) = yield handler.appservice_register(
user_localpart, as_token user_localpart, as_token
) )
token = yield self.auth_handler.issue_access_token(user_id)
self._remove_session(session) self._remove_session(session)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,
@@ -384,7 +377,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs) super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@@ -437,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
user_id, token = yield handler.get_or_create_user( user_id, token = yield handler.get_or_create_user(
localpart=localpart, localpart=localpart,
displayname=displayname, displayname=displayname,
duration_in_ms=(duration_seconds * 1000), duration_seconds=duration_seconds,
password_hash=password_hash password_hash=password_hash
) )

View File

@@ -20,14 +20,12 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
import logging import logging
import urllib import urllib
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -35,10 +33,6 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet): class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@@ -86,10 +80,6 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events # TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet): class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -174,10 +164,6 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback # TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet): class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@@ -222,9 +208,6 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins # TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet): class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
@@ -311,10 +294,6 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
@@ -341,10 +320,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -352,19 +327,31 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10, request, default_limit=10,
) )
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
handler = self.handlers.message_handler handler = self.handlers.message_handler
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
requester=requester, requester=requester,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event, as_client_event=as_client_event
event_filter=event_filter, )
defer.returnValue((200, msgs))
class RoomFileListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/files$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(
request, default_limit=10, default_dir='b',
)
handler = self.handlers.message_handler
msgs = yield handler.get_files(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
) )
defer.returnValue((200, msgs)) defer.returnValue((200, msgs))
@@ -374,10 +361,6 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -395,10 +378,6 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -419,7 +398,6 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomEventContext, self).__init__(hs) super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_id): def on_GET(self, request, room_id, event_id):
@@ -456,10 +434,6 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet): class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@@ -498,10 +472,6 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet): class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@@ -582,10 +552,6 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server) register_txn_path(self, PATTERNS, http_server)
@@ -668,10 +634,6 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$" "/search$"
) )
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
@@ -724,6 +686,7 @@ def register_servlets(hs, http_server):
RoomCreateRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server)
RoomFileListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server) RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server)

View File

@@ -25,9 +25,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,), def client_v2_patterns(path_regex, releases=(0,)):
v2_alpha=True,
unstable=True):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@@ -37,12 +35,9 @@ def client_v2_patterns(path_regex, releases=(0,),
Returns: Returns:
SRE_Pattern SRE_Pattern
""" """
patterns = [] patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
if v2_alpha: unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) patterns.append(re.compile("^" + unstable_prefix + path_regex))
if unstable:
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases: for release in releases:
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))

View File

@@ -28,40 +28,8 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$") PATTERNS = client_v2_patterns("/account/password")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
@@ -121,83 +89,8 @@ class PasswordRestServlet(RestServlet):
return 200, {} return 200, {}
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
if user_id != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
defer.returnValue((200, {}))
class ThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$") PATTERNS = client_v2_patterns("/account/3pid")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
@@ -264,8 +157,5 @@ class ThreepidRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server)

View File

@@ -1,100 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DevicesRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
device = yield self.device_handler.get_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)

View File

@@ -13,25 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors
import synapse.server
import synapse.types
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID from synapse.types import UserID
from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns from ._base import client_v2_patterns
import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet): class KeyUploadServlet(RestServlet):
""" """
POST /keys/upload HTTP/1.1 POST /keys/upload/<device_id> HTTP/1.1
Content-Type: application/json Content-Type: application/json
{ {
@@ -54,45 +53,23 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
releases=())
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(KeyUploadServlet, self).__init__() super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if (requester.device_id is not None and
device_id != requester.device_id):
logger.warning("Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id, device_id)
else:
device_id = requester.device_id
if device_id is None:
raise synapse.api.errors.SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys. # TODO: Validate the JSON to make sure it has the right keys.
@@ -125,12 +102,13 @@ class KeyUploadServlet(RestServlet):
user_id, device_id, time_now, key_list user_id, device_id, time_now, key_list
) )
# the device should have been registered already, but it may have been result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
# deleted due to a race with a DELETE request. Or we may be using an defer.returnValue((200, {"one_time_key_counts": result}))
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with @defer.inlineCallbacks
# keys without a corresponding device. def on_GET(self, request, device_id):
self.device_handler.check_device_registered(user_id, device_id) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result})) defer.returnValue((200, {"one_time_key_counts": result}))
@@ -184,19 +162,17 @@ class KeyQueryServlet(RestServlet):
) )
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(KeyQueryServlet, self).__init__() super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler() self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id, device_id): def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body) result = yield self.handle_request(body)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -205,11 +181,45 @@ class KeyQueryServlet(RestServlet):
auth_user_id = requester.user.to_string() auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []
result = yield self.e2e_keys_handler.query_devices( result = yield self.handle_request(
{"device_keys": {user_id: device_ids}} {"device_keys": {user_id: device_ids}}
) )
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet): class OneTimeKeyServlet(RestServlet):
""" """

View File

@@ -1,99 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import (
RestServlet, parse_string, parse_integer
)
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id,
)
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
PATTERNS = client_v2_patterns("/notifications$", releases=())
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False)
limit = parse_integer(request, "limit", default=50)
limit = min(limit, 500)
push_actions = yield self.store.get_push_actions_for_user(
user_id, from_token, limit
)
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
user_id, 'm.read'
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
notif_events = yield self.store.get_events(notif_event_ids)
returned_push_actions = []
next_token = None
for pa in push_actions:
returned_pa = {
"room_id": pa["room_id"],
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
"event": serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
),
}
if pa["room_id"] not in receipts_by_room:
returned_pa["read"] = False
else:
receipt = receipts_by_room[pa["room_id"]]
returned_pa["read"] = (
receipt["topological_ordering"], receipt["stream_ordering"]
) >= (
pa["topological_ordering"], pa["stream_ordering"]
)
returned_push_actions.append(returned_pa)
next_token = pa["stream_ordering"]
defer.returnValue((200, {
"notifications": returned_push_actions,
"next_token": next_token,
}))
def register_servlets(hs, http_server):
NotificationsServlet(hs).register(http_server)

View File

@@ -41,59 +41,17 @@ else:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$") PATTERNS = client_v2_patterns("/register")
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__() super(RegisterRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler() self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@@ -112,6 +70,10 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,) "Do not understand membership kind: %s" % (kind,)
) )
if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these # we do basic sanity checks here because the auth layer will store these
@@ -142,12 +104,11 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the # Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username) if isinstance(body.get("user"), basestring):
desired_username = body["user"]
if isinstance(desired_username, basestring): result = yield self._do_appservice_registration(
result = yield self._do_appservice_registration( desired_username, request.args["access_token"][0]
desired_username, request.args["access_token"][0], body )
)
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return
@@ -156,7 +117,7 @@ class RegisterRestServlet(RestServlet):
# FIXME: Should we really be determining if this is shared secret # FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key? # auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration( result = yield self._do_shared_secret_registration(
desired_username, desired_password, body desired_username, desired_password, body["mac"]
) )
defer.returnValue((200, result)) # we throw for non 200 responses defer.returnValue((200, result)) # we throw for non 200 responses
return return
@@ -196,12 +157,12 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
authed, auth_result, params, session_id = yield self.auth_handler.check_auth( authed, result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
if not authed: if not authed:
defer.returnValue((401, auth_result)) defer.returnValue((401, result))
return return
if registered_user_id is not None: if registered_user_id is not None:
@@ -209,58 +170,106 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session", "Already registered user ID %r for this session",
registered_user_id registered_user_id
) )
# don't re-register the email address access_token = yield self.auth_handler.issue_access_token(registered_user_id)
add_email = False refresh_token = yield self.auth_handler.issue_refresh_token(
else: registered_user_id
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
generate_token=False,
) )
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))
# remember that we've now registered that user account, and with # NB: This may be from the auth handler and NOT from the POST
# what user ID (since the user may not have specified) if 'password' not in params:
self.auth_handler.set_session_data( raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
session_id, "registered_user_id", registered_user_id
)
add_email = True desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
return_dict = yield self._create_registration_details( (user_id, token) = yield self.registration_handler.register(
registered_user_id, params localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
) )
if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result: # remember that we've now registered that user account, and with what
threepid = auth_result[LoginType.EMAIL_IDENTITY] # user ID (since the user may not have specified)
yield self._register_email_threepid( self.auth_handler.set_session_data(
registered_user_id, threepid, return_dict["access_token"], session_id, "registered_user_id", user_id
params.get("bind_email") )
)
defer.returnValue((200, return_dict)) if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (
self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users
):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=user_tuple["token_id"],
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")
emailThreepid = result[LoginType.EMAIL_IDENTITY]
threepid_creds = emailThreepid['threepid_creds']
logger.debug("Binding emails %s to %s" % (
emailThreepid, user_id
))
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
else:
logger.info("bind_email not specified: not binding email")
result = yield self._create_registration_details(user_id, token)
defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token, body): def _do_appservice_registration(self, username, as_token):
user_id = yield self.registration_handler.appservice_register( (user_id, token) = yield self.registration_handler.appservice_register(
username, as_token username, as_token
) )
defer.returnValue((yield self._create_registration_details(user_id, body))) defer.returnValue((yield self._create_registration_details(user_id, token)))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, body): def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@@ -268,7 +277,7 @@ class RegisterRestServlet(RestServlet):
# str() because otherwise hmac complains that 'unicode' does not # str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface # have the buffer interface
got_mac = str(body["mac"]) got_mac = str(mac)
want_mac = hmac.new( want_mac = hmac.new(
key=self.hs.config.registration_shared_secret, key=self.hs.config.registration_shared_secret,
@@ -281,132 +290,43 @@ class RegisterRestServlet(RestServlet):
403, "HMAC incorrect", 403, "HMAC incorrect",
) )
(user_id, _) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=username, password=password, generate_token=False, localpart=username, password=password
) )
defer.returnValue((yield self._create_registration_details(user_id, token)))
result = yield self._create_registration_details(user_id, body)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token, bind_email): def _create_registration_details(self, user_id, token):
"""Add an email address as a 3pid identifier refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
Also adds an email pusher for the email address, if configured in the
HS config
Also optionally binds emails to the given user_id on the identity server
Args:
user_id (str): id of user
threepid (object): m.login.email.identity auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
logger.info("Can't add incomplete 3pid")
defer.returnValue()
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
token_id = user_tuple["token_id"]
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if bind_email:
logger.info("bind_email specified: binding")
logger.debug("Binding emails %s to %s" % (
threepid, user_id
))
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token
and refresh_token.
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
device_id = yield self._register_device(user_id, params)
access_token, refresh_token = (
yield self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
)
defer.returnValue({ defer.returnValue({
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"refresh_token": refresh_token, "refresh_token": refresh_token,
"device_id": device_id,
}) })
def _register_device(self, user_id, params): @defer.inlineCallbacks
"""Register a device for a user. def onEmailTokenRequest(self, request):
body = parse_json_object_from_request(request)
This is called after the user's credentials have been validated, but required = ['id_server', 'client_secret', 'email', 'send_attempt']
before the access token has been issued. absent = []
for k in required:
if k not in body:
absent.append(k)
Args: if len(absent) > 0:
(str) user_id: full canonical @user:id raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
(object) params: registration parameters, from which we pull
device_id and initial_device_name existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
Returns: 'email', body['email']
defer.Deferred: (str) device_id
"""
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
) )
return device_id
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self): def _do_guest_registration(self):
@@ -416,11 +336,7 @@ class RegisterRestServlet(RestServlet):
generate_token=False, generate_token=False,
make_guest=True make_guest=True
) )
access_token = self.auth_handler.generate_access_token( access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
user_id, ["guest = true"]
)
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, { defer.returnValue((200, {
"user_id": user_id, "user_id": user_id,
"access_token": access_token, "access_token": access_token,
@@ -429,5 +345,4 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence: if affect_presence:
yield self.presence_handler.set_state(user, {"presence": set_presence}, True) yield self.presence_handler.set_state(user, {"presence": set_presence})
context = yield self.presence_handler.user_syncing( context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence, user.to_string(), affect_presence=affect_presence,

View File

@@ -1,78 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results))
class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$",
releases=())
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request, protocol):
yield self.auth.get_user_by_req(request)
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results))
def register_servlets(hs, http_server):
ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server)

View File

@@ -39,13 +39,9 @@ class TokenRefreshRestServlet(RestServlet):
try: try:
old_refresh_token = body["refresh_token"] old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler() auth_handler = self.hs.get_auth_handler()
refresh_result = yield self.store.exchange_refresh_token( (user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token old_refresh_token, auth_handler.generate_refresh_token)
) new_access_token = yield auth_handler.issue_access_token(user_id)
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, { defer.returnValue((200, {
"access_token": new_access_token, "access_token": new_access_token,
"refresh_token": new_refresh_token, "refresh_token": new_refresh_token,

View File

@@ -26,11 +26,7 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
return (200, { return (200, {
"versions": [ "versions": ["r0.0.1"]
"r0.0.1",
"r0.1.0",
"r0.2.0",
]
}) })

View File

@@ -15,7 +15,6 @@
from synapse.http.server import request_handler, respond_with_json_bytes from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.crypto.keyring import KeyLookupError
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@@ -211,10 +210,9 @@ class RemoteKey(Resource):
yield self.keyring.get_server_verify_key_v2_direct( yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
except: except:
logger.exception("Failed to get key for %r", server_name) logger.exception("Failed to get key for %r", server_name)
pass
yield self.query_keys( yield self.query_keys(
request, query, query_remote_on_cache_miss=False request, query, query_remote_on_cache_miss=False
) )

View File

@@ -45,7 +45,6 @@ class DownloadResource(Resource):
@request_handler() @request_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
request.setHeader("Content-Security-Policy", "sandbox")
server_name, media_id, name = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name) yield self._respond_local_file(request, media_id, name)

View File

@@ -35,7 +35,6 @@ import fnmatch
import cgi import cgi
import ujson as json import ujson as json
import urlparse import urlparse
import itertools
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -162,7 +161,7 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info) logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']): if self._is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails( dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info media_info['filesystem_id'], media_info
) )
@@ -183,9 +182,11 @@ class PreviewUrlResource(Resource):
logger.warn("Couldn't get dims for %s" % url) logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media # define our OG response for this media
elif _is_html(media_info['media_type']): elif self._is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM # TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import etree
file = open(media_info['filename']) file = open(media_info['filename'])
body = file.read() body = file.read()
file.close() file.close()
@@ -196,35 +197,17 @@ class PreviewUrlResource(Resource):
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I) match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8" encoding = match.group(1) if match else "utf-8"
og = decode_and_calc_og(body, media_info['uri'], encoding) try:
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body, parser)
og = yield self._calc_og(tree, media_info, requester)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = yield self._calc_og(tree, media_info, requester)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
_rebase_url(og['og:image'], media_info['uri']), requester.user
)
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
else: else:
logger.warn("Failed to find any OG data in %s", url) logger.warn("Failed to find any OG data in %s", url)
og = {} og = {}
@@ -247,6 +230,135 @@ class PreviewUrlResource(Resource):
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True) respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
@defer.inlineCallbacks
def _calc_og(self, tree, media_info, requester):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
self._rebase_url(og['og:image'], media_info['uri']), requester.user
)
if self._is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
"ancestor::aside | ancestor::footer | "
"ancestor::script | ancestor::style)]" +
"[ancestor::body]")
text = ''
for text_node in text_nodes:
if len(text) < 500:
text += text_node + ' '
else:
break
text = re.sub(r'[\t ]+', ' ', text)
text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
text = text.strip()[:500]
og['og:description'] = text if text else None
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
defer.returnValue(og)
def _rebase_url(self, url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
@defer.inlineCallbacks @defer.inlineCallbacks
def _download_url(self, url, user): def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice # TODO: we should probably honour robots.txt... except in practice
@@ -327,221 +439,14 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None, "etag": headers["ETag"][0] if "ETag" in headers else None,
}) })
def _is_media(self, content_type):
if content_type.lower().startswith("image/"):
return True
def decode_and_calc_og(body, media_uri, request_encoding=None): def _is_html(self, content_type):
from lxml import etree content_type = content_type.lower()
if (
try: content_type.startswith("text/html") or
parser = etree.HTMLParser(recover=True, encoding=request_encoding) content_type.startswith("application/xhtml")
tree = etree.fromstring(body, parser) ):
og = _calc_og(tree, media_uri) return True
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = _calc_og(tree, media_uri)
return og
def _calc_og(tree, media_uri):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = (
"header", "nav", "aside", "footer", "script", "style", etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def _iterate_over_text(tree, *tags_to_ignore):
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
# This is basically a stack that we extend using itertools.chain.
# This will either consist of an element to iterate over *or* a string
# to be returned.
elements = iter([tree])
while True:
el = elements.next()
if isinstance(el, basestring):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the elements children, interspersed with
# each child's tail text (if it exists). The tail text of a node
# is text that comes *after* the node, so we always include it even
# if we ignore the child node.
elements = itertools.chain(
itertools.chain.from_iterable( # Basically a flatmap
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
elements
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
def _is_media(content_type):
if content_type.lower().startswith("image/"):
return True
def _is_html(content_type):
content_type = content_type.lower()
if (
content_type.startswith("text/html") or
content_type.startswith("application/xhtml")
):
return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?
description = ''
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
description += text_node + '\n\n'
else:
break
description = description.strip()
description = re.sub(r'[\t ]+', ' ', description)
description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
if len(description) > max_size:
new_desc = ""
# This splits the paragraph into words, but keeping the
# (preceeding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer("\s*\S+", description):
word = match.group()
# Keep adding words while the total length is less than
# MAX_SIZE.
if len(word) + len(new_desc) < max_size:
new_desc += word
else:
# At this point the next word *will* take us over
# MAX_SIZE, but we also want to ensure that its not
# a huge word. If it is add it anyway and we'll
# truncate later.
if len(new_desc) < min_size:
new_desc += word
break
# Double check that we're not over the limit
if len(new_desc) > max_size:
new_desc = new_desc[:max_size]
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
description = new_desc.strip() + ""
return description if description else None

View File

@@ -19,39 +19,38 @@
# partial one for unit test mocking. # partial one for unit test mocking.
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
import logging
from twisted.enterprise import adbapi
from twisted.web.client import BrowserLikePolicyForHTTPS from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.enterprise import adbapi
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring from synapse.appservice.api import ApplicationServiceApi
from synapse.events.builder import EventBuilderFactory
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
from synapse.handlers import Handlers from synapse.handlers import Handlers
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.room import RoomListHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.handlers.auth import AuthHandler
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.notifier import Notifier
from synapse.push.pusherpool import PusherPool
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock from synapse.util import Clock
from synapse.util.distributor import Distributor from synapse.util.distributor import Distributor
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -93,10 +92,6 @@ class HomeServer(object):
'typing_handler', 'typing_handler',
'room_list_handler', 'room_list_handler',
'auth_handler', 'auth_handler',
'device_handler',
'e2e_keys_handler',
'event_handler',
'event_stream_handler',
'application_service_api', 'application_service_api',
'application_service_scheduler', 'application_service_scheduler',
'application_service_handler', 'application_service_handler',
@@ -202,12 +197,6 @@ class HomeServer(object):
def build_auth_handler(self): def build_auth_handler(self):
return AuthHandler(self) return AuthHandler(self)
def build_device_handler(self):
return DeviceHandler(self)
def build_e2e_keys_handler(self):
return E2eKeysHandler(self)
def build_application_service_api(self): def build_application_service_api(self):
return ApplicationServiceApi(self) return ApplicationServiceApi(self)
@@ -217,12 +206,6 @@ class HomeServer(object):
def build_application_service_handler(self): def build_application_service_handler(self):
return ApplicationServicesHandler(self) return ApplicationServicesHandler(self)
def build_event_handler(self):
return EventHandler(self)
def build_event_stream_handler(self):
return EventStreamHandler(self)
def build_event_sources(self): def build_event_sources(self):
return EventSources(self) return EventSources(self)

Some files were not shown because too many files have changed in this diff Show More