Compare commits

..

2 Commits

Author SHA1 Message Date
Erik Johnston
9854f4c7ff Add basic file API 2016-07-13 16:29:35 +01:00
Erik Johnston
518b3a3f89 Track in DB file message events 2016-07-13 15:16:02 +01:00
197 changed files with 3413 additions and 9599 deletions

View File

@@ -1,172 +1,3 @@
Changes in synapse v0.17.1 (2016-08-24)
=======================================
Changes:
* Delete old received_transactions rows (PR #1038)
* Pass through user-supplied content in /join/$room_id (PR #1039)
Bug fixes:
* Fix bug with backfill (PR #1040)
Changes in synapse v0.17.1-rc1 (2016-08-22)
===========================================
Features:
* Add notification API (PR #1028)
Changes:
* Don't print stack traces when failing to get remote keys (PR #996)
* Various federation /event/ perf improvements (PR #998)
* Only process one local membership event per room at a time (PR #1005)
* Move default display name push rule (PR #1011, #1023)
* Fix up preview URL API. Add tests. (PR #1015)
* Set ``Content-Security-Policy`` on media repo (PR #1021)
* Make notify_interested_services faster (PR #1022)
* Add usage stats to prometheus monitoring (PR #1037)
Bug fixes:
* Fix token login (PR #993)
* Fix CAS login (PR #994, #995)
* Fix /sync to not clobber status_msg (PR #997)
* Fix redacted state events to include prev_content (PR #1003)
* Fix some bugs in the auth/ldap handler (PR #1007)
* Fix backfill request to limit URI length, so that remotes don't reject the
requests due to path length limits (PR #1012)
* Fix AS push code to not send duplicate events (PR #1025)
Changes in synapse v0.17.0 (2016-08-08)
=======================================
This release contains significant security bug fixes regarding authenticating
events received over federation. PLEASE UPGRADE.
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Changes:
* Add federation /version API (PR #990)
* Make psutil dependency optional (PR #992)
Bug fixes:
* Fix URL preview API to exclude HTML comments in description (PR #988)
* Fix error handling of remote joins (PR #991)
Changes in synapse v0.17.0-rc4 (2016-08-05)
===========================================
Changes:
* Change the way we summarize URLs when previewing (PR #973)
* Add new ``/state_ids/`` federation API (PR #979)
* Speed up processing of ``/state/`` response (PR #986)
Bug fixes:
* Fix event persistence when event has already been partially persisted
(PR #975, #983, #985)
* Fix port script to also copy across backfilled events (PR #982)
Changes in synapse v0.17.0-rc3 (2016-08-02)
===========================================
Changes:
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
* Add some basic admin API docs (PR #963)
Bug fixes:
* Send the correct host header when fetching keys (PR #941)
* Fix joining a room that has missing auth events (PR #964)
* Fix various push bugs (PR #966, #970)
* Fix adding emails on registration (PR #968)
Changes in synapse v0.17.0-rc2 (2016-08-02)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08)
==========================================

View File

@@ -14,7 +14,6 @@ recursive-include docs *
recursive-include res *
recursive-include scripts *
recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py
recursive-include synapse/static *.css
@@ -24,7 +23,5 @@ recursive-include synapse/static *.js
exclude jenkins.sh
exclude jenkins*.sh
exclude jenkins*
recursive-exclude jenkins *.sh
prune demo/etc

View File

@@ -11,8 +11,8 @@ VoIP. The basics you need to know to get up and running are:
like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
you will normally refer to yourself and others using a third party identifier
(3PID): email address, phone number, etc rather than manipulating Matrix user IDs)
you will normally refer to yourself and others using a 3PID: email
address, phone number, etc rather than manipulating Matrix user IDs)
The overall architecture is::
@@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 2.7
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org
- At least 512 MB RAM.
Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the
@@ -134,12 +134,6 @@ Installing prerequisites on Raspbian::
sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv
Installing prerequisites on openSUSE::
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse
@@ -205,21 +199,6 @@ run (e.g. ``~/.synapse``), and::
source ./bin/activate
synctl start
Security Note
=============
Matrix serves raw user generated data in some APIs - specifically the content
repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
Whilst we have tried to mitigate against possible XSS attacks (e.g.
https://github.com/matrix-org/synapse/pull/1021) we recommend running
matrix homeservers on a dedicated domain name, to limit any malicious user generated
content served to web browsers a matrix API from being able to attack webapps hosted
on the same domain. This is particularly true of sharing a matrix webclient and
server on the same domain.
See https://github.com/vector-im/vector-web/issues/1977 and
https://developer.github.com/changes/2014-04-25-user-content-security for more details.
Using PostgreSQL
================
@@ -236,6 +215,9 @@ The advantages of Postgres include:
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_.
@@ -463,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs:
1) Use the machine's own hostname as available on public DNS in the form of
its A records. This is easier to set up initially, perhaps for
its A or AAAA records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV

View File

@@ -27,7 +27,7 @@ running:
# Pull the latest version of the master branch.
git pull
# Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.15.0

View File

@@ -1,12 +0,0 @@
Admin APIs
==========
This directory includes documentation for the various synapse specific admin
APIs available.
Only users that are server admins can use these APIs. A user can be marked as a
server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register.

View File

@@ -1,15 +0,0 @@
Purge History API
=================
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.

View File

@@ -1,19 +0,0 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
{
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View File

@@ -1,97 +0,0 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View File

@@ -4,19 +4,84 @@ set -eux
: ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
./dendron/jenkins/build_dendron.sh
./sytest/jenkins/prep_sytest_for_postgres.sh
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/install_and_run.sh \
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .dendron-base ]]; then
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
else
(cd .dendron-base; git fetch -p)
fi
rm -rf dendron
git clone .dendron-base dendron --shared
cd dendron
: ${GOPATH:=${WORKSPACE}/.gopath}
if [[ "${GOPATH}" != *:* ]]; then
mkdir -p "${GOPATH}"
export PATH="${GOPATH}/bin:${PATH}"
fi
export GOPATH
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
go get github.com/constabulary/gb/...
gb generate
gb build
cd ..
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
mkdir -p var
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--federation-reader \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..

View File

@@ -4,14 +4,61 @@ set -eux
: ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/prep_sytest_for_postgres.sh
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
./sytest/jenkins/install_and_run.sh \
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -4,12 +4,55 @@ set -eux
: ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/install_and_run.sh \
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_COUNT=20}
: ${PORT_BASE:=8000}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -22,9 +22,4 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27

View File

@@ -1,44 +0,0 @@
#! /bin/bash
# This clones a project from github into a named subdirectory
# If the project has a branch with the same name as this branch
# then it will checkout that branch after cloning.
# Otherwise it will checkout "origin/develop."
# The first argument is the name of the directory to checkout
# the branch into.
# The second argument is the URL of the remote repository to checkout.
# Usually something like https://github.com/matrix-org/sytest.git
set -eux
NAME=$1
PROJECT=$2
BASE=".$NAME-base"
# Update our mirror.
if [ ! -d ".$NAME-base" ]; then
# Create a local mirror of the source repository.
# This saves us from having to download the entire repository
# when this script is next run.
git clone "$PROJECT" "$BASE" --mirror
else
# Fetch any updates from the source repository.
(cd "$BASE"; git fetch -p)
fi
# Remove the existing repository so that we have a clean copy
rm -rf "$NAME"
# Cloning with --shared means that we will share portions of the
# .git directory with our local mirror.
git clone "$BASE" "$NAME" --shared
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd "$NAME"
# check out the relevant branch
git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop"
)

View File

@@ -1,20 +0,0 @@
#! /bin/bash
cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox
mkdir -p $TOX_DIR
if ! [ $TOX_DIR -ef .tox ]; then
ln -s "$TOX_DIR" .tox
fi
# set up the virtualenv
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View File

@@ -116,19 +116,17 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items():
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
authorization_headers.append(bytes(
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
origin_name, key, sig,
)
authorization_headers.append(bytes(header))
sys.stderr.write(header)
sys.stderr.write("\n")
))
result = requests.get(
lookup(destination, path),
headers={"Authorization": authorization_headers[0]},
verify=False,
)
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json()
@@ -143,7 +141,6 @@ def main():
)
json.dump(result, sys.stdout)
print ""
if __name__ == "__main__":
main()

View File

@@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"],
"events": ["processed", "outlier"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
@@ -92,12 +92,8 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__[
"_simple_select_one_onecol_txn"
]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@@ -162,40 +158,31 @@ class Porter(object):
def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
row = yield self.postgres_store._simple_select_one(
next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3",
keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"),
retcol="rowid",
allow_none=True,
)
total_to_port = None
if row is None:
if next_chunk is None:
if table == "sent_transactions":
forward_chunk, already_ported, total_to_port = (
next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions()
)
backward_chunk = 0
else:
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
values={"table_name": table, "rowid": 1}
)
forward_chunk = 1
backward_chunk = 0
next_chunk = 1
already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk
table, next_chunk
)
else:
def delete_all(txn):
@@ -209,85 +196,46 @@ class Porter(object):
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
values={"table_name": table, "rowid": 0}
)
forward_chunk = 1
backward_chunk = 0
next_chunk = 1
already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk
table, next_chunk
)
defer.returnValue(
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
defer.returnValue((table, already_ported, total_to_port, next_chunk))
@defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk,
backward_chunk):
def handle_table(self, table, postgres_size, table_size, next_chunk):
if not table_size:
return
self.progress.add_table(table, postgres_size, table_size)
if table == "event_search":
yield self.handle_search_table(
postgres_size, table_size, forward_chunk, backward_chunk
)
yield self.handle_search_table(postgres_size, table_size, next_chunk)
return
forward_select = (
select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
)
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
)
do_forward = [True]
do_backward = [True]
while True:
def r(txn):
forward_rows = []
backward_rows = []
if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
if forward_rows or backward_rows:
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
else:
headers = None
return headers, forward_rows, backward_rows
return headers, rows
headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if frows or brows:
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
backward_chunk = min(row[0] for row in brows) - 1
if rows:
next_chunk = rows[-1][0] + 1
rows = frows + brows
self._convert_rows(table, headers, rows)
def insert(txn):
@@ -299,10 +247,7 @@ class Porter(object):
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
@@ -314,8 +259,7 @@ class Porter(object):
return
@defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk,
backward_chunk):
def handle_search_table(self, postgres_size, table_size, next_chunk):
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
@@ -326,7 +270,7 @@ class Porter(object):
while True:
def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,))
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
@@ -335,7 +279,7 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
forward_chunk = rows[-1][0] + 1
next_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a
# different structure in the two different databases.
@@ -368,10 +312,7 @@ class Porter(object):
txn,
table="port_from_sqlite3",
keyvalues={"table_name": "event_search"},
updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
@@ -383,6 +324,7 @@ class Porter(object):
else:
return
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
@@ -453,32 +395,10 @@ class Porter(object):
txn.execute(
"CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
" forward_rowid bigint NOT NULL,"
" backward_rowid bigint NOT NULL"
" rowid bigint NOT NULL"
")"
)
# The old port script created a table with just a "rowid" column.
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
)
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
try:
yield self.postgres_store.runInteraction(
"create_port_table", create_port_table
@@ -538,7 +458,7 @@ class Porter(object):
@defer.inlineCallbacks
def _setup_sent_transactions(self):
# Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000
yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination
select = (
@@ -594,11 +514,7 @@ class Porter(object):
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
values={"table_name": "sent_transactions", "rowid": next_chunk}
)
def get_sent_table_size(txn):
@@ -619,18 +535,13 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
frows = yield self.sqlite_store.execute_sql(
def _get_remaining_count_to_port(self, table, next_chunk):
rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
forward_chunk,
next_chunk,
)
brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
)
defer.returnValue(frows[0][0] + brows[0][0])
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
@@ -641,10 +552,10 @@ class Porter(object):
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults(
[
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table),
],
consumeErrors=True,
@@ -775,7 +686,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr(
i + 2, left_margin + max_len - len(table),
i+2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
)
@@ -783,18 +694,18 @@ class CursesProgress(Progress):
size = 20
progress = "[%s%s]" % (
"#" * int(perc * size / 100),
" " * (size - int(perc * size / 100)),
"#" * int(perc*size/100),
" " * (size - int(perc*size/100)),
)
self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space,
i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
if self.finished:
self.stdscr.addstr(
rows - 1, 0,
rows-1, 0,
"Press any key to exit...",
)

View File

@@ -16,5 +16,7 @@ ignore =
[flake8]
max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.17.1"
__version__ = "0.16.1-r1"

View File

@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64
import logging
import pymacaroons
logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class Auth(object):
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([
@@ -63,18 +63,7 @@ class Auth(object):
"user_id = ",
])
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
self.check(event, auth_events=auth_events, do_sig_check=False)
def check(self, event, auth_events, do_sig_check=True):
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
Args:
@@ -90,13 +79,6 @@ class Auth(object):
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender)
# Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
@@ -104,12 +86,6 @@ class Auth(object):
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
@@ -132,22 +108,6 @@ class Auth(object):
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
logger.debug(
@@ -278,17 +238,21 @@ class Auth(object):
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
curr_state = yield self.state.get_current_state(room_id)
entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids
)
for event in curr_state.values():
if event.type == EventTypes.Member:
try:
if get_domain_from_id(event.state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
continue
ret = yield self.store.is_host_joined(
room_id, host, entry.state_group, entry.state
)
defer.returnValue(ret)
if event.content["membership"] == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)
def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, )
@@ -383,10 +347,6 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
@@ -577,7 +537,9 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object
tuple of:
UserID (str)
Access token ID (str)
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -586,7 +548,9 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id))
defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights)
@@ -594,10 +558,6 @@ class Auth(object):
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
"User-Agent",
@@ -609,8 +569,7 @@ class Auth(object):
user=user,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
device_id=device_id,
user_agent=user_agent
)
if is_guest and not allow_guest:
@@ -620,8 +579,7 @@ class Auth(object):
request.authenticated_entity = user.to_string()
defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id))
defer.returnValue(Requester(user, token_id, is_guest))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -671,10 +629,7 @@ class Auth(object):
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret)
@defer.inlineCallbacks
@@ -682,25 +637,33 @@ class Auth(object):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
user_id = self.get_user_id_from_macaroon(macaroon)
user_prefix = "user_id = "
user = None
user_id = None
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user_id = caveat.caveat_id[len(user_prefix):]
user = UserID.from_string(user_id)
elif caveat.caveat_id == "guest = true":
guest = True
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id == "guest = true":
guest = True
if user is None:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
if guest:
ret = {
"user": user,
"is_guest": True,
"token_id": None,
"device_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
@@ -708,20 +671,13 @@ class Auth(object):
"user": user,
"is_guest": False,
"token_id": None,
"device_id": None,
}
else:
# This codepath exists for several reasons:
# * so that we can actually return a token ID, which is used
# in some parts of the schema (where we probably ought to
# use device IDs instead)
# * the only way we currently have to invalidate an
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
@@ -741,29 +697,6 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
AuthError if there is no user_id caveat in the macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -775,7 +708,6 @@ class Auth(object):
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
"""
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
@@ -819,14 +751,10 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"device_id": ret.get("device_id"),
}
defer.returnValue(user_info)
@@ -854,7 +782,7 @@ class Auth(object):
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids
@@ -862,32 +790,30 @@ class Auth(object):
builder.auth_events = auth_events_entries
@defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
defer.returnValue([])
return []
auth_ids = []
key = (EventTypes.PowerLevels, "", )
power_level_event_id = current_state_ids.get(key)
power_level_event = current_state.get(key)
if power_level_event_id:
auth_ids.append(power_level_event_id)
if power_level_event:
auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", )
join_rule_event_id = current_state_ids.get(key)
join_rule_event = current_state.get(key)
key = (EventTypes.Member, event.user_id, )
member_event_id = current_state_ids.get(key)
member_event = current_state.get(key)
key = (EventTypes.Create, "", )
create_event_id = current_state_ids.get(key)
if create_event_id:
auth_ids.append(create_event_id)
create_event = current_state.get(key)
if create_event:
auth_ids.append(create_event.event_id)
if join_rule_event_id:
join_rule_event = yield self.store.get_event(join_rule_event_id)
if join_rule_event:
join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else:
@@ -896,21 +822,15 @@ class Auth(object):
if event.type == EventTypes.Member:
e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event_id:
auth_ids.append(join_rule_event_id)
if join_rule_event:
auth_ids.append(join_rule_event.event_id)
if e_type == Membership.JOIN:
if member_event_id and not is_public:
auth_ids.append(member_event_id)
if member_event and not is_public:
auth_ids.append(member_event.event_id)
else:
if member_event_id:
auth_ids.append(member_event_id)
if for_verification:
key = (EventTypes.Member, event.state_key, )
existing_event_id = current_state_ids.get(key)
if existing_event_id:
auth_ids.append(existing_event_id)
if member_event:
auth_ids.append(member_event.event_id)
if e_type == Membership.INVITE:
if "third_party_invite" in event.content:
@@ -918,15 +838,14 @@ class Auth(object):
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
third_party_invite_id = current_state_ids.get(key)
if third_party_invite_id:
auth_ids.append(third_party_invite_id)
elif member_event_id:
member_event = yield self.store.get_event(member_event_id)
third_party_invite = current_state.get(key)
if third_party_invite:
auth_ids.append(third_party_invite.event_id)
elif member_event:
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
defer.returnValue(auth_ids)
return auth_ids
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )

View File

@@ -85,8 +85,3 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"

View File

@@ -43,7 +43,6 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"

View File

@@ -191,17 +191,6 @@ class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event):
"""Checks whether the filter matches the given event.
@@ -220,10 +209,9 @@ class Filter(object):
event.get("room_id", None),
sender,
event.get("type", None),
"url" in event.get("content", {})
)
def check_fields(self, room_id, sender, event_type, contains_url):
def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields.
Returns:
@@ -237,20 +225,15 @@ class Filter(object):
for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,)
disallowed_values = getattr(self, not_name)
disallowed_values = self.filter_json.get(not_name, [])
if any(map(match_func, disallowed_values)):
return False
allowed_values = getattr(self, name)
allowed_values = self.filter_json.get(name, None)
if allowed_values is not None:
if not any(map(match_func, allowed_values)):
return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True
def filter_rooms(self, room_ids):

View File

@@ -25,3 +25,4 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@@ -16,11 +16,13 @@
import sys
sys.dont_write_bytecode = True
from synapse import python_dependencies # noqa: E402
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try:
python_dependencies.check_requirements()
except python_dependencies.MissingRequirementError as e:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",

View File

@@ -1,209 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -1,206 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
TransactionStore,
BaseSlavedStore,
):
pass
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -51,7 +51,6 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
@@ -285,7 +284,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config
check_requirements(config)
version_string = "Synapse/" + get_version_string(synapse)
version_string = get_version_string("Synapse", synapse)
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
@@ -336,8 +335,6 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs)
reactor.callWhenRunning(start)
return hs
@@ -385,8 +382,6 @@ def run(hs):
start_time = hs.get_clock().time()
stats = {}
@defer.inlineCallbacks
def phone_stats_home():
logger.info("Gathering stats for reporting")
@@ -395,10 +390,7 @@ def run(hs):
if uptime < 0:
uptime = 0
# If the stats directory is empty then this is the first time we've
# reported stats.
first_time = not stats
stats = {}
stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
@@ -411,25 +403,6 @@ def run(hs):
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
else:
stats.pop("daily_messages", None)
if first_time:
# Add callbacks to report the synapse stats as metrics whenever
# prometheus requests them, typically every 30s.
# As some of the stats are expensive to calculate we only update
# them when synapse phones home to matrix.org every 24 hours.
metrics = get_metrics_for("synapse.usage")
metrics.add_callback("timestamp", lambda: stats["timestamp"])
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
metrics.add_callback("total_users", lambda: stats["total_users"])
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
metrics.add_callback(
"daily_active_users", lambda: stats["daily_active_users"]
)
metrics.add_callback(
"daily_messages", lambda: stats.get("daily_messages", 0)
)
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:

View File

@@ -1,212 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -80,6 +80,11 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__
)
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@@ -163,6 +168,7 @@ class PusherServer(HomeServer):
store = self.get_datastore()
replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
@@ -214,11 +220,21 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids
)
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
poke_pushers(result)
except:
@@ -257,7 +273,7 @@ def start(config_options):
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
version_string=get_version_string("Synapse", synapse),
database_engine=database_engine,
)

View File

@@ -26,7 +26,6 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@@ -36,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
@@ -73,10 +71,14 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
SlavedDeviceInboxStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@@ -87,23 +89,17 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
@@ -123,13 +119,11 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state, ignore_status_msg=False):
def set_state(self, user, state):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
@@ -200,39 +194,19 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False
yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
state = UserPresenceState(
self.user_to_current_state[user_id] = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object):
@@ -292,12 +266,10 @@ class SynchrotronServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
@@ -335,10 +307,15 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
@@ -399,19 +376,23 @@ class SynchrotronServer(HomeServer):
notify_from_stream(
result, "typing", "typing_key", room="room_id"
)
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
typing_handler.process_replication(result)
yield presence_handler.process_replication(result)
presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
@@ -443,7 +424,7 @@ def start(config_options):
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
version_string=get_version_string("Synapse", synapse),
database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)

View File

@@ -14,8 +14,6 @@
# limitations under the License.
from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging
import re
@@ -81,7 +79,7 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None):
sender=None, id=None):
self.token = token
self.url = url
self.hs_token = hs_token
@@ -89,12 +87,6 @@ class ApplicationService(object):
self.namespaces = self._check_namespaces(namespaces)
self.id = id
# .protocols is a publicly visible field
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
# {
@@ -146,66 +138,65 @@ class ApplicationService(object):
return regex_obj["exclusive"]
return False
@defer.inlineCallbacks
def _matches_user(self, event, store):
if not event:
defer.returnValue(False)
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)):
return True
# also check m.room.member state key
if (event.type == EventTypes.Member and
self.is_interested_in_user(event.state_key)):
defer.returnValue(True)
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
if (hasattr(event, "type") and event.type == EventTypes.Member
and hasattr(event, "state_key")
and self.is_interested_in_user(event.state_key)):
return True
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
defer.returnValue(True)
defer.returnValue(False)
return True
return False
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
@defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
def _matches_aliases(self, event, alias_list):
for alias in alias_list:
if self.is_interested_in_alias(alias):
defer.returnValue(True)
defer.returnValue(False)
return True
return False
@defer.inlineCallbacks
def is_interested(self, event, store=None):
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
member_list=None):
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
store(DataStore)
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
Returns:
bool: True if this service would like to know about this event.
"""
# Do cheap checks first
if self._matches_room_id(event):
defer.returnValue(True)
if aliases_for_event is None:
aliases_for_event = []
if member_list is None:
member_list = []
if (yield self._matches_aliases(event, store)):
defer.returnValue(True)
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
# this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if (yield self._matches_user(event, store)):
defer.returnValue(True)
defer.returnValue(False)
if not restrict_to:
return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id):
return (
@@ -225,9 +216,6 @@ class ApplicationService(object):
or user_id == self.sender
)
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View File

@@ -14,11 +14,9 @@
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache
import logging
import urllib
@@ -26,34 +24,6 @@ import urllib
logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
@@ -63,12 +33,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks
def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None
try:
@@ -88,8 +54,6 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None
try:
@@ -107,77 +71,8 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
if service.url is None:
defer.returnValue([])
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
if service.url is None:
defer.returnValue({})
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
)
try:
defer.returnValue((yield self.get_json(uri, {})))
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
defer.returnValue({})
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
if service.url is None:
defer.returnValue(True)
events = self._serialize(events)
if txn_id is None:

View File

@@ -48,12 +48,9 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
@@ -76,7 +73,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer
)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
self.queuer = _ServiceQueuer(self.txn_ctrl)
@defer.inlineCallbacks
def start(self):
@@ -97,36 +94,38 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run.
"""
def __init__(self, txn_ctrl, clock):
def __init__(self, txn_ctrl):
self.queued_events = {} # dict of {service_id: [events]}
self.requests_in_flight = set()
self.pending_requests = {} # dict of {service_id: Deferred}
self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event):
# if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event)
preserve_fn(self._send_request)(service)
if not self.pending_requests.get(service.id):
self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
@defer.inlineCallbacks
def _send_request(self, service):
if service.id in self.requests_in_flight:
return
def _send_request(self, service, events):
# send request and add callbacks
d = self.txn_ctrl.send(service, events)
d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
self.requests_in_flight.add(service.id)
try:
while True:
events = self.queued_events.pop(service.id, [])
if not events:
return
def _on_request_finish(self, service):
self.pending_requests[service.id] = None
# if there are queued events, then send them.
if (service.id in self.queued_events
and len(self.queued_events[service.id]) > 0):
self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
with Measure(self.clock, "servicequeuer.send"):
try:
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
def _on_request_fail(self, err):
logger.error("AS request failed: %s", err)
class _TransactionController(object):
@@ -150,12 +149,14 @@ class _TransactionController(object):
if service_is_up:
sent = yield txn.send(self.as_api)
if sent:
yield txn.complete(self.store)
txn.complete(self.store)
else:
preserve_fn(self._start_recoverer)(service)
self._start_recoverer(service)
except Exception as e:
logger.exception(e)
preserve_fn(self._start_recoverer)(service)
self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):

View File

@@ -28,7 +28,6 @@ class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs):
return """\
@@ -86,7 +85,7 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [
"id", "as_token", "hs_token", "sender_localpart"
"id", "url", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
@@ -94,14 +93,6 @@ def _load_appservice(hostname, as_info, config_filename):
field, config_filename,
))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), basestring) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
@@ -131,22 +122,6 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
if as_info["url"] is None:
logger.info(
"(%s) Explicitly empty 'url' provided. This application service"
" will not receive events or queries.",
config_filename,
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
@@ -154,5 +129,4 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
protocols=protocols,
)

View File

@@ -77,12 +77,10 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self):
self.remote_key = defer.Deferred()
self.host = None
self._peer = None
def connectionMade(self):
self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self._peer)
self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
@@ -126,10 +124,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel()
def on_timeout(self):
logger.debug(
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
logger.debug("Timeout waiting for response from %s", self.host)
self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
@@ -138,5 +133,4 @@ class SynapseKeyClientFactory(Factory):
def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
protocol.host = self.host
return protocol

View File

@@ -22,7 +22,6 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
from synapse.util.metrics import Measure
from twisted.internet import defer
@@ -45,25 +44,7 @@ import logging
logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple("VerifyRequest", (
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class KeyLookupError(ValueError):
pass
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object):
@@ -93,32 +74,39 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
verify_requests = []
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
deferred = defer.fail(SynapseError(
deferreds[group_id] = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
deferred = defer.Deferred()
deferreds[group_id] = defer.Deferred()
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
group = KeyGroup(server_name, group_id, key_ids)
verify_requests.append(verify_request)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks
def handle_key_deferred(verify_request):
server_name = verify_request.server_name
def handle_key_deferred(group, deferred):
server_name = group.server_name
try:
_, key_id, verify_key = yield verify_request.deferred
_, _, key_id, verify_key = yield deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
@@ -140,7 +128,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
json_object = verify_request.json_object
json_object = group_id_to_json[group.group_id]
try:
verify_signed_json(json_object, server_name, verify_key)
@@ -169,34 +157,36 @@ class Keyring(object):
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(verify_requests)
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_request_ids = {}
server_to_gids = {}
def remove_deferreds(res, server_name, verify_request):
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
preserve_context_over_fn(handle_key_deferred, verify_request)
for verify_request in verify_requests
preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
]
@defer.inlineCallbacks
@@ -230,7 +220,7 @@ class Keyring(object):
d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests):
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@@ -244,79 +234,76 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
merged_results = {}
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
for group in group_id_to_group.values():
missing_keys.setdefault(group.server_name, set()).update(
group.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
verify_request.deferred.callback((
server_name,
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
result_keys[key_id],
merged_results[group.server_name][key_id],
))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_keys:
if not missing_groups:
break
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for verify_request in verify_requests:
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield preserve_context_over_deferred(defer.gatherResults(
res = yield defer.gatherResults(
[
preserve_fn(self.store.get_server_verify_keys)(
self.store.get_server_verify_keys(
server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
defer.returnValue(dict(res))
@@ -337,13 +324,13 @@ class Keyring(object):
)
defer.returnValue({})
results = yield preserve_context_over_deferred(defer.gatherResults(
results = yield defer.gatherResults(
[
preserve_fn(get_key)(p_name, p_keys)
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
union_of_keys = {}
for result in results:
@@ -369,7 +356,7 @@ class Keyring(object):
)
except Exception as e:
logger.info(
"Unable to get key %r for %r directly: %s %s",
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
@@ -383,13 +370,13 @@ class Keyring(object):
defer.returnValue(keys)
results = yield preserve_context_over_deferred(defer.gatherResults(
results = yield defer.gatherResults(
[
preserve_fn(get_key)(server_name, key_ids)
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
merged = {}
for result in results:
@@ -431,7 +418,7 @@ class Keyring(object):
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
raise KeyLookupError(
raise ValueError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
@@ -454,21 +441,21 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
raise KeyLookupError(
raise ValueError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
processed_response = yield self.process_v2_response(
perspective_name, response, only_from_server=False
perspective_name, response
)
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults(
yield defer.gatherResults(
[
preserve_fn(self.store_keys)(
self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -476,7 +463,7 @@ class Keyring(object):
for server_name, response_keys in keys.items()
],
consumeErrors=True
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
defer.returnValue(keys)
@@ -497,10 +484,10 @@ class Keyring(object):
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server")
raise ValueError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
raise KeyLookupError("Key response missing TLS fingerprints")
raise ValueError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
@@ -514,7 +501,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise KeyLookupError("TLS certificate not allowed by fingerprints")
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
@@ -524,7 +511,7 @@ class Keyring(object):
keys.update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults(
yield defer.gatherResults(
[
preserve_fn(self.store_keys)(
server_name=key_server_name,
@@ -534,13 +521,13 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
requested_ids=[], only_from_server=True):
requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -564,16 +551,9 @@ class Keyring(object):
results = {}
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise KeyLookupError(
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
@@ -600,7 +580,7 @@ class Keyring(object):
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
yield preserve_context_over_deferred(defer.gatherResults(
yield defer.gatherResults(
[
preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
@@ -613,7 +593,7 @@ class Keyring(object):
for key_id in updated_key_ids
],
consumeErrors=True,
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
results[server_name] = response_keys
@@ -641,15 +621,15 @@ class Keyring(object):
if ("signatures" not in response
or server_name not in response["signatures"]):
raise KeyLookupError("Key response not signed by remote server")
raise ValueError("Key response not signed by remote server")
if "tls_certificate" not in response:
raise KeyLookupError("Key response missing TLS certificate")
raise ValueError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise KeyLookupError("TLS certificate doesn't match")
raise ValueError("TLS certificate doesn't match")
# Cache the result in the datastore.
@@ -665,7 +645,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
raise KeyLookupError(
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
@@ -702,7 +682,7 @@ class Keyring(object):
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
yield preserve_context_over_deferred(defer.gatherResults(
yield defer.gatherResults(
[
preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
@@ -710,4 +690,4 @@ class Keyring(object):
for key_id, key in verify_keys.items()
],
consumeErrors=True,
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)

View File

@@ -99,7 +99,7 @@ class EventBase(object):
return d
def get(self, key, default=None):
def get(self, key, default):
return self._event_dict.get(key, default)
def get_internal_metadata_dict(self):

View File

@@ -15,9 +15,9 @@
class EventContext(object):
def __init__(self):
self.current_state_ids = None
self.prev_state_ids = None
def __init__(self, current_state=None):
self.current_state = current_state
self.state_group = None
self.rejected = False
self.push_actions = []

View File

@@ -88,8 +88,6 @@ def prune_event(event):
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)(
allowed_fields,

View File

@@ -23,7 +23,6 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging
@@ -103,10 +102,10 @@ class FederationBase(object):
warn, pdu
)
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
valid_pdus = yield defer.gatherResults(
deferreds,
consumeErrors=True
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -130,7 +129,7 @@ class FederationBase(object):
for pdu in pdus
]
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])

View File

@@ -27,9 +27,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -53,35 +51,10 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
self.state = hs.get_state_handler()
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
old_dict = self.pdu_destination_tried
self.pdu_destination_tried = {}
for event_id, destination_dict in old_dict.items():
destination_dict = {
dest: time
for dest, time in destination_dict.items()
if time + PDU_RETRY_TIME_MS > now
}
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@@ -228,10 +201,10 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
pdus[:] = yield defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
defer.returnValue(pdus)
@@ -263,19 +236,12 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache:
ev = self._get_pdu_cache.get(event_id)
if ev:
defer.returnValue(ev)
e = self._get_pdu_cache.get(event_id)
if e:
defer.returnValue(e)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
signed_pdu = None
pdu = None
for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
if last_attempt + PDU_RETRY_TIME_MS > now:
continue
try:
limiter = yield get_retry_limiter(
destination,
@@ -299,33 +265,39 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
pdu_attempts[destination] = now
except SynapseError as e:
except SynapseError:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
pdu_attempts[destination] = now
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
if self._get_pdu_cache is not None and signed_pdu:
self._get_pdu_cache[event_id] = signed_pdu
if self._get_pdu_cache is not None and pdu:
self._get_pdu_cache[event_id] = pdu
defer.returnValue(signed_pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
@@ -342,42 +314,6 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs.
"""
try:
# First we try and ask for just the IDs, as thats far quicker if
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id,
)
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events(
[destination], room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
logger.warn("Failed to get %r", failed_to_fetch)
event_map = {
ev.event_id: ev for ev in fetched_events
}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue((pdus, auth_chain))
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
else:
raise e
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
@@ -391,95 +327,18 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", [])
]
seen_events = yield self.store.get_events([
ev.event_id for ev in itertools.chain(pdus, auth_chain)
])
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination,
[p for p in pdus if p.event_id not in seen_events],
outlier=True
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
destination, pdus, outlier=True
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination,
[p for p in auth_chain if p.event_id not in seen_events],
outlier=True
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
def get_events(self, destinations, room_id, event_ids, return_local=True):
"""Fetch events from some remote destinations, checking if we already
have them.
Args:
destinations (list)
room_id (str)
event_ids (list)
return_local (bool): Whether to include events we already have in
the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
signed_events = []
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
def random_server_list():
srvs = list(destinations)
random.shuffle(srvs)
return srvs
batch_size = 20
missing_events = list(missing_events)
for i in xrange(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
defer.returnValue((signed_events, failed_to_fetch))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
@@ -555,19 +414,14 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict))
)
break
except CodeMessageException as e:
if not 500 <= e.code < 600:
except CodeMessageException:
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
raise
raise RuntimeError("Failed to send to any server.")
@@ -639,14 +493,8 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth,
"origin": destination,
})
except CodeMessageException as e:
if not 500 <= e.code < 600:
except CodeMessageException:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e:
logger.exception(
"Failed to send_join via %s: %s",
@@ -813,8 +661,7 @@ class FederationClient(FederationBase):
if len(signed_events) >= limit:
defer.returnValue(signed_events)
users = yield self.state.get_current_user_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(servers)
servers.discard(self.server_name)
@@ -859,16 +706,14 @@ class FederationClient(FederationBase):
return srvs
deferreds = [
preserve_fn(self.get_pdu)(
self.get_pdu(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
signed_events.append(val)

View File

@@ -21,11 +21,10 @@ from .units import Transaction, Edu
from synapse.util.async import Linearizer
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
@@ -49,15 +48,9 @@ class FederationServer(FederationBase):
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
@@ -195,48 +188,10 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
)
else:
resp = yield result
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_state_ids_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_ids = yield self.handler.get_state_ids_for_pdu(
room_id, event_id,
)
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
defer.returnValue((200, {
"pdu_ids": state_ids,
"auth_chain_ids": auth_chain_ids,
}))
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
@@ -253,11 +208,13 @@ class FederationServer(FederationBase):
self.hs.config.signing_key[0]
)
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue({
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
})
}))
@defer.inlineCallbacks
@log_function
@@ -391,9 +348,27 @@ class FederationServer(FederationBase):
(200, send_content)
)
@defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)
query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks
@log_function
@@ -603,7 +578,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id,
)
except:
logger.exception("Failed to get state for event: %s", pdu.event_id)
logger.warn("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu(
origin,

View File

@@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func
import synapse.metrics
import logging
@@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer
self.clock = hs.get_clock()
self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
@@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
self._next_txn_id = int(self._clock.time_msec())
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -119,46 +119,89 @@ class TransactionQueue(object):
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order)
(pdu, deferred, order)
)
preserve_context_over_fn(
self._attempt_new_transaction, destination
)
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if not self.can_send_to(destination):
return
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn(
self._attempt_new_transaction, destination
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
deferred = defer.Deferred()
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
).append(failure)
preserve_context_over_fn(
self._attempt_new_transaction, destination
).append(
(failure, deferred)
)
def chain(f):
if not deferred.called:
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send failure to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
yield run_on_reactor()
while True:
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
@@ -183,31 +226,27 @@ class TransactionQueue(object):
logger.debug("TX [%s] Nothing to send", destination)
return
yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures
)
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
destination,
self.clock,
self._clock,
self.store,
)
@@ -223,7 +262,7 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
@@ -254,7 +293,7 @@ class TransactionQueue(object):
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
@@ -294,11 +333,22 @@ class TransactionQueue(object):
logger.debug("TX [%s] Marked as delivered", destination)
if code != 200:
for p in pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
deferred.callback(None)
else:
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
@@ -313,9 +363,6 @@ class TransactionQueue(object):
destination,
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
@@ -325,9 +372,13 @@ class TransactionQueue(object):
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)

View File

@@ -54,28 +54,6 @@ class TransportLayerClient(object):
destination, path=path, args={"event_id": event_id},
)
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
destination (str): The host name of the remote home server we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = PREFIX + "/state_ids/%s/" % room_id
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@log_function
def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server.

View File

@@ -18,14 +18,13 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
import functools
import logging
import simplejson as json
import re
import synapse
logger = logging.getLogger(__name__)
@@ -61,16 +60,6 @@ class TransportLayerServer(JsonResource):
)
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object):
def __init__(self, hs):
self.keyring = hs.get_keyring()
@@ -78,7 +67,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request, content):
def authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
@@ -86,11 +75,18 @@ class Authenticator(object):
"signatures": {},
}
if content is not None:
json_request["content"] = content
content = None
origin = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
def parse_auth_header(header_str):
try:
params = auth.split(" ")[1].split(",")
@@ -107,14 +103,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"])
return (origin, key, sig)
except:
raise AuthenticationError(
raise SynapseError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers:
raise NoAuthenticationError(
raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -125,7 +121,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]:
raise NoAuthenticationError(
raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
@@ -134,12 +130,10 @@ class Authenticator(object):
logger.info("Request from %s", origin)
request.authenticated_entity = origin
defer.returnValue(origin)
defer.returnValue((origin, content))
class BaseFederationServlet(object):
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler
@@ -147,46 +141,29 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
def _wrap(self, func):
def _wrap(self, code):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
@defer.inlineCallbacks
@functools.wraps(func)
def new_func(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
@functools.wraps(code)
def new_code(request, *args, **kwargs):
try:
origin = yield authenticator.authenticate_request(request, content)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
logger.exception("authenticate_request failed")
raise
(origin, content) = yield authenticator.authenticate_request(request)
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield code(
origin, content, request.args, *args, **kwargs
)
except:
logger.exception("authenticate_request failed")
raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish
new_func.__self__ = func.__self__
new_code.__self__ = code.__self__
return new_func
return new_code
def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -294,17 +271,6 @@ class FederationStateServlet(BaseFederationServlet):
)
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/"
def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request(
origin,
room_id,
query.get("event_id", [None])[0],
)
class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/"
@@ -401,8 +367,10 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
return self.handler.on_query_client_keys(origin, content)
response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet):
@@ -452,10 +420,9 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
def on_POST(self, request):
content = parse_json_object_from_request(request)
if "invites" in content:
last_exception = None
for invite in content["invites"]:
@@ -477,6 +444,11 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception
defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet):
"""
@@ -497,11 +469,9 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
token = query.get("access_token", [None])[0]
def on_GET(self, request):
token = parse_string(request, "access_token")
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -518,6 +488,11 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet):
"""
@@ -558,26 +533,11 @@ class PublicRoomList(BaseFederationServlet):
defer.returnValue((200, data))
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
FederationStateIdsServlet,
FederationBackfillServlet,
FederationQueryServlet,
FederationMakeJoinServlet,
@@ -595,7 +555,6 @@ SERVLET_CLASSES = (
On3pidBindServlet,
OpenIdUserInfo,
PublicRoomList,
FederationVersionServlet,
)

View File

@@ -19,6 +19,7 @@ from .room import (
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
@@ -30,21 +31,10 @@ from .search import SearchHandler
class Handlers(object):
""" Deprecated. A collection of handlers.
""" A collection of all the event handlers.
At some point most of the classes whose name ended "Handler" were
accessed through this class.
However this makes it painful to unit test the handlers and to run cut
down versions of synapse that only use specific handlers because using a
single handler required creating all of the handlers. So some of the
handlers have been lifted out of the Handlers object and are now accessed
directly through the homeserver object itself.
Any new handlers should follow the new pattern of being accessed through
the homeserver object and should not be added to the Handlers object.
The remaining handlers should be moved out of the handlers object.
There's no need to lazily create these; we'll just make them all eagerly
at construction time.
"""
def __init__(self, hs):
@@ -52,6 +42,8 @@ class Handlers(object):
self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)

View File

@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
import synapse.types
from synapse.api.constants import Membership, EventTypes
from synapse.api.errors import LimitExceededError
from synapse.types import UserID
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester
import logging
logger = logging.getLogger(__name__)
@@ -31,15 +31,11 @@ class BaseHandler(object):
Common base class for the event handlers.
Attributes:
store (synapse.storage.DataStore):
store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -65,21 +61,33 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
self.hs.is_mine_id(state_key)
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, context=None):
def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
current_state = current_state.values()
else:
current_state = yield self.store.get_current_state(event.room_id)
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
@defer.inlineCallbacks
@@ -112,8 +120,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
requester = synapse.types.create_requester(
target_user, is_guest=True)
requester = Requester(target_user, "", True)
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,

View File

@@ -16,8 +16,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.appservice import ApplicationService
import logging
@@ -43,53 +42,25 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks
def notify_interested_services(self, current_id):
def notify_interested_services(self, event):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
prolonged length of time.
Args:
current_id(int): The current maximum ID.
event(Event): The event to push out to interested services.
"""
services = yield self.store.get_app_services()
if not services or not self.notify_appservices:
return
self.current_max = max(self.current_max, current_id)
if self.is_processing:
return
with Measure(self.clock, "notify_interested_services"):
self.is_processing = True
try:
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
if not events:
break
for event in events:
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
continue # no services need notifying
return # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
# Do we know this user exists? If not, poke the user query API for
# all services which match that user regex. This needs to block as these
# user queries need to be made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
@@ -100,16 +71,7 @@ class ApplicationServicesHandler(object):
# Fork off pushes to these services
for service in services:
preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
self.scheduler.submit_event_for_as(service, event)
@defer.inlineCallbacks
def query_user_exists(self, user_id):
@@ -142,12 +104,11 @@ class ApplicationServicesHandler(object):
association can be found.
"""
room_alias_str = room_alias.to_string()
services = yield self.store.get_app_services()
alias_query_services = [
s for s in services if (
s.is_interested_in_alias(room_alias_str)
alias_query_services = yield self._get_services_for_event(
event=None,
restrict_to=ApplicationService.NS_ALIASES,
alias_list=[room_alias_str]
)
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str
@@ -160,45 +121,34 @@ class ApplicationServicesHandler(object):
defer.returnValue(result)
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield preserve_context_over_deferred(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services
], consumeErrors=True))
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_3pe_protocols(self):
services = yield self.store.get_app_services()
protocols = {}
for s in services:
for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
defer.returnValue(protocols)
@defer.inlineCallbacks
def _get_services_for_event(self, event):
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
"""
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
yield s.is_interested(event, self.store)
s.is_interested(event, restrict_to, alias_list, member_list)
)
]
defer.returnValue(interested_list)
@@ -213,14 +163,6 @@ class ApplicationServicesHandler(object):
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):

View File

@@ -45,10 +45,6 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
@@ -70,14 +66,13 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base
self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -235,6 +230,7 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -244,7 +240,11 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
return self._check_password(user_id, password)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -280,16 +280,7 @@ class AuthHandler(BaseHandler):
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@@ -357,84 +348,67 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
def validate_password_login(self, user_id, password):
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): complete @user:id
user_id (str): User ID
password (str): Password
Returns:
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args:
user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((access_token, refresh_token))
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def check_user_exists(self, user_id):
def get_login_tuple_for_user_id(self, user_id):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Gets login tuple for the user with the given user ID.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
Args:
(str) user_id: complete @user:id
user_id (str): User ID
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try:
res = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(res[0])
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
except LoginError:
defer.returnValue(None)
defer.returnValue(False)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -464,45 +438,27 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_password(self, user_id, password):
"""Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
"""
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
True if the user_id successfully authenticated
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
defer.returnValue(user_id)
defer.returnValue(True)
result = yield self._check_local_password(user_id, password)
defer.returnValue(result)
valid_local_password = yield self._check_local_password(user_id, password)
if valid_local_password:
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are
multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
defer.returnValue(self.validate_hash(password, password_hash))
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
@@ -614,7 +570,7 @@ class AuthHandler(BaseHandler):
)
# check for existing account, if none exists, create one
if not (yield self.check_user_exists(user_id)):
if not (yield self.does_user_exist(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
@@ -660,7 +616,7 @@ class AuthHandler(BaseHandler):
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(conn.response)
len(result)
)
defer.returnValue(False)
@@ -670,26 +626,23 @@ class AuthHandler(BaseHandler):
defer.returnValue(False)
@defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None):
def issue_access_token(self, user_id):
access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token,
device_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id, device_id=None):
def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None,
duration_in_ms=(60 * 60 * 1000)):
def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
@@ -719,14 +672,13 @@ class AuthHandler(BaseHandler):
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True, user_id)
return user_id
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True)
return self.get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
@@ -737,11 +689,21 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
except_access_token_id = requester.access_token_id if requester else None
except_access_token_ids = [requester.access_token_id] if requester else []
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@@ -750,10 +712,10 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens(
user_id, except_access_token_id
user_id, except_access_token_ids
)
yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id
user_id, except_access_token_ids
)
@defer.inlineCallbacks

View File

@@ -1,181 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api import errors
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string_with_symbols(16)
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys())
)
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),)
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})

View File

@@ -19,7 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes
from synapse.types import RoomAlias, UserID, get_domain_from_id
from synapse.types import RoomAlias, UserID
import logging
import string
@@ -55,8 +55,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
users = yield self.state.get_current_user_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
servers = yield self.store.get_joined_hosts_for_room(room_id)
if not servers:
raise SynapseError(400, "Failed to get server list")
@@ -194,8 +193,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND
)
users = yield self.state.get_current_user_in_room(room_id)
extra_servers = set(get_domain_from_id(u) for u in users)
extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.

View File

@@ -1,139 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import logging
from twisted.internet import defer
from synapse.api import errors
import synapse.types
logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks
def query_devices(self, query_body):
""" Handle a device key query from a client
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict)
for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id)
queries_by_domain[user.domain][user_id] = device_ids
# do the queries
# TODO: do these in parallel
results = {}
for destination, destination_query in queries_by_domain.items():
if destination == self.server_name:
res = yield self.query_local_devices(destination_query)
else:
res = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}
)
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
results[user_id] = keys
defer.returnValue((200, {"device_keys": results}))
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise errors.SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})

View File

@@ -47,7 +47,6 @@ class EventStreamHandler(BaseHandler):
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
@defer.inlineCallbacks
@log_function
@@ -91,7 +90,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.state.get_current_user_in_room(event.room_id)
users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,

View File

@@ -26,10 +26,7 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
from synapse.util.metrics import measure_func
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -101,9 +98,6 @@ class FederationHandler(BaseHandler):
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler.
auth_chain and state are None if we already have the necessary state
and prev_events in the db
"""
event = pdu
@@ -121,25 +115,16 @@ class FederationHandler(BaseHandler):
# FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work
# If state and auth_chain are None, then we don't need to do this check
# as we already know we have enough state in the DB to handle this
# event.
if state and auth_chain and not event.internal_metadata.is_outlier():
is_in_room = yield self.auth.check_host_in_room(
event.room_id,
self.server_name
)
else:
is_in_room = True
if not is_in_room:
logger.info(
"Got event for room we're not in: %r %r",
event.room_id, event.event_id
)
if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
try:
event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event
auth_chain, state, event
)
except AuthError as e:
raise FederationError(
@@ -230,28 +215,17 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key))
if not prev_state or prev_state.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally
# joined the room. Don't bother if the user is just
# changing their profile info.
newly_joined = True
prev_state_id = context.prev_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
prev_state = yield self.store.get_event(
prev_state_id, allow_none=True,
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
if newly_joined:
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state_ids = yield self.store.get_state_ids_for_events(
event_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
@@ -259,30 +233,6 @@ class FederationHandler(BaseHandler):
)
)
# We only want to pull out member events that correspond to the
# server's domain.
def check_match(id):
try:
return server_name == get_domain_from_id(id)
except:
return False
event_map = yield self.store.get_events([
e_id for key_to_eid in event_to_state_ids.values()
for key, e_id in key_to_eid
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.items()
}
def redact_disallowed(event, state):
if not state:
return event
@@ -299,7 +249,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member:
continue
try:
domain = get_domain_from_id(ev.state_key)
domain = UserID.from_string(ev.state_key).domain
except:
continue
@@ -324,7 +274,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities):
def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return
@@ -334,6 +284,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill(
dest,
room_id,
@@ -382,60 +335,31 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id
for event in events + state_events.values() + auth_events.values()
for a_id, _ in event.auth_events
a_id for event in all_events for a_id, _ in event.auth_events
)
auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
})
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
missing_auth = required_auth - set(auth_events)
if missing_auth:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch
)
results = yield preserve_context_over_deferred(defer.gatherResults(
results = yield defer.gatherResults(
[
preserve_fn(self.replication_layer.get_pdu)(
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
for event_id in missing_auth
],
consumeErrors=True
)).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
for event in results if event
for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = []
for a in auth_events.values():
@@ -448,7 +372,6 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
if a_id in auth_events
}
})
@@ -460,20 +383,13 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
if a_id in auth_events
}
})
try:
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
except Exception as e:
logger.warn(
"Failed to handle auth events because: %s", e
)
raise
events.sort(key=lambda e: e.depth)
@@ -510,10 +426,6 @@ class FederationHandler(BaseHandler):
)
max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
@@ -610,24 +522,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
states = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
states = yield defer.gatherResults([
self.state_handler.resolve_state_groups(room_id, [e])
for e in event_ids
]))
])
states = dict(zip(event_ids, [s[1] for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
} for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
@@ -737,7 +637,7 @@ class FederationHandler(BaseHandler):
pass
event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event
auth_chain, state, event
)
with PreserveLoggingContext():
@@ -788,9 +688,7 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e)
raise e
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
self.auth.check(event, auth_events=context.current_state)
defer.returnValue(event)
@@ -838,11 +736,16 @@ class FederationHandler(BaseHandler):
new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = set()
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin)
@@ -855,15 +758,13 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = context.prev_state_ids.values()
state_ids = [e.event_id for e in context.current_state.values()]
auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids
))
state = yield self.store.get_events(context.prev_state_ids.values())
defer.returnValue({
"state": state.values(),
"state": context.current_state.values(),
"auth_chain": auth_chain,
})
@@ -1017,9 +918,7 @@ class FederationHandler(BaseHandler):
)
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -1063,12 +962,18 @@ class FederationHandler(BaseHandler):
new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = set()
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin)
logger.debug(
@@ -1082,11 +987,14 @@ class FederationHandler(BaseHandler):
defer.returnValue(None)
@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
)
@@ -1125,34 +1033,6 @@ class FederationHandler(BaseHandler):
else:
defer.returnValue([])
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
yield run_on_reactor()
state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id]
)
if state_groups:
_, state = state_groups.items().pop()
results = state
event = yield self.store.get_event(event_id)
if event and event.is_state():
# Get previous state
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id
else:
del results[(event.type, event.state_key)]
defer.returnValue(results.values())
else:
defer.returnValue([])
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit):
@@ -1185,7 +1065,6 @@ class FederationHandler(BaseHandler):
)
if event:
if self.hs.is_mine_id(event.event_id):
# FIXME: This is a temporary work around where we occasionally
# return events slightly differently than when they were
# originally signed
@@ -1205,12 +1084,6 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server(
origin, event.room_id, [event]
)
event = events[0]
defer.returnValue(event)
else:
defer.returnValue(None)
@@ -1241,7 +1114,6 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
@@ -1257,9 +1129,9 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
contexts = yield preserve_context_over_deferred(defer.gatherResults(
contexts = yield defer.gatherResults(
[
preserve_fn(self._prep_event)(
self._prep_event(
origin,
ev_info["event"],
state=ev_info.get("state"),
@@ -1267,7 +1139,7 @@ class FederationHandler(BaseHandler):
)
for ev_info in event_infos
]
))
)
yield self.store.persist_events(
[
@@ -1278,19 +1150,11 @@ class FederationHandler(BaseHandler):
)
@defer.inlineCallbacks
def _persist_auth_tree(self, origin, auth_events, state, event):
def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event seperately.
Will attempt to fetch missing auth events.
Args:
origin (str): Where the events came from
auth_events (list)
state (list)
event (Event)
Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event`
@@ -1303,7 +1167,7 @@ class FederationHandler(BaseHandler):
event_map = {
e.event_id: e
for e in itertools.chain(auth_events, state, [event])
for e in auth_events
}
create_event = None
@@ -1312,29 +1176,10 @@ class FederationHandler(BaseHandler):
create_event = e
break
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events:
if e_id not in event_map:
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu(
[origin],
e_id,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
else:
logger.info("Failed to find auth event %r", e_id)
for e in itertools.chain(auth_events, state, [event]):
auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events
if e_id in event_map
}
if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event
@@ -1383,13 +1228,7 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
auth_events = context.current_state
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -1413,15 +1252,10 @@ class FederationHandler(BaseHandler):
)
context.rejected = RejectedReason.AUTH_ERROR
except Exception as e:
logger.warn(
"Failed to auth event: %s because %s",
event.event_id, e.msg
)
raise
if event.type == EventTypes.GuestAccess:
yield self.maybe_kick_guest_users(event)
full_context = yield self.store.get_current_state(room_id=event.room_id)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context)
@@ -1489,11 +1323,6 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event.is_state():
event_key = (event.type, event.state_key)
else:
event_key = None
if event_auth_events - current_state:
have_events = yield self.store.have_events(
event_auth_events - current_state
@@ -1567,9 +1396,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
different_events = yield preserve_context_over_deferred(defer.gatherResults(
different_events = yield defer.gatherResults(
[
preserve_fn(self.store.get_event)(
self.store.get_event(
d,
allow_none=True,
allow_rejected=False,
@@ -1578,7 +1407,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d]
],
consumeErrors=True
)).addErrback(unwrapFirstError)
).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
@@ -1597,14 +1426,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
context.current_state.update(auth_events)
context.state_group = None
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
@@ -1625,8 +1448,8 @@ class FederationHandler(BaseHandler):
if do_resolution:
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids
auth_ids = self.auth.compute_auth_events(
event, context.current_state
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@@ -1682,14 +1505,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
context.current_state_ids.update({
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
context.current_state.update(auth_events)
context.state_group = None
try:
self.auth.check(event, auth_events=auth_events)
@@ -1875,12 +1692,12 @@ class FederationHandler(BaseHandler):
)
try:
yield self.auth.check_from_context(event, context)
self.auth.check(event, context.current_state)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context)
else:
@@ -1906,11 +1723,11 @@ class FederationHandler(BaseHandler):
)
try:
self.auth.check_from_context(event, context)
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
@@ -1924,12 +1741,7 @@ class FederationHandler(BaseHandler):
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
original_invite = None
original_invite_id = context.prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
original_invite = context.current_state.get(key)
if not original_invite:
logger.info(
"Could not find invite event for third_party_invite - "
@@ -1946,13 +1758,13 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context))
@defer.inlineCallbacks
def _check_signature(self, event, context):
def _check_signature(self, event, auth_events):
"""
Checks that the signature in the event is consistent with its invite.
Args:
event (Event): The m.room.member event to check
context (EventContext):
auth_events (dict<(event type, state_key), event>):
Raises:
AuthError: if signature didn't match any keys, or key has been
@@ -1963,14 +1775,10 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
invite_event_id = context.prev_state_ids.get(
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
invite_event = None
if invite_event_id:
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
if not invite_event:
raise AuthError(403, "Could not find invite")

View File

@@ -28,8 +28,7 @@ from synapse.types import (
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.metrics import measure_func
from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -67,7 +66,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True, event_filter=None):
as_client_event=True):
"""Get messages in a room.
Args:
@@ -76,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
@@ -130,13 +129,8 @@ class MessageHandler(BaseHandler):
room_id, max_topo
)
events, next_key = yield self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
event_filter=event_filter,
events, next_key = yield data_source.get_pagination_rows(
requester.user, source_config, room_id
)
next_token = pagin_config.from_token.copy_and_replace(
@@ -150,9 +144,6 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(),
})
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
self.store,
user_id,
@@ -173,6 +164,101 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def get_files(self, requester, room_id, pagin_config):
"""Get files in a room.
Args:
requester (Requester): The user requesting files.
room_id (str): The room they want files from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
)
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id
)
if source_config.direction == 'b':
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token(
room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
events, next_key = yield self.store.paginate_room_file_events(
room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield filter_events_for_client(
self.store,
user_id,
events,
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now)
for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
"""
@@ -248,7 +334,7 @@ class MessageHandler(BaseHandler):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
prev_state = self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
@@ -263,7 +349,6 @@ class MessageHandler(BaseHandler):
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
@@ -271,17 +356,13 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
prev_event = context.current_state.get((event.type, event.state_key))
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
return prev_event
return None
@defer.inlineCallbacks
def create_and_send_nonmember_event(
@@ -508,17 +589,15 @@ class MessageHandler(BaseHandler):
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
(messages, token), current_state = yield defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
self.store.get_recent_events_for_room(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
@@ -727,9 +806,9 @@ class MessageHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults(
[
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
get_presence(),
get_receipts(),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
@@ -763,7 +842,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(ret)
@measure_func("_create_new_client_event")
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
@@ -807,15 +885,14 @@ class MessageHandler(BaseHandler):
event = builder.build()
logger.debug(
"Created event %s with state: %s",
event.event_id, context.prev_state_ids,
"Created event %s with current state: %s",
event.event_id, context.current_state,
)
defer.returnValue(
(event, context,)
)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
self,
@@ -831,12 +908,12 @@ class MessageHandler(BaseHandler):
self.ratelimit(requester)
try:
yield self.auth.check_from_context(event, context)
self.auth.check(event, auth_events=context.current_state)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context)
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
@@ -864,15 +941,6 @@ class MessageHandler(BaseHandler):
e.sender == event.sender
)
state_to_include_ids = [
e_id
for k, e_id in context.current_state_ids.items()
if k[0] in self.hs.config.room_invite_state_types
or k[0] == EventTypes.Member and k[1] == event.sender
]
state_to_include = yield self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
@@ -880,7 +948,9 @@ class MessageHandler(BaseHandler):
"content": e.content,
"sender": e.sender,
}
for e in state_to_include.values()
for k, e in context.current_state.items()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
]
invitee = UserID.from_string(event.state_key)
@@ -902,14 +972,7 @@ class MessageHandler(BaseHandler):
)
if event.type == EventTypes.Redaction:
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
if self.auth.check_redaction(event, auth_events=auth_events):
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
@@ -923,7 +986,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.prev_state_ids:
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
@@ -944,17 +1007,21 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id
)
users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = [
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
]
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
yield self.notifier.on_new_room_event(
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
@@ -964,6 +1031,6 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
preserve_fn(federation_handler.handle_new_event)(
federation_handler.handle_new_event(
event, destinations=destinations,
)

View File

@@ -52,11 +52,6 @@ bump_active_time_counter = metrics.register_counter("bump_active_time")
get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
state_transition_counter = metrics.register_counter(
"state_transition", labels=["from", "to"]
)
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active"
@@ -93,8 +88,6 @@ class PresenceHandler(object):
self.notifier = hs.get_notifier()
self.federation = hs.get_replication_layer()
self.state = hs.get_state_handler()
self.federation.register_edu_handler(
"m.presence", self.incoming_presence
)
@@ -196,13 +189,6 @@ class PresenceHandler(object):
5000,
)
self.clock.call_later(
60,
self.clock.looping_call,
self._persist_unpersisted_changes,
60 * 1000,
)
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
@defer.inlineCallbacks
@@ -228,27 +214,6 @@ class PresenceHandler(object):
])
logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
logger.info(
"Performing _persist_unpersisted_changes. Persiting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
@@ -538,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states)
@defer.inlineCallbacks
def _get_interested_parties(self, states, calculate_remote_hosts=True):
def _get_interested_parties(self, states):
"""Given a list of states return which entities (rooms, users, servers)
are interested in the given states.
@@ -561,15 +526,12 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {}
if calculate_remote_hosts:
for room_id, states in room_ids_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states:
continue
users = yield self.state.get_current_user_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
hosts = yield self.store.get_joined_hosts_for_room(room_id)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
@@ -603,16 +565,6 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers.
@@ -720,7 +672,7 @@ class PresenceHandler(object):
])
@defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False):
def set_state(self, target_user, state):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
@@ -737,13 +689,10 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
new_fields = {
"state": presence
"state": presence,
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
}
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()
@@ -762,13 +711,13 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part
# of the event stream/sync.
# TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id)
if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string())
hosts = set(get_domain_from_id(u) for u in user_ids)
hosts = yield self.store.get_joined_hosts_for_room(room_id)
self._push_to_remotes({host: (state,) for host in hosts})
else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids)
@@ -944,32 +893,22 @@ class PresenceHandler(object):
def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties.
"""
if old_state == new_state:
return False
if old_state.status_msg != new_state.status_msg:
notify_reason_counter.inc("status_msg_change")
return True
if old_state.state != new_state.state:
notify_reason_counter.inc("state_change")
state_transition_counter.inc(old_state.state, new_state.state)
return True
if old_state.state == PresenceState.ONLINE:
if new_state.state != PresenceState.ONLINE:
# Always notify for online -> anything
return True
if new_state.currently_active != old_state.currently_active:
notify_reason_counter.inc("current_active_change")
return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
notify_reason_counter.inc("last_active_change_online")
# Always notify for a transition where last active gets bumped.
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped.
notify_reason_counter.inc("last_active_change_not_online")
if old_state.state != new_state.state:
return True
return False
@@ -1002,7 +941,6 @@ class PresenceEventSource(object):
self.get_presence_handler = hs.get_presence_handler
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@defer.inlineCallbacks
@log_function
@@ -1065,7 +1003,7 @@ class PresenceEventSource(object):
user_ids_to_check = set()
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.store.get_users_in_room(room_id)
user_ids_to_check.update(users)
user_ids_to_check.update(friends)

View File

@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID
from synapse.types import UserID, Requester
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
@@ -165,9 +165,7 @@ class ProfileHandler(BaseHandler):
try:
# Assume the user isn't a guest because we don't let guests set
# profile or avatar data.
# XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
requester = Requester(user, "", False)
yield handler.update_membership(
requester,
user,

View File

@@ -18,7 +18,6 @@ from ._base import BaseHandler
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import get_domain_from_id
import logging
@@ -38,7 +37,6 @@ class ReceiptsHandler(BaseHandler):
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id,
@@ -135,8 +133,7 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"]
data = receipt["data"]
users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name)

View File

@@ -14,19 +14,18 @@
# limitations under the License.
"""Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer
import synapse.types
from synapse.types import UserID, Requester
from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from synapse.http.client import CaptchaServerHttpClient
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
import logging
import urllib
logger = logging.getLogger(__name__)
@@ -53,13 +52,6 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME
)
if localpart[0] == '_':
raise SynapseError(
400,
"User ID may not begin with _",
Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
@@ -109,11 +101,6 @@ class RegistrationHandler(BaseHandler):
password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -209,13 +196,15 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
password_hash="",
appservice_id=service_id,
create_profile_with_localpart=user.localpart,
)
defer.returnValue(user_id)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
@@ -371,7 +360,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_in_ms,
def get_or_create_user(self, localpart, displayname, duration_seconds,
password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -401,8 +390,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.auth_handler().generate_access_token(
user_id, None, duration_in_ms)
token = self.auth_handler().generate_short_term_login_token(
user_id, duration_seconds)
if need_register:
yield self.store.register(
@@ -418,9 +407,8 @@ class RegistrationHandler(BaseHandler):
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname(
user, requester, displayname
user, Requester(user, token, False), displayname
)
defer.returnValue((user_id, token))

View File

@@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler):
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs)
self.remote_list_request_cache = ResponseCache(hs)
self.response_cache = ResponseCache()
self.remote_list_request_cache = ResponseCache()
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL

View File

@@ -14,22 +14,24 @@
# limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__)
@@ -59,13 +61,10 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids,
txn_id=None,
ratelimit=True,
content=None,
):
if content is None:
content = {}
msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership
content = {"membership": membership}
if requester.is_guest:
content["kind"] = "guest"
@@ -85,12 +84,6 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=prev_event_ids,
)
# Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return
yield msg_handler.handle_new_client_event(
requester,
event,
@@ -99,25 +92,19 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
prev_member_event_id = context.prev_state_ids.get(
prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
@@ -155,9 +142,8 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
content=None,
):
key = (room_id,)
key = (target, room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
@@ -169,7 +155,6 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
)
defer.returnValue(result)
@@ -185,11 +170,7 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
content=None,
):
if content is None:
content = {}
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@@ -207,19 +188,16 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids(
current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids,
)
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
old_state = current_state.get((EventTypes.Member, target.to_string()))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned"
" (membership=%s)" % old_membership,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
@@ -229,10 +207,10 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE
)
is_host_in_room = yield self._is_host_in_room(current_state_ids)
is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state_ids):
if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
@@ -242,7 +220,7 @@ class RoomMemberHandler(BaseHandler):
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content["membership"] = Membership.JOIN
content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
@@ -296,7 +274,6 @@ class RoomMemberHandler(BaseHandler):
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
content=content,
)
@defer.inlineCallbacks
@@ -338,17 +315,15 @@ class RoomMemberHandler(BaseHandler):
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = synapse.types.create_requester(target_user)
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = yield message_handler.deduplicate_state_event(event, context)
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
if not guest_can_join:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
@@ -361,39 +336,27 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit,
)
prev_member_event_id = context.prev_state_ids.get(
(EventTypes.Member, event.state_key),
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
if not guest_access_id:
defer.returnValue(False)
guest_access = yield self.store.get_event(guest_access_id)
defer.returnValue(
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -712,24 +675,3 @@ class RoomMemberHandler(BaseHandler):
if membership:
yield self.store.forget(user_id, room_id)
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
for (etype, state_key), event_id in current_state_ids.items():
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
event = yield self.store.get_event(event_id, allow_none=True)
if not event:
continue
if event.membership == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)

View File

@@ -35,7 +35,6 @@ SyncConfig = collections.namedtuple("SyncConfig", [
"filter_collection",
"is_guest",
"request_key",
"device_id",
])
@@ -114,7 +113,6 @@ class SyncResult(collections.namedtuple("SyncResult", [
"joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
])):
__slots__ = []
@@ -128,8 +126,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.joined or
self.invited or
self.archived or
self.account_data or
self.to_device
self.account_data
)
@@ -141,8 +138,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs)
self.state = hs.get_state_handler()
self.response_cache = ResponseCache()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
@@ -359,11 +355,11 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.store.get_state_ids_for_event(event.event_id)
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
state_ids = state_ids.copy()
state_ids[(event.type, event.state_key)] = event.event_id
defer.returnValue(state_ids)
state = state.copy()
state[(event.type, event.state_key)] = event
defer.returnValue(state)
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
@@ -416,61 +412,57 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"):
if full_state:
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
current_state = yield self.store.get_state_for_event(
batch.events[-1].event_id
)
state_ids = yield self.store.get_state_ids_for_event(
state = yield self.store.get_state_for_event(
batch.events[0].event_id
)
else:
current_state_ids = yield self.get_state_at(
current_state = yield self.get_state_at(
room_id, stream_position=now_token
)
state_ids = current_state_ids
state = current_state
timeline_state = {
(event.type, event.state_key): event.event_id
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state_ids = _calculate_state(
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_ids,
timeline_start=state,
previous={},
current=current_state_ids,
current=current_state,
)
elif batch.limited:
state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token
)
current_state_ids = yield self.store.get_state_ids_for_event(
current_state = yield self.store.get_state_for_event(
batch.events[-1].event_id
)
state_at_timeline_start = yield self.store.get_state_ids_for_event(
state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id
)
timeline_state = {
(event.type, event.state_key): event.event_id
(event.type, event.state_key): event
for event in batch.events if event.is_state()
}
state_ids = _calculate_state(
state = _calculate_state(
timeline_contains=timeline_state,
timeline_start=state_at_timeline_start,
previous=state_at_previous_sync,
current=current_state_ids,
current=current_state,
)
else:
state_ids = {}
state = {}
if state_ids:
state = yield self.store.get_events(state_ids.values())
defer.returnValue({
(e.type, e.state_key): e
@@ -535,57 +527,15 @@ class SyncHandler(object):
sync_result_builder, newly_joined_rooms, newly_joined_users
)
yield self._generate_sync_entry_for_to_device(sync_result_builder)
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined,
invited=sync_result_builder.invited,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
next_batch=sync_result_builder.now_token,
))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates
`sync_result_builder` with the result.
Args:
sync_result_builder(SyncResultBuilder)
Returns:
Deferred(dict): A dictionary containing the per room account data.
"""
user_id = sync_result_builder.sync_config.user.to_string()
device_id = sync_result_builder.sync_config.device_id
now_token = sync_result_builder.now_token
since_stream_id = 0
if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key)
if since_stream_id != int(now_token.to_device_key):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
logger.debug("Deleting messages up to %d", since_stream_id)
yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug("Getting messages up to %d", now_token.to_device_key)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
logger.debug("Got messages up to %d: %r", stream_id, messages)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
)
sync_result_builder.to_device = messages
else:
sync_result_builder.to_device = []
@defer.inlineCallbacks
def _generate_sync_entry_for_account_data(self, sync_result_builder):
"""Generates the account data portion of the sync response. Populates
@@ -676,7 +626,7 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_users)
for room_id in newly_joined_rooms:
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.store.get_users_in_room(room_id)
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
@@ -816,13 +766,8 @@ class SyncHandler(object):
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
old_mem_ev = yield self.store.get_event(
old_mem_ev_id, allow_none=True
)
old_state = yield self.get_state_at(room_id, since_token)
old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
@@ -1114,25 +1059,27 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
Returns:
dict
"""
event_id_to_key = {
e: key
for key, e in itertools.chain(
timeline_contains.items(),
previous.items(),
timeline_start.items(),
current.items(),
event_id_to_state = {
e.event_id: e
for e in itertools.chain(
timeline_contains.values(),
previous.values(),
timeline_start.values(),
current.values(),
)
}
c_ids = set(e for e in current.values())
tc_ids = set(e for e in timeline_contains.values())
p_ids = set(e for e in previous.values())
ts_ids = set(e for e in timeline_start.values())
c_ids = set(e.event_id for e in current.values())
tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return {
event_id_to_key[e]: e for e in state_ids
(e.type, e.state_key): e
for e in evs
}
@@ -1156,7 +1103,6 @@ class SyncResultBuilder(object):
self.joined = []
self.invited = []
self.archived = []
self.device = []
class RoomSyncResultBuilder(object):

View File

@@ -16,11 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
)
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
from synapse.types import UserID, get_domain_from_id
from synapse.types import UserID
import logging
@@ -42,7 +40,6 @@ class TypingHandler(object):
self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
@@ -167,19 +164,18 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_update(self, room_id, user_id, typing):
users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
domains = yield self.store.get_joined_hosts_for_room(room_id)
deferreds = []
for domain in domains:
if domain == self.server_name:
preserve_fn(self._push_update_local)(
self._push_update_local(
room_id=room_id,
user_id=user_id,
typing=typing
)
else:
deferreds.append(preserve_fn(self.federation.send_edu)(
deferreds.append(self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
@@ -189,9 +185,7 @@ class TypingHandler(object):
},
))
yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
@@ -201,8 +195,7 @@ class TypingHandler(object):
# Check that the string is a valid user id
UserID.from_string(user_id)
users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
domains = yield self.store.get_joined_hosts_for_room(room_id)
if self.server_name in domains:
self._push_update_local(

View File

@@ -155,7 +155,9 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60,
)
response = yield preserve_context_over_fn(send_request)
response = yield preserve_context_over_fn(
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,)
break

View File

@@ -19,7 +19,6 @@ from synapse.api.errors import (
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure
import synapse.metrics
import synapse.events
@@ -75,12 +74,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0
def request_handler(include_metrics=False):
def request_handler(report_metrics=True):
"""Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
def wrap_request_handler(request_handler, include_metrics=False):
def wrap_request_handler(request_handler, report_metrics):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
@@ -104,17 +103,14 @@ def wrap_request_handler(request_handler, include_metrics=False):
_next_request_id += 1
with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"):
if report_metrics:
request_metrics = RequestMetrics()
request_metrics.start(self.clock, name=self.__class__.__name__)
request_metrics.start(self.clock)
request_context.request = request_id
with request.processing():
try:
with PreserveLoggingContext(request_context):
if include_metrics:
yield request_handler(self, request, request_metrics)
else:
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
@@ -149,11 +145,12 @@ def wrap_request_handler(request_handler, include_metrics=False):
)
finally:
try:
if report_metrics:
request_metrics.stop(
self.clock, request
self.clock, request, self.__class__.__name__
)
except Exception as e:
logger.warn("Failed to stop metrics: %r", e)
except:
pass
return wrapped_request_handler
@@ -208,7 +205,6 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback)
)
@@ -223,9 +219,9 @@ class JsonResource(HttpServer, resource.Resource):
# It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself.
@request_handler(include_metrics=True)
@request_handler(report_metrics=False)
@defer.inlineCallbacks
def _async_render(self, request, request_metrics):
def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
@@ -234,6 +230,9 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, 200, {})
return
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
@@ -247,6 +246,12 @@ class JsonResource(HttpServer, resource.Resource):
callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items()
@@ -257,13 +262,10 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
try:
request_metrics.stop(self.clock, request, servlet_classname)
except:
pass
return
@@ -295,12 +297,11 @@ class JsonResource(HttpServer, resource.Resource):
class RequestMetrics(object):
def start(self, clock, name):
def start(self, clock):
self.start = clock.time_msec()
self.start_context = LoggingContext.current_context()
self.name = name
def stop(self, clock, request):
def stop(self, clock, request, servlet_classname):
context = LoggingContext.current_context()
tag = ""
@@ -314,26 +315,26 @@ class RequestMetrics(object):
)
return
incoming_requests_counter.inc(request.method, self.name, tag)
incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
self.name, tag
servlet_classname, tag
)
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(
ru_utime, request.method, self.name, tag
ru_utime, request.method, servlet_classname, tag
)
response_ru_stime.inc_by(
ru_stime, request.method, self.name, tag
ru_stime, request.method, servlet_classname, tag
)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, self.name, tag
context.db_txn_count, request.method, servlet_classname, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, self.name, tag
context.db_txn_duration, request.method, servlet_classname, tag
)

View File

@@ -27,8 +27,7 @@ import gc
from twisted.internet import reactor
from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
MemoryUsageMetric,
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
)
@@ -67,21 +66,6 @@ class Metrics(object):
return self._register(CacheMetric, *args, **kwargs)
def register_memory_metrics(hs):
try:
import psutil
process = psutil.Process()
process.memory_info().rss
except (ImportError, AttributeError):
logger.warn(
"psutil is not installed or incorrect version."
" Disabling memory metrics."
)
return
metric = MemoryUsageMetric(hs, psutil)
all_metrics.append(metric)
def get_metrics_for(pkg_name):
""" Returns a Metrics instance for conveniently creating metrics
namespaced with the given name prefix. """

View File

@@ -153,43 +153,3 @@ class CacheMetric(object):
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
]
class MemoryUsageMetric(object):
"""Keeps track of the current memory usage, using psutil.
The class will keep the current min/max/sum/counts of rss over the last
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
"""
UPDATE_HZ = 2 # number of times to get memory per second
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
def __init__(self, hs, psutil):
clock = hs.get_clock()
self.memory_snapshots = []
self.process = psutil.Process()
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
def _update_curr_values(self):
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
self.memory_snapshots.append(self.process.memory_info().rss)
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
def render(self):
if not self.memory_snapshots:
return []
max_rss = max(self.memory_snapshots)
min_rss = min(self.memory_snapshots)
sum_rss = sum(self.memory_snapshots)
len_rss = len(self.memory_snapshots)
return [
"process_psutil_rss:max %d" % max_rss,
"process_psutil_rss:min %d" % min_rss,
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]

View File

@@ -19,8 +19,7 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.metrics import Measure
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
import synapse.metrics
@@ -68,8 +67,10 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class.
"""
def __init__(self, user_id, rooms, current_token, time_now_ms):
def __init__(self, user_id, rooms, current_token, time_now_ms,
appservice=None):
self.user_id = user_id
self.appservice = appservice
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
@@ -106,6 +107,11 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id)
if self.appservice:
notifier.appservice_to_user_streams.get(
self.appservice, set()
).discard(self)
def count_listeners(self):
return len(self.notify_deferred.observers())
@@ -136,6 +142,7 @@ class Notifier(object):
def __init__(self, hs):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
@@ -161,6 +168,8 @@ class Notifier(object):
all_user_streams |= x
for x in self.user_to_user_stream.values():
all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners)
@@ -173,8 +182,11 @@ class Notifier(object):
"users",
lambda: len(self.user_to_user_stream),
)
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
)
@preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
""" Used by handlers to inform the notifier something has happened
@@ -196,7 +208,6 @@ class Notifier(object):
self.notify_replication()
@preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
@@ -214,11 +225,24 @@ class Notifier(object):
else:
self._on_new_room_event(event, room_stream_id, extra_users)
@preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
self.appservice_handler.notify_interested_services(room_stream_id)
self.appservice_handler.notify_interested_services(event)
app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
app_streams |= app_user_streams
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
@@ -227,16 +251,16 @@ class Notifier(object):
"room_key", room_stream_id,
users=extra_users,
rooms=[event.room_id],
extra_streams=app_streams,
)
@preserve_fn
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
""" Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms.
"""
with PreserveLoggingContext():
with Measure(self.clock, "on_new_event"):
user_streams = set()
for user in users:
@@ -256,7 +280,6 @@ class Notifier(object):
self.notify_replication()
@preserve_fn
def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend
without waking up any of the normal user event streams"""
@@ -271,6 +294,7 @@ class Notifier(object):
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token()
if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id)
@@ -278,6 +302,7 @@ class Notifier(object):
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
appservice=appservice,
current_token=current_token,
time_now_ms=self.clock.time_msec(),
)
@@ -423,8 +448,7 @@ class Notifier(object):
def _is_world_readable(self, room_id):
state = yield self.state_handler.get_current_state(
room_id,
EventTypes.RoomHistoryVisibility,
"",
EventTypes.RoomHistoryVisibility
)
if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable")
@@ -453,6 +477,11 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None:

View File

@@ -38,14 +38,13 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"):
with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context
event, self.hs, self.store, context.current_state
)
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context
event, context.current_state
)
context.push_actions = [

View File

@@ -217,27 +217,6 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify'
]
},
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
]
@@ -263,6 +242,23 @@ BASE_APPEND_UNDERRIDE_RULES = [
}
]
},
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{
'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [

View File

@@ -19,8 +19,8 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes
from synapse.visibility import filter_events_for_clients_context
from synapse.api.constants import EventTypes, Membership
from synapse.visibility import filter_events_for_clients
logger = logging.getLogger(__name__)
@@ -36,11 +36,35 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
def evaluator_for_event(event, hs, store, context):
rules_by_user = yield store.bulk_get_push_rules_for_room(
event, context
def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
)
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@@ -48,12 +72,12 @@ def evaluator_for_event(event, hs, store, context):
if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher:
rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user
)
user_ids.add(invited_user)
rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator(
event.room_id, rules_by_user, store
room_id, rules_by_user, user_ids, store
))
@@ -66,13 +90,14 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562)
"""
def __init__(self, room_id, rules_by_user, store):
def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id
self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
def action_for_event_by_user(self, event, current_state):
actions_by_user = {}
# None of these users can be peeking since this list of users comes
@@ -82,25 +107,27 @@ class BulkPushRuleEvaluator:
(u, False) for u in self.rules_by_user.keys()
]
filtered_by_user = yield filter_events_for_clients_context(
self.store, user_tuples, [event], {event.event_id: context}
filtered_by_user = yield filter_events_for_clients(
self.store, user_tuples, [event], {event.event_id: current_state}
)
room_members = yield self.store.get_joined_users_from_context(
event, context
room_members = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
condition_cache = {}
display_names = {}
for ev in current_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items():
display_name = None
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
if member_ev_id:
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
if member_ev:
display_name = member_ev.content.get("displayname", None)
display_name = display_names.get(uid, None)
filtered = filtered_by_user[uid]
if len(filtered) == 0:

View File

@@ -14,7 +14,6 @@
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
@@ -93,11 +92,7 @@ class EmailPusher(object):
def on_stop(self):
if self.timed_call:
try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@@ -145,8 +140,9 @@ class EmailPusher(object):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None
@@ -194,10 +190,7 @@ class EmailPusher(object):
soonest_due_at = should_notify_at
if self.timed_call is not None:
try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
if soonest_due_at is not None:

View File

@@ -16,7 +16,6 @@
from synapse.push import PusherConfigException
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
import push_rule_evaluator
@@ -110,11 +109,7 @@ class HttpPusher(object):
def on_stop(self):
if self.timed_call:
try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks
def _process(self):
@@ -146,8 +141,7 @@ class HttpPusher(object):
run once per pusher.
"""
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn(
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -245,7 +239,7 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id
self.state_handler, event, self.user_id
)
d = {

View File

@@ -22,7 +22,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from synapse.util.async import concurrently_execute
from synapse.push.presentable_names import (
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events
)
from synapse.types import UserID
@@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks
def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state_ids(room_id)
room_state = yield self.state_handler.get_current_state(room_id)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
@@ -159,12 +159,11 @@ class Mailer(object):
)
rooms.append(roomvars)
reason['room_name'] = yield calculate_room_name(
self.store, state_by_room[reason['room_id']], user_id,
fallback_to_members=True
reason['room_name'] = calculate_room_name(
state_by_room[reason['room_id']], user_id, fallback_to_members=True
)
summary_text = yield self.make_summary_text(
summary_text = self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason
)
@@ -204,15 +203,12 @@ class Mailer(object):
)
@defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
my_member_event_id = room_state_ids[("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
my_member_event = room_state[("m.room.member", user_id)]
is_invite = my_member_event.content["membership"] == "invite"
room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
room_vars = {
"title": room_name,
"title": calculate_room_name(room_state, user_id),
"hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [],
"invite": is_invite,
@@ -222,7 +218,7 @@ class Mailer(object):
if not is_invite:
for n in notifs:
notifvars = yield self.get_notif_vars(
n, user_id, notif_events[n['event_id']], room_state_ids
n, user_id, notif_events[n['event_id']], room_state
)
# merge overlapping notifs together.
@@ -247,7 +243,7 @@ class Mailer(object):
defer.returnValue(room_vars)
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
def get_notif_vars(self, notif, user_id, notif_event, room_state):
results = yield self.store.get_events_around(
notif['room_id'], notif['event_id'],
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
@@ -265,19 +261,17 @@ class Mailer(object):
the_events.append(notif_event)
for event in the_events:
messagevars = yield self.get_message_vars(notif, event, room_state_ids)
messagevars = self.get_message_vars(notif, event, room_state)
if messagevars is not None:
ret['messages'].append(messagevars)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_message_vars(self, notif, event, room_state_ids):
def get_message_vars(self, notif, event, room_state):
if event.type != EventTypes.Message:
return
return None
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
sender_state_event = yield self.store.get_event(sender_state_event_id)
sender_state_event = room_state[("m.room.member", event.sender)]
sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content.get("avatar_url")
@@ -305,7 +299,7 @@ class Mailer(object):
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
defer.returnValue(ret)
return ret
def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format")
@@ -327,7 +321,6 @@ class Mailer(object):
return messagevars
@defer.inlineCallbacks
def make_summary_text(self, notifs_by_room, state_by_room,
notif_events, user_id, reason):
if len(notifs_by_room) == 1:
@@ -337,8 +330,8 @@ class Mailer(object):
# If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name(
self.store, state_by_room[room_id], user_id, fallback_to_members=False
room_name = calculate_room_name(
state_by_room[room_id], user_id, fallback_to_members=False
)
my_member_event = state_by_room[room_id][("m.room.member", user_id)]
@@ -349,16 +342,16 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
defer.returnValue(INVITE_FROM_PERSON % {
return INVITE_FROM_PERSON % {
"person": inviter_name,
"app": self.app_name
})
}
else:
defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
return INVITE_FROM_PERSON_TO_ROOM % {
"person": inviter_name,
"room": room_name,
"app": self.app_name,
})
}
sender_name = None
if len(notifs_by_room[room_id]) == 1:
@@ -369,24 +362,24 @@ class Mailer(object):
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
return MESSAGE_FROM_PERSON_IN_ROOM % {
"person": sender_name,
"room": room_name,
"app": self.app_name,
})
}
elif sender_name is not None:
defer.returnValue(MESSAGE_FROM_PERSON % {
return MESSAGE_FROM_PERSON % {
"person": sender_name,
"app": self.app_name,
})
}
else:
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
defer.returnValue(MESSAGES_IN_ROOM % {
return MESSAGES_IN_ROOM % {
"room": room_name,
"app": self.app_name,
})
}
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
@@ -395,22 +388,22 @@ class Mailer(object):
for n in notifs_by_room[room_id]
]))
defer.returnValue(MESSAGES_FROM_PERSON % {
return MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events([
state_by_room[room_id][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name,
})
}
else:
# Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail
if reason['room_name'] is not None:
defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
return MESSAGES_IN_ROOM_AND_OTHERS % {
"room": reason['room_name'],
"app": self.app_name,
})
}
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
@@ -419,13 +412,13 @@ class Mailer(object):
for n in notifs_by_room[reason['room_id']]
]))
defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
return MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events([
state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name,
})
}
def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS

View File

@@ -14,18 +14,17 @@
# limitations under the License.
from twisted.internet import defer
from synapse.push.presentable_names import (
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event
)
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.get_invited_rooms_for_user)(user_id),
preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True))
invites, joins = yield defer.gatherResults([
store.get_invited_rooms_for_user(user_id),
store.get_rooms_for_user(user_id),
], consumeErrors=True)
my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
@@ -49,22 +48,21 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks
def get_context_for_event(store, state_handler, ev, user_id):
def get_context_for_event(state_handler, ev, user_id):
ctx = {}
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
room_state = yield state_handler.get_current_state(ev.room_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or
# human-readable name instead, be that m.room.namer, an alias or
# a list of people in the room
name = yield calculate_room_name(
store, room_state_ids, user_id, fallback_to_single_member=False
name = calculate_room_name(
room_state, user_id, fallback_to_single_member=False
)
if name:
ctx['name'] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield store.get_event(sender_state_event_id)
sender_state_event = room_state[("m.room.member", ev.sender)]
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx)

View File

@@ -17,7 +17,7 @@
from twisted.internet import defer
import pusher
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.logcontext import preserve_fn
from synapse.util.async import run_on_reactor
import logging
@@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
def remove_pushers_by_user(self, user_id, except_token_ids=[]):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access tokens id %r",
user_id, except_access_token_id
"Removing all pushers for user %s except access tokens ids %r",
user_id, except_token_ids
)
for p in all:
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
@@ -130,12 +130,10 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
preserve_fn(p.on_new_notifications)(
min_stream_id, max_stream_id
)
p.on_new_notifications(min_stream_id, max_stream_id)
)
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
yield defer.gatherResults(deferreds)
except:
logger.exception("Exception in pusher on_new_notifications")
@@ -157,10 +155,10 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
p.on_new_receipts(min_stream_id, max_stream_id)
)
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
yield defer.gatherResults(deferreds)
except:
logger.exception("Exception in pusher on_new_receipts")

View File

@@ -51,9 +51,6 @@ CONDITIONAL_REQUIREMENTS = {
"ldap": {
"ldap3>=1.0": ["ldap3>=1.0"],
},
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
}

View File

@@ -40,8 +40,7 @@ STREAM_NAMES = (
("backfill",),
("push_rules",),
("pushers",),
("caches",),
("to_device",),
("state",),
)
@@ -71,7 +70,6 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters:
@@ -130,7 +128,7 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
caches_token = self.store.get_cache_stream_token()
state_token = self.store.get_state_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@@ -141,9 +139,7 @@ class ReplicationResource(Resource):
backfill_token,
push_rules_token,
pushers_token,
0, # State stream is no longer a thing
caches_token,
int(stream_token.to_device_key),
state_token,
))
@request_handler()
@@ -191,8 +187,7 @@ class ReplicationResource(Resource):
yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total)
@@ -366,31 +361,22 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams):
current_position = current_token.caches
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
caches = request_streams.get("caches")
state = request_streams.get("state")
if caches is not None:
updated_caches = yield self.store.get_all_updated_caches(
caches, current_position, limit
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
writer.write_header_and_rows("caches", updated_caches, (
"position", "cache_func", "keys", "invalidation_ts"
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
@defer.inlineCallbacks
def to_device(self, writer, current_token, limit, request_streams):
current_position = current_token.to_device
to_device = request_streams.get("to_device")
if to_device is not None:
to_device_rows = yield self.store.get_all_new_device_messages(
to_device, current_position, limit
)
writer.write_header_and_rows("to_device", to_device_rows, (
"position", "user_id", "device_id", "message_json"
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
@@ -421,7 +407,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device",
"push_rules", "pushers", "state"
))):
__slots__ = []

View File

@@ -14,43 +14,15 @@
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self):
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
return {}
def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)

View File

@@ -28,13 +28,3 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__

View File

@@ -1,42 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_inbox", "stream_id",
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("to_device")
if stream:
self._device_inbox_id_gen.advance(int(stream["position"]))
return super(SlavedDeviceInboxStore, self).process_replication(result)

View File

@@ -1,23 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
]

View File

@@ -93,11 +93,8 @@ class SlavedEventStore(BaseSlavedStore):
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range_for_http = (
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
)
get_unread_push_actions_for_user_in_range_for_email = (
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
)
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__
@@ -120,21 +117,10 @@ class SlavedEventStore(BaseSlavedStore):
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
_get_joined_users_from_context = (
RoomMemberStore.__dict__["_get_joined_users_from_context"]
)
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
is_host_joined = DataStore.is_host_joined.__func__
_is_host_joined = RoomMemberStore.__dict__["_is_host_joined"]
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = staticmethod(DataStore._set_before_and_after)
@@ -156,15 +142,6 @@ class SlavedEventStore(BaseSlavedStore):
_get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__
get_missing_events = DataStore.get_missing_events.__func__
_get_missing_events = DataStore._get_missing_events.__func__
get_auth_chain = DataStore.get_auth_chain.__func__
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
@@ -222,6 +199,7 @@ class SlavedEventStore(BaseSlavedStore):
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
@@ -245,6 +223,7 @@ class SlavedEventStore(BaseSlavedStore):
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering

View File

@@ -1,33 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.keys import KeyStore
class SlavedKeyStore(BaseSlavedStore):
_get_server_verify_key = KeyStore.__dict__[
"_get_server_verify_key"
]
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
store_server_verify_key = DataStore.store_server_verify_key.__func__
get_server_certificate = DataStore.get_server_certificate.__func__
store_server_certificate = DataStore.store_server_certificate.__func__
get_server_keys_json = DataStore.get_server_keys_json.__func__
store_server_keys_json = DataStore.store_server_keys_json.__func__

View File

@@ -25,9 +25,6 @@ class SlavedRegistrationStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
]
].orig
_query_for_auth = DataStore._query_for_auth.__func__
get_user_by_id = RegistrationStore.__dict__[
"get_user_by_id"
]

View File

@@ -1,21 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
class RoomStore(BaseSlavedStore):
get_public_room_ids = DataStore.get_public_room_ids.__func__

View File

@@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.transactions import TransactionStore
class TransactionStore(BaseSlavedStore):
get_destination_retry_timings = TransactionStore.__dict__[
"get_destination_retry_timings"
].orig
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
# For now, don't record the destination rety timings
def set_destination_retry_timings(*args, **kwargs):
return defer.succeed(None)

View File

@@ -46,10 +46,6 @@ from synapse.rest.client.v2_alpha import (
account_data,
report_event,
openid,
notifications,
devices,
thirdparty,
sendtodevice,
)
from synapse.http.server import JsonResource
@@ -94,7 +90,3 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)
thirdparty.register_servlets(hs, client_resource)
sendtodevice.register_servlets(hs, client_resource)

View File

@@ -28,10 +28,6 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(WhoisRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@@ -86,10 +82,6 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
)
def __init__(self, hs):
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)

View File

@@ -52,11 +52,8 @@ class ClientV1RestServlet(RestServlet):
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore()

View File

@@ -36,10 +36,6 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
def __init__(self, hs):
super(ClientDirectoryServer, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
@@ -150,7 +146,6 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):

View File

@@ -32,10 +32,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
super(EventStreamRestServlet, self).__init__(hs)
self.event_stream_handler = hs.get_event_stream_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(
@@ -50,6 +46,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args:
room_id = request.args["room_id"][0]
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
@@ -60,7 +57,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
chunk = yield self.event_stream_handler.get_stream(
chunk = yield handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@@ -83,12 +80,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id)
handler = self.handlers.event_handler
event = yield handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:

View File

@@ -23,10 +23,6 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$")
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)

View File

@@ -54,9 +54,11 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
flows = []
@@ -107,6 +109,17 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield self.http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
@@ -132,23 +145,15 @@ class LoginRestServlet(ClientV1RestServlet):
).to_string()
auth_handler = self.auth_handler
user_id = yield auth_handler.validate_password_login(
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id=user_id,
password=login_submission["password"],
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
password=login_submission["password"])
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
defer.returnValue((200, result))
@@ -160,19 +165,57 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
@@ -202,27 +245,18 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
device_id = yield self._register_device(
registered_user_id, login_submission
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name")
)
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": registered_user_id,
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
@@ -234,25 +268,32 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
return (user, attributes)
class SAML2RestServlet(ClientV1RestServlet):
@@ -261,7 +302,6 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -299,6 +339,18 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@@ -330,8 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -364,13 +414,13 @@ class CasTicketServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
login_token = auth_handler.generate_short_term_login_token(user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
@@ -384,39 +434,30 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
user = None
attributes = None
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = (root[0].tag.endswith("authenticationSuccess"))
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
if attributes is None:
raise Exception("CAS response does not contain attributes")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
return user, attributes
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def register_servlets(hs, http_server):
@@ -426,3 +467,5 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View File

@@ -24,10 +24,6 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@@ -66,10 +62,6 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@@ -107,10 +99,6 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)

View File

@@ -52,10 +52,6 @@ class RegisterRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as:
# self.sessions = {
@@ -64,8 +60,6 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@@ -305,10 +299,9 @@ class RegisterRestServlet(ClientV1RestServlet):
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
user_id = yield handler.appservice_register(
(user_id, token) = yield handler.appservice_register(
user_localpart, as_token
)
token = yield self.auth_handler.issue_access_token(user_id)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
@@ -384,7 +377,6 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -437,7 +429,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
duration_in_ms=(duration_seconds * 1000),
duration_seconds=duration_seconds,
password_hash=password_hash
)

View File

@@ -20,14 +20,12 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request
import logging
import urllib
import ujson as json
logger = logging.getLogger(__name__)
@@ -35,10 +33,6 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
@@ -86,10 +80,6 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -174,10 +164,6 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@@ -222,9 +208,6 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -268,7 +251,6 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
action="join",
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
content=content,
third_party_signed=content.get("third_party_signed", None),
)
@@ -312,10 +294,6 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@@ -342,10 +320,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -353,19 +327,31 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
as_client_event=as_client_event,
event_filter=event_filter,
as_client_event=as_client_event
)
defer.returnValue((200, msgs))
class RoomFileListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/files$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(
request, default_limit=10, default_dir='b',
)
handler = self.handlers.message_handler
msgs = yield handler.get_files(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
)
defer.returnValue((200, msgs))
@@ -375,10 +361,6 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -396,10 +378,6 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -420,7 +398,6 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -457,10 +434,6 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@@ -499,10 +472,6 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@@ -583,10 +552,6 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server)
@@ -669,10 +634,6 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$"
)
def __init__(self, hs):
super(SearchRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
@@ -725,6 +686,7 @@ def register_servlets(hs, http_server):
RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
RoomFileListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)

View File

@@ -25,9 +25,7 @@ import logging
logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,),
v2_alpha=True,
unstable=True):
def client_v2_patterns(path_regex, releases=(0,)):
"""Creates a regex compiled client path with the correct client path
prefix.
@@ -37,10 +35,7 @@ def client_v2_patterns(path_regex, releases=(0,),
Returns:
SRE_Pattern
"""
patterns = []
if v2_alpha:
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
if unstable:
patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:

View File

@@ -28,40 +28,8 @@ import logging
logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password$")
PATTERNS = client_v2_patterns("/account/password")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
@@ -121,83 +89,8 @@ class PasswordRestServlet(RestServlet):
return 200, {}
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
if user_id != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
defer.returnValue((200, {}))
class ThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid$")
PATTERNS = client_v2_patterns("/account/3pid")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
@@ -264,8 +157,5 @@ class ThreepidRestServlet(RestServlet):
def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)

View File

@@ -1,100 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DevicesRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
device = yield self.device_handler.get_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)

Some files were not shown because too many files have changed in this diff Show More