Compare commits

..

2 Commits

Author SHA1 Message Date
Erik Johnston
9854f4c7ff Add basic file API 2016-07-13 16:29:35 +01:00
Erik Johnston
518b3a3f89 Track in DB file message events 2016-07-13 15:16:02 +01:00
252 changed files with 5167 additions and 14747 deletions

8
.gitignore vendored
View File

@@ -24,10 +24,10 @@ homeserver*.yaml
.coverage .coverage
htmlcov htmlcov
demo/*/*.db demo/*.db
demo/*/*.log demo/*.log
demo/*/*.log.* demo/*.log.*
demo/*/*.pid demo/*.pid
demo/media_store.* demo/media_store.*
demo/etc demo/etc

View File

@@ -1,408 +1,3 @@
Changes in synapse v0.18.4-rc1 (2016-11-14)
===========================================
Changes:
* Various database efficiency improvements (PR #1188, #1192)
* Update default config to blacklist more internal IPs, thanks to Euan Kemp (PR
#1198)
* Allow specifying duration in minutes in config, thanks to Daniel Dent (PR
#1625)
Bug fixes:
* Fix media repo to set CORs headers on responses (PR #1190)
* Fix registration to not error on non-ascii passwords (PR #1191)
* Fix create event code to limit the number of prev_events (PR #1615)
* Fix bug in transaction ID deduplication (PR #1624)
Changes in synapse v0.18.3 (2016-11-08)
=======================================
SECURITY UPDATE
Explicitly require authentication when using LDAP3. This is the default on
versions of ``ldap3`` above 1.0, but some distributions will package an older
version.
If you are using LDAP3 login and have a version of ``ldap3`` older than 1.0 it
is **CRITICAL to updgrade**.
Changes in synapse v0.18.2 (2016-11-01)
=======================================
No changes since v0.18.2-rc5
Changes in synapse v0.18.2-rc5 (2016-10-28)
===========================================
Bug fixes:
* Fix prometheus process metrics in worker processes (PR #1184)
Changes in synapse v0.18.2-rc4 (2016-10-27)
===========================================
Bug fixes:
* Fix ``user_threepids`` schema delta, which in some instances prevented
startup after upgrade (PR #1183)
Changes in synapse v0.18.2-rc3 (2016-10-27)
===========================================
Changes:
* Allow clients to supply access tokens as headers (PR #1098)
* Clarify error codes for GET /filter/, thanks to Alexander Maznev (PR #1164)
* Make password reset email field case insensitive (PR #1170)
* Reduce redundant database work in email pusher (PR #1174)
* Allow configurable rate limiting per AS (PR #1175)
* Check whether to ratelimit sooner to avoid work (PR #1176)
* Standardise prometheus metrics (PR #1177)
Bug fixes:
* Fix incredibly slow back pagination query (PR #1178)
* Fix infinite typing bug (PR #1179)
Changes in synapse v0.18.2-rc2 (2016-10-25)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.18.2-rc1 (2016-10-17)
===========================================
Changes:
* Remove redundant event_auth index (PR #1113)
* Reduce DB hits for replication (PR #1141)
* Implement pluggable password auth (PR #1155)
* Remove rate limiting from app service senders and fix get_or_create_user
requester, thanks to Patrik Oldsberg (PR #1157)
* window.postmessage for Interactive Auth fallback (PR #1159)
* Use sys.executable instead of hardcoded python, thanks to Pedro Larroy
(PR #1162)
* Add config option for adding additional TLS fingerprints (PR #1167)
* User-interactive auth on delete device (PR #1168)
Bug fixes:
* Fix not being allowed to set your own state_key, thanks to Patrik Oldsberg
(PR #1150)
* Fix interactive auth to return 401 from for incorrect password (PR #1160,
#1166)
* Fix email push notifs being dropped (PR #1169)
Changes in synapse v0.18.1 (2016-10-05)
======================================
No changes since v0.18.1-rc1
Changes in synapse v0.18.1-rc1 (2016-09-30)
===========================================
Features:
* Add total_room_count_estimate to ``/publicRooms`` (PR #1133)
Changes:
* Time out typing over federation (PR #1140)
* Restructure LDAP authentication (PR #1153)
Bug fixes:
* Fix 3pid invites when server is already in the room (PR #1136)
* Fix upgrading with SQLite taking lots of CPU for a few days
after upgrade (PR #1144)
* Fix upgrading from very old database versions (PR #1145)
* Fix port script to work with recently added tables (PR #1146)
Changes in synapse v0.18.0 (2016-09-19)
=======================================
The release includes major changes to the state storage database schemas, which
significantly reduce database size. Synapse will attempt to upgrade the current
data in the background. Servers with large SQLite database may experience
degradation of performance while this upgrade is in progress, therefore you may
want to consider migrating to using Postgres before upgrading very large SQLite
databases
Changes:
* Make public room search case insensitive (PR #1127)
Bug fixes:
* Fix and clean up publicRooms pagination (PR #1129)
Changes in synapse v0.18.0-rc1 (2016-09-16)
===========================================
Features:
* Add ``only=highlight`` on ``/notifications`` (PR #1081)
* Add server param to /publicRooms (PR #1082)
* Allow clients to ask for the whole of a single state event (PR #1094)
* Add is_direct param to /createRoom (PR #1108)
* Add pagination support to publicRooms (PR #1121)
* Add very basic filter API to /publicRooms (PR #1126)
* Add basic direct to device messaging support for E2E (PR #1074, #1084, #1104,
#1111)
Changes:
* Move to storing state_groups_state as deltas, greatly reducing DB size (PR
#1065)
* Reduce amount of state pulled out of the DB during common requests (PR #1069)
* Allow PDF to be rendered from media repo (PR #1071)
* Reindex state_groups_state after pruning (PR #1085)
* Clobber EDUs in send queue (PR #1095)
* Conform better to the CAS protocol specification (PR #1100)
* Limit how often we ask for keys from dead servers (PR #1114)
Bug fixes:
* Fix /notifications API when used with ``from`` param (PR #1080)
* Fix backfill when cannot find an event. (PR #1107)
Changes in synapse v0.17.3 (2016-09-09)
=======================================
This release fixes a major bug that stopped servers from handling rooms with
over 1000 members.
Changes in synapse v0.17.2 (2016-09-08)
=======================================
This release contains security bug fixes. Please upgrade.
No changes since v0.17.2-rc1
Changes in synapse v0.17.2-rc1 (2016-09-05)
===========================================
Features:
* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050,
#1062, #1066)
Changes:
* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063,
#1068)
* Don't notify for online to online presence transitions. (PR #1054)
* Occasionally persist unpersisted presence updates (PR #1055)
* Allow application services to have an optional 'url' (PR #1056)
* Clean up old sent transactions from DB (PR #1059)
Bug fixes:
* Fix None check in backfill (PR #1043)
* Fix membership changes to be idempotent (PR #1067)
* Fix bug in get_pdu where it would sometimes return events with incorrect
signature
Changes in synapse v0.17.1 (2016-08-24)
=======================================
Changes:
* Delete old received_transactions rows (PR #1038)
* Pass through user-supplied content in /join/$room_id (PR #1039)
Bug fixes:
* Fix bug with backfill (PR #1040)
Changes in synapse v0.17.1-rc1 (2016-08-22)
===========================================
Features:
* Add notification API (PR #1028)
Changes:
* Don't print stack traces when failing to get remote keys (PR #996)
* Various federation /event/ perf improvements (PR #998)
* Only process one local membership event per room at a time (PR #1005)
* Move default display name push rule (PR #1011, #1023)
* Fix up preview URL API. Add tests. (PR #1015)
* Set ``Content-Security-Policy`` on media repo (PR #1021)
* Make notify_interested_services faster (PR #1022)
* Add usage stats to prometheus monitoring (PR #1037)
Bug fixes:
* Fix token login (PR #993)
* Fix CAS login (PR #994, #995)
* Fix /sync to not clobber status_msg (PR #997)
* Fix redacted state events to include prev_content (PR #1003)
* Fix some bugs in the auth/ldap handler (PR #1007)
* Fix backfill request to limit URI length, so that remotes don't reject the
requests due to path length limits (PR #1012)
* Fix AS push code to not send duplicate events (PR #1025)
Changes in synapse v0.17.0 (2016-08-08)
=======================================
This release contains significant security bug fixes regarding authenticating
events received over federation. PLEASE UPGRADE.
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Changes:
* Add federation /version API (PR #990)
* Make psutil dependency optional (PR #992)
Bug fixes:
* Fix URL preview API to exclude HTML comments in description (PR #988)
* Fix error handling of remote joins (PR #991)
Changes in synapse v0.17.0-rc4 (2016-08-05)
===========================================
Changes:
* Change the way we summarize URLs when previewing (PR #973)
* Add new ``/state_ids/`` federation API (PR #979)
* Speed up processing of ``/state/`` response (PR #986)
Bug fixes:
* Fix event persistence when event has already been partially persisted
(PR #975, #983, #985)
* Fix port script to also copy across backfilled events (PR #982)
Changes in synapse v0.17.0-rc3 (2016-08-02)
===========================================
Changes:
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
* Add some basic admin API docs (PR #963)
Bug fixes:
* Send the correct host header when fetching keys (PR #941)
* Fix joining a room that has missing auth events (PR #964)
* Fix various push bugs (PR #966, #970)
* Fix adding emails on registration (PR #968)
Changes in synapse v0.17.0-rc2 (2016-08-02)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08) Changes in synapse v0.16.1-r1 (2016-07-08)
========================================== ==========================================

View File

@@ -14,7 +14,6 @@ recursive-include docs *
recursive-include res * recursive-include res *
recursive-include scripts * recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py recursive-include tests *.py
recursive-include synapse/static *.css recursive-include synapse/static *.css
@@ -24,7 +23,5 @@ recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh exclude jenkins*.sh
exclude jenkins*
recursive-exclude jenkins *.sh
prune demo/etc prune demo/etc

View File

@@ -11,8 +11,8 @@ VoIP. The basics you need to know to get up and running are:
like ``#matrix:matrix.org`` or ``#test:localhost:8448``. like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future - Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
you will normally refer to yourself and others using a third party identifier you will normally refer to yourself and others using a 3PID: email
(3PID): email address, phone number, etc rather than manipulating Matrix user IDs) address, phone number, etc rather than manipulating Matrix user IDs)
The overall architecture is:: The overall architecture is::
@@ -95,7 +95,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
- Python 2.7 - Python 2.7
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org - At least 512 MB RAM.
Synapse is written in python but some of the libraries is uses are written in Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the C. So before we can install synapse itself we need a working C compiler and the
@@ -134,12 +134,6 @@ Installing prerequisites on Raspbian::
sudo pip install --upgrade ndg-httpsclient sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv sudo pip install --upgrade virtualenv
Installing prerequisites on openSUSE::
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@@ -205,21 +199,6 @@ run (e.g. ``~/.synapse``), and::
source ./bin/activate source ./bin/activate
synctl start synctl start
Security Note
=============
Matrix serves raw user generated data in some APIs - specifically the content
repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
Whilst we have tried to mitigate against possible XSS attacks (e.g.
https://github.com/matrix-org/synapse/pull/1021) we recommend running
matrix homeservers on a dedicated domain name, to limit any malicious user generated
content served to web browsers a matrix API from being able to attack webapps hosted
on the same domain. This is particularly true of sharing a matrix webclient and
server on the same domain.
See https://github.com/vector-im/vector-web/issues/1977 and
https://developer.github.com/changes/2014-04-25-user-content-security for more details.
Using PostgreSQL Using PostgreSQL
================ ================
@@ -236,6 +215,9 @@ The advantages of Postgres include:
pointing at the same DB master, as well as enabling DB replication in pointing at the same DB master, as well as enabling DB replication in
synapse itself. synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
For information on how to install and use PostgreSQL, please see For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_. `docs/postgres.rst <docs/postgres.rst>`_.
@@ -463,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs: IDs:
1) Use the machine's own hostname as available on public DNS in the form of 1) Use the machine's own hostname as available on public DNS in the form of
its A records. This is easier to set up initially, perhaps for its A or AAAA records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV. testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV 2) Set up a SRV record for your domain name. This requires you create a SRV

View File

@@ -27,7 +27,7 @@ running:
# Pull the latest version of the master branch. # Pull the latest version of the master branch.
git pull git pull
# Update the versions of synapse's python dependencies. # Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.15.0 Upgrading to v0.15.0

View File

@@ -1,12 +0,0 @@
Admin APIs
==========
This directory includes documentation for the various synapse specific admin
APIs available.
Only users that are server admins can use these APIs. A user can be marked as a
server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register.

View File

@@ -1,15 +0,0 @@
Purge History API
=================
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.

View File

@@ -1,19 +0,0 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
{
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View File

@@ -15,45 +15,36 @@ How to monitor Synapse metrics using Prometheus
Restart synapse Restart synapse
3: Add a prometheus target for synapse. It needs to set the ``metrics_path`` 3: Check out synapse-prometheus-config
to a non-default value:: https://github.com/matrix-org/synapse-prometheus-config
- job_name: "synapse" 4: Add ``synapse.html`` and ``synapse.rules``
metrics_path: "/_synapse/metrics" The ``.html`` file needs to appear in prometheus's ``consoles`` directory,
static_configs: and the ``.rules`` file needs to be invoked somewhere in the main config
- targets: file. A symlink to each from the git checkout into the prometheus directory
"my.server.here:9092" might be easiest to ensure ``git pull`` keeps it updated.
Standard Metric Names 5: Add a prometheus target for synapse
--------------------- This is easiest if prometheus runs on the same machine as synapse, as it can
then just use localhost::
As of synapse version 0.18.2, the format of the process-wide metrics has been global: {
changed to fit prometheus standard naming conventions. Additionally the units rule_file: "synapse.rules"
have been changed to seconds, from miliseconds. }
================================== ============================= job: {
New name Old name name: "synapse"
---------------------------------- -----------------------------
process_cpu_user_seconds_total process_resource_utime / 1000
process_cpu_system_seconds_total process_resource_stime / 1000
process_open_fds (no 'type' label) process_fds
================================== =============================
The python-specific counts of garbage collector performance have been renamed. target_group: {
target: "http://localhost:9092/"
}
}
=========================== ====================== 6: Start prometheus::
New name Old name
--------------------------- ----------------------
python_gc_time reactor_gc_time
python_gc_unreachable_total reactor_gc_unreachable
python_gc_counts reactor_gc_counts
=========================== ======================
The twisted-specific reactor metrics have been renamed. ./prometheus -config.file=prometheus.conf
==================================== ===================== 7: Wait a few seconds for it to start and perform the first scrape,
New name Old name then visit the console:
------------------------------------ ---------------------
python_twisted_reactor_pending_calls reactor_pending_calls http://server-where-prometheus-runs:9090/consoles/synapse.html
python_twisted_reactor_tick_time reactor_tick_time
==================================== =====================

View File

@@ -1,98 +0,0 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
* synapse.app.client_reader - handles client API endpoints like /publicRooms
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View File

@@ -4,21 +4,84 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
./dendron/jenkins/build_dendron.sh # Write coverage reports to a separate file for each process
./sytest/jenkins/prep_sytest_for_postgres.sh export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/install_and_run.sh \ # Output flake8 violations to violations.flake8.log
--synapse-directory $WORKSPACE \ # Don't exit with non-0 status code on Jenkins,
--dendron $WORKSPACE/dendron/bin/dendron \ # so that the build steps continue and a later step can decided whether to
--pusher \ # UNSTABLE or FAILURE this build.
--synchrotron \ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
--federation-reader \
--client-reader \ rm .coverage* || echo "No coverage files to remove"
--appservice \
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .dendron-base ]]; then
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
else
(cd .dendron-base; git fetch -p)
fi
rm -rf dendron
git clone .dendron-base dendron --shared
cd dendron
: ${GOPATH:=${WORKSPACE}/.gopath}
if [[ "${GOPATH}" != *:* ]]; then
mkdir -p "${GOPATH}"
export PATH="${GOPATH}/bin:${PATH}"
fi
export GOPATH
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
go get github.com/constabulary/gb/...
gb generate
gb build
cd ..
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
mkdir -p var
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1))
cd ..

View File

@@ -4,14 +4,61 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/prep_sytest_for_postgres.sh # Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
./sytest/jenkins/install_and_run.sh \ rm .coverage* || echo "No coverage files to remove"
--synapse-directory $WORKSPACE \
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
: ${PORT_COUNT=20}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -4,12 +4,55 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/install_and_run.sh \ # Output flake8 violations to violations.flake8.log
--synapse-directory $WORKSPACE \ # Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_COUNT=20}
: ${PORT_BASE:=8000}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-range ${PORT_BASE}:$((PORT_BASE+PORT_COUNT-1)) \
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -22,9 +22,4 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove" rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27 tox -e py27

View File

@@ -1,44 +0,0 @@
#! /bin/bash
# This clones a project from github into a named subdirectory
# If the project has a branch with the same name as this branch
# then it will checkout that branch after cloning.
# Otherwise it will checkout "origin/develop."
# The first argument is the name of the directory to checkout
# the branch into.
# The second argument is the URL of the remote repository to checkout.
# Usually something like https://github.com/matrix-org/sytest.git
set -eux
NAME=$1
PROJECT=$2
BASE=".$NAME-base"
# Update our mirror.
if [ ! -d ".$NAME-base" ]; then
# Create a local mirror of the source repository.
# This saves us from having to download the entire repository
# when this script is next run.
git clone "$PROJECT" "$BASE" --mirror
else
# Fetch any updates from the source repository.
(cd "$BASE"; git fetch -p)
fi
# Remove the existing repository so that we have a clean copy
rm -rf "$NAME"
# Cloning with --shared means that we will share portions of the
# .git directory with our local mirror.
git clone "$BASE" "$NAME" --shared
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd "$NAME"
# check out the relevant branch
git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop"
)

View File

@@ -1,20 +0,0 @@
#! /bin/bash
cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox
mkdir -p $TOX_DIR
if ! [ $TOX_DIR -ef .tox ]; then
ln -s "$TOX_DIR" .tox
fi
# set up the virtualenv
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View File

@@ -18,9 +18,7 @@
<div class="summarytext">{{ summary_text }}</div> <div class="summarytext">{{ summary_text }}</div>
</td> </td>
<td class="logo"> <td class="logo">
{% if app_name == "Riot" %} {% if app_name == "Vector" %}
<img src="http://matrix.org/img/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %} {% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>

View File

@@ -116,19 +116,17 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = [] authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items(): for key, sig in signed_json["signatures"][origin_name].items():
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( authorization_headers.append(bytes(
origin_name, key, sig, "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
) origin_name, key, sig,
authorization_headers.append(bytes(header)) )
sys.stderr.write(header) ))
sys.stderr.write("\n")
result = requests.get( result = requests.get(
lookup(destination, path), lookup(destination, path),
headers={"Authorization": authorization_headers[0]}, headers={"Authorization": authorization_headers[0]},
verify=False, verify=False,
) )
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json() return result.json()
@@ -143,7 +141,6 @@ def main():
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)
print ""
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -34,12 +34,11 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"], "events": ["processed", "outlier"],
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"], "presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"],
} }
@@ -72,14 +71,6 @@ APPEND_ONLY_TABLES = [
"event_to_state_groups", "event_to_state_groups",
"rejections", "rejections",
"event_search", "event_search",
"presence_stream",
"push_rules_stream",
"current_state_resets",
"ex_outlier_stream",
"cache_invalidation_stream",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
] ]
@@ -101,12 +92,8 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__[ _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
"_simple_select_one_onecol_txn"
]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@@ -171,40 +158,31 @@ class Porter(object):
def setup_table(self, table): def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
row = yield self.postgres_store._simple_select_one( next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"), retcol="rowid",
allow_none=True, allow_none=True,
) )
total_to_port = None total_to_port = None
if row is None: if next_chunk is None:
if table == "sent_transactions": if table == "sent_transactions":
forward_chunk, already_ported, total_to_port = ( next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions() yield self._setup_sent_transactions()
) )
backward_chunk = 0
else: else:
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "rowid": 1}
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 next_chunk = 1
backward_chunk = 0
already_ported = 0 already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, next_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
@@ -218,85 +196,46 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "rowid": 0}
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 next_chunk = 1
backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, next_chunk
) )
defer.returnValue( defer.returnValue((table, already_ported, total_to_port, next_chunk))
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk, def handle_table(self, table, postgres_size, table_size, next_chunk):
backward_chunk):
if not table_size: if not table_size:
return return
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search": if table == "event_search":
yield self.handle_search_table( yield self.handle_search_table(postgres_size, table_size, next_chunk)
postgres_size, table_size, forward_chunk, backward_chunk
)
return return
forward_select = ( select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,) % (table,)
) )
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
)
do_forward = [True]
do_backward = [True]
while True: while True:
def r(txn): def r(txn):
forward_rows = [] txn.execute(select, (next_chunk, self.batch_size,))
backward_rows = [] rows = txn.fetchall()
if do_forward[0]: headers = [column[0] for column in txn.description]
txn.execute(forward_select, (forward_chunk, self.batch_size,))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
if do_backward[0]: return headers, rows
txn.execute(backward_select, (backward_chunk, self.batch_size,))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
if forward_rows or backward_rows: headers, rows = yield self.sqlite_store.runInteraction("select", r)
headers = [column[0] for column in txn.description]
else:
headers = None
return headers, forward_rows, backward_rows if rows:
next_chunk = rows[-1][0] + 1
headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
if frows or brows:
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows
self._convert_rows(table, headers, rows) self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
@@ -308,10 +247,7 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
updatevalues={ updatevalues={"rowid": next_chunk},
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@@ -323,8 +259,7 @@ class Porter(object):
return return
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk, def handle_search_table(self, postgres_size, table_size, next_chunk):
backward_chunk):
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@@ -335,7 +270,7 @@ class Porter(object):
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,)) txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@@ -344,7 +279,7 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r) headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows: if rows:
forward_chunk = rows[-1][0] + 1 next_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a # We have to treat event_search differently since it has a
# different structure in the two different databases. # different structure in the two different databases.
@@ -377,10 +312,7 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": "event_search"}, keyvalues={"table_name": "event_search"},
updatevalues={ updatevalues={"rowid": next_chunk},
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@@ -392,6 +324,7 @@ class Porter(object):
else: else:
return return
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
@@ -462,32 +395,10 @@ class Porter(object):
txn.execute( txn.execute(
"CREATE TABLE port_from_sqlite3 (" "CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
" forward_rowid bigint NOT NULL," " rowid bigint NOT NULL"
" backward_rowid bigint NOT NULL"
")" ")"
) )
# The old port script created a table with just a "rowid" column.
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
)
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
@@ -547,7 +458,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _setup_sent_transactions(self): def _setup_sent_transactions(self):
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000 yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination # And save the max transaction id from each destination
select = ( select = (
@@ -603,11 +514,7 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": "sent_transactions", "rowid": next_chunk}
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
@@ -628,18 +535,13 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count)) defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): def _get_remaining_count_to_port(self, table, next_chunk):
frows = yield self.sqlite_store.execute_sql( rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
forward_chunk, next_chunk,
) )
brows = yield self.sqlite_store.execute_sql( defer.returnValue(rows[0][0])
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
)
defer.returnValue(frows[0][0] + brows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
@@ -650,10 +552,10 @@ class Porter(object):
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults( remaining, done = yield defer.gatherResults(
[ [
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk), self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table), self._get_already_ported_count(table),
], ],
consumeErrors=True, consumeErrors=True,
@@ -784,7 +686,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len - len(table), i+2, left_margin + max_len - len(table),
table, table,
curses.A_BOLD | color, curses.A_BOLD | color,
) )
@@ -792,18 +694,18 @@ class CursesProgress(Progress):
size = 20 size = 20
progress = "[%s%s]" % ( progress = "[%s%s]" % (
"#" * int(perc * size / 100), "#" * int(perc*size/100),
" " * (size - int(perc * size / 100)), " " * (size - int(perc*size/100)),
) )
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space, i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(
rows - 1, 0, rows-1, 0,
"Press any key to exit...", "Press any key to exit...",
) )

View File

@@ -16,5 +16,7 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503
[pep8]
max-line-length = 90

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.18.4" __version__ = "0.16.1-r1"

View File

@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import pymacaroons
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64
import logging
import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class Auth(object):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at # Docs for these currently lives at
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst # https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to # In addition, we have type == delete_pusher which grants access only to
# delete pushers. # delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
@@ -63,18 +63,7 @@ class Auth(object):
"user_id = ", "user_id = ",
]) ])
@defer.inlineCallbacks def check(self, event, auth_events):
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@@ -90,30 +79,6 @@ class Auth(object):
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now) # are trusting that this is allowed (at least for now)
@@ -121,12 +86,6 @@ class Auth(object):
return True return True
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME # FIXME
return True return True
@@ -149,22 +108,6 @@ class Auth(object):
# FIXME: Temp hack # FIXME: Temp hack
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True return True
logger.debug( logger.debug(
@@ -295,17 +238,21 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"): curr_state = yield self.state.get_current_state(room_id)
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
entry = yield self.state.resolve_state_groups( for event in curr_state.values():
room_id, latest_event_ids if event.type == EventTypes.Member:
) try:
if get_domain_from_id(event.state_key) != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
continue
ret = yield self.store.is_host_joined( if event.content["membership"] == Membership.JOIN:
room_id, host, entry.state_group, entry.state defer.returnValue(True)
)
defer.returnValue(ret) defer.returnValue(False)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
@@ -400,10 +347,6 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content: if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events): if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.") raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
@@ -508,9 +451,6 @@ class Auth(object):
if not invite_event: if not invite_event:
return False return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id: if event.user_id != invite_event.user_id:
return False return False
@@ -597,32 +537,27 @@ class Auth(object):
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object tuple of:
UserID (str)
Access token ID (str)
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id, app_service = yield self._get_appservice_user_id(request) user_id = yield self._get_appservice_user_id(request.args)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(
synapse.types.create_requester(user_id, app_service=app_service) Requester(UserID.from_string(user_id), "", False)
) )
access_token = get_access_token_from_request( access_token = request.args["access_token"][0]
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
user_info = yield self.get_user_by_access_token(access_token, rights) user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
"User-Agent", "User-Agent",
@@ -634,8 +569,7 @@ class Auth(object):
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent, user_agent=user_agent
device_id=device_id,
) )
if is_guest and not allow_guest: if is_guest and not allow_guest:
@@ -645,9 +579,7 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(synapse.types.create_requester( defer.returnValue(Requester(user, token_id, is_guest))
user, token_id, is_guest, device_id, app_service=app_service)
)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -655,21 +587,19 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request_args):
app_service = self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
get_access_token_from_request( request_args["access_token"][0]
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
) )
if app_service is None: if app_service is None:
defer.returnValue((None, None)) defer.returnValue(None)
if "user_id" not in request.args: if "user_id" not in request_args:
defer.returnValue((app_service.sender, app_service)) defer.returnValue(app_service.sender)
user_id = request.args["user_id"][0] user_id = request_args["user_id"][0]
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service)) defer.returnValue(app_service.sender)
if not app_service.is_interested_in_user(user_id): if not app_service.is_interested_in_user(user_id):
raise AuthError( raise AuthError(
@@ -681,7 +611,7 @@ class Auth(object):
403, 403,
"Application service has not registered this user" "Application service has not registered this user"
) )
defer.returnValue((user_id, app_service)) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"): def get_user_by_access_token(self, token, rights="access"):
@@ -699,10 +629,7 @@ class Auth(object):
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token) ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -710,25 +637,33 @@ class Auth(object):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
user_id = self.get_user_id_from_macaroon(macaroon) user_prefix = "user_id = "
user = UserID.from_string(user_id) user = None
user_id = None
guest = False
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user_id = caveat.caveat_id[len(user_prefix):]
user = UserID.from_string(user_id)
elif caveat.caveat_id == "guest = true":
guest = True
self.validate_macaroon( self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token, macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id, user_id=user_id,
) )
guest = False if user is None:
for caveat in macaroon.caveats: raise AuthError(
if caveat.caveat_id == "guest = true": self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
guest = True errcode=Codes.UNKNOWN_TOKEN
)
if guest: if guest:
ret = { ret = {
"user": user, "user": user,
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
"device_id": None,
} }
elif rights == "delete_pusher": elif rights == "delete_pusher":
# We don't store these tokens in the database # We don't store these tokens in the database
@@ -736,20 +671,13 @@ class Auth(object):
"user": user, "user": user,
"is_guest": False, "is_guest": False,
"token_id": None, "token_id": None,
"device_id": None,
} }
else: else:
# This codepath exists for several reasons: # This codepath exists so that we can actually return a
# * so that we can actually return a token ID, which is used # token ID, because we use token IDs in place of device
# in some parts of the schema (where we probably ought to # identifiers throughout the codebase.
# use device IDs instead) # TODO(daniel): Remove this fallback when device IDs are
# * the only way we currently have to invalidate an # properly implemented.
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str) ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user: if ret["user"] != user:
logger.error( logger.error(
@@ -769,29 +697,6 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
AuthError if there is no user_id caveat in the macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id): def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
@@ -803,7 +708,6 @@ class Auth(object):
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet. token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
@@ -847,23 +751,18 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False, "is_guest": False,
"device_id": ret.get("device_id"),
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = get_access_token_from_request( token = request.args["access_token"][0]
request, self.TOKEN_NOT_FOUND_HTTP_STATUS service = yield self.store.get_app_service_by_token(token)
)
service = self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
raise AuthError( raise AuthError(
@@ -872,7 +771,7 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
request.authenticated_entity = service.sender request.authenticated_entity = service.sender
return defer.succeed(service) defer.returnValue(service)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token." self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
@@ -883,7 +782,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids
@@ -891,32 +790,30 @@ class Auth(object):
builder.auth_events = auth_events_entries builder.auth_events = auth_events_entries
@defer.inlineCallbacks def compute_auth_events(self, event, current_state):
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
defer.returnValue([]) return []
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event_id = current_state_ids.get(key) power_level_event = current_state.get(key)
if power_level_event_id: if power_level_event:
auth_ids.append(power_level_event_id) auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event_id = current_state_ids.get(key) join_rule_event = current_state.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event_id = current_state_ids.get(key) member_event = current_state.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event_id = current_state_ids.get(key) create_event = current_state.get(key)
if create_event_id: if create_event:
auth_ids.append(create_event_id) auth_ids.append(create_event.event_id)
if join_rule_event_id: if join_rule_event:
join_rule_event = yield self.store.get_event(join_rule_event_id)
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else: else:
@@ -925,21 +822,15 @@ class Auth(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
e_type = event.content["membership"] e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event_id: if join_rule_event:
auth_ids.append(join_rule_event_id) auth_ids.append(join_rule_event.event_id)
if e_type == Membership.JOIN: if e_type == Membership.JOIN:
if member_event_id and not is_public: if member_event and not is_public:
auth_ids.append(member_event_id) auth_ids.append(member_event.event_id)
else: else:
if member_event_id: if member_event:
auth_ids.append(member_event_id) auth_ids.append(member_event.event_id)
if for_verification:
key = (EventTypes.Member, event.state_key, )
existing_event_id = current_state_ids.get(key)
if existing_event_id:
auth_ids.append(existing_event_id)
if e_type == Membership.INVITE: if e_type == Membership.INVITE:
if "third_party_invite" in event.content: if "third_party_invite" in event.content:
@@ -947,15 +838,14 @@ class Auth(object):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
third_party_invite_id = current_state_ids.get(key) third_party_invite = current_state.get(key)
if third_party_invite_id: if third_party_invite:
auth_ids.append(third_party_invite_id) auth_ids.append(third_party_invite.event_id)
elif member_event_id: elif member_event:
member_event = yield self.store.get_event(member_event_id)
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
defer.returnValue(auth_ids) return auth_ids
def _get_send_level(self, etype, state_key, auth_events): def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
@@ -1004,6 +894,16 @@ class Auth(object):
403, 403,
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True return True
@@ -1161,68 +1061,3 @@ class Auth(object):
"This server requires you to be a moderator in the room to" "This server requires you to be a moderator in the room to"
" edit its room list entry" " edit its room list entry"
) )
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
auth_headers = request.requestHeaders.getRawHeaders("Authorization")
return bool(query_params) or bool(auth_headers)
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
auth_headers = request.requestHeaders.getRawHeaders("Authorization")
query_params = request.args.get("access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
raise AuthError(
token_not_found_http_status,
"Mixing Authorization headers and access_token query parameters.",
errcode=Codes.MISSING_TOKEN,
)
if len(auth_headers) > 1:
raise AuthError(
token_not_found_http_status,
"Too many Authorization headers.",
errcode=Codes.MISSING_TOKEN,
)
parts = auth_headers[0].split(" ")
if parts[0] == "Bearer" and len(parts) == 2:
return parts[1]
else:
raise AuthError(
token_not_found_http_status,
"Invalid Authorization header.",
errcode=Codes.MISSING_TOKEN,
)
else:
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]

View File

@@ -85,8 +85,3 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"

View File

@@ -43,7 +43,6 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE" THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"

View File

@@ -191,17 +191,6 @@ class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
@@ -220,10 +209,9 @@ class Filter(object):
event.get("room_id", None), event.get("room_id", None),
sender, sender,
event.get("type", None), event.get("type", None),
"url" in event.get("content", {})
) )
def check_fields(self, room_id, sender, event_type, contains_url): def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
Returns: Returns:
@@ -237,20 +225,15 @@ class Filter(object):
for name, match_func in literal_keys.items(): for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,) not_name = "not_%s" % (name,)
disallowed_values = getattr(self, not_name) disallowed_values = self.filter_json.get(not_name, [])
if any(map(match_func, disallowed_values)): if any(map(match_func, disallowed_values)):
return False return False
allowed_values = getattr(self, name) allowed_values = self.filter_json.get(name, None)
if allowed_values is not None: if allowed_values is not None:
if not any(map(match_func, allowed_values)): if not any(map(match_func, allowed_values)):
return False return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True return True
def filter_rooms(self, room_ids): def filter_rooms(self, room_ids):

View File

@@ -23,7 +23,7 @@ class Ratelimiter(object):
def __init__(self): def __init__(self):
self.message_counts = collections.OrderedDict() self.message_counts = collections.OrderedDict()
def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True): def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
"""Can the user send a message? """Can the user send a message?
Args: Args:
user_id: The user sending a message. user_id: The user sending a message.
@@ -32,15 +32,12 @@ class Ratelimiter(object):
second. second.
burst_count: How many messages the user can send before being burst_count: How many messages the user can send before being
limited. limited.
update (bool): Whether to update the message rates or not. This is
useful to check if a message would be allowed to be sent before
its ready to be actually sent.
Returns: Returns:
A pair of a bool indicating if they can send a message now and a A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message. time in seconds of when they can next send a message.
""" """
self.prune_message_counts(time_now_s) self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.get( message_count, time_start, _ignored = self.message_counts.pop(
user_id, (0., time_now_s, None), user_id, (0., time_now_s, None),
) )
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
@@ -55,10 +52,9 @@ class Ratelimiter(object):
allowed = True allowed = True
message_count += 1 message_count += 1
if update: self.message_counts[user_id] = (
self.message_counts[user_id] = ( message_count, time_start, msg_rate_hz
message_count, time_start, msg_rate_hz )
)
if msg_rate_hz > 0: if msg_rate_hz > 0:
time_allowed = ( time_allowed = (

View File

@@ -25,3 +25,4 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@@ -16,11 +16,13 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse import python_dependencies # noqa: E402 from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try: try:
python_dependencies.check_requirements() check_requirements()
except python_dependencies.MissingRequirementError as e: except MissingRequirementError as e:
message = "\n".join([ message = "\n".join([
"Missing Requirement: %s" % (e.message,), "Missing Requirement: %s" % (e.message,),
"To install run:", "To install run:",

View File

@@ -1,214 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse import events
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -1,220 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.crypto import context_factory
from synapse import events
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.client_reader")
class ClientReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
pass
class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
PublicRoomListRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse client reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse client reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.client_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-client-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -1,211 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory
from synapse import events
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
TransactionStore,
BaseSlavedStore,
):
pass
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -51,7 +51,6 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
@@ -285,7 +284,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
version_string = "Synapse/" + get_version_string(synapse) version_string = get_version_string("Synapse", synapse)
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)
@@ -336,8 +335,6 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs)
reactor.callWhenRunning(start) reactor.callWhenRunning(start)
return hs return hs
@@ -385,8 +382,6 @@ def run(hs):
start_time = hs.get_clock().time() start_time = hs.get_clock().time()
stats = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def phone_stats_home(): def phone_stats_home():
logger.info("Gathering stats for reporting") logger.info("Gathering stats for reporting")
@@ -395,10 +390,7 @@ def run(hs):
if uptime < 0: if uptime < 0:
uptime = 0 uptime = 0
# If the stats directory is empty then this is the first time we've stats = {}
# reported stats.
first_time = not stats
stats["homeserver"] = hs.config.server_name stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now stats["timestamp"] = now
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime
@@ -411,25 +403,6 @@ def run(hs):
daily_messages = yield hs.get_datastore().count_daily_messages() daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None: if daily_messages is not None:
stats["daily_messages"] = daily_messages stats["daily_messages"] = daily_messages
else:
stats.pop("daily_messages", None)
if first_time:
# Add callbacks to report the synapse stats as metrics whenever
# prometheus requests them, typically every 30s.
# As some of the stats are expensive to calculate we only update
# them when synapse phones home to matrix.org every 24 hours.
metrics = get_metrics_for("synapse.usage")
metrics.add_callback("timestamp", lambda: stats["timestamp"])
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
metrics.add_callback("total_users", lambda: stats["total_users"])
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
metrics.add_callback(
"daily_active_users", lambda: stats["daily_active_users"]
)
metrics.add_callback(
"daily_messages", lambda: stats.get("daily_messages", 0)
)
logger.info("Reporting stats to matrix.org: %s" % (stats,)) logger.info("Reporting stats to matrix.org: %s" % (stats,))
try: try:

View File

@@ -1,217 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from synapse import events
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -36,8 +36,6 @@ from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string from synapse.util.versionstring import get_version_string
from synapse import events
from twisted.internet import reactor, defer from twisted.internet import reactor, defer
from twisted.web.resource import Resource from twisted.web.resource import Resource
@@ -82,6 +80,11 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__ DataStore.get_profile_displayname.__func__
) )
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@@ -165,6 +168,7 @@ class PusherServer(HomeServer):
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool() pusher_pool = self.get_pusherpool()
clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey): def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
@@ -199,7 +203,7 @@ class PusherServer(HomeServer):
yield start_pusher(user_id, app_id, pushkey) yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events") stream = results.get("events")
if stream and stream["rows"]: if stream:
min_stream_id = stream["rows"][0][0] min_stream_id = stream["rows"][0][0]
max_stream_id = stream["position"] max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_notifications)( preserve_fn(pusher_pool.on_new_notifications)(
@@ -207,7 +211,7 @@ class PusherServer(HomeServer):
) )
stream = results.get("receipts") stream = results.get("receipts")
if stream and stream["rows"]: if stream:
rows = stream["rows"] rows = stream["rows"]
affected_room_ids = set(row[1] for row in rows) affected_room_ids = set(row[1] for row in rows)
min_stream_id = rows[0][0] min_stream_id = rows[0][0]
@@ -216,11 +220,21 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids min_stream_id, max_stream_id, affected_room_ids
) )
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
poke_pushers(result) poke_pushers(result)
except: except:
@@ -241,8 +255,6 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if config.start_pushers: if config.start_pushers:
sys.stderr.write( sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process" "\nThe pushers must be disabled in the main synapse process"
@@ -261,7 +273,7 @@ def start(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string=get_version_string("Synapse", synapse),
database_engine=database_engine, database_engine=database_engine,
) )
@@ -280,7 +292,6 @@ def start(config_options):
ps.replicate() ps.replicate()
ps.get_pusherpool().start() ps.get_pusherpool().start()
ps.get_datastore().start_profiling() ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@@ -26,9 +26,6 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@@ -38,8 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
@@ -76,11 +71,14 @@ class SynchrotronSlavedStore(
SlavedRegistrationStore, SlavedRegistrationStore,
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore,
RoomStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
): ):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = ( who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
@@ -91,23 +89,17 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[ get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted" "get_presence_list_accepted"
] ]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000 UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object): class SynchrotronPresence(object):
def __init__(self, hs): def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client() self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users" self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info() active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = { self.user_to_current_state = {
@@ -127,13 +119,11 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown) reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state, ignore_status_msg=False): def set_state(self, user, state):
# TODO Hows this supposed to work? # TODO Hows this supposed to work?
pass pass
get_states = PresenceHandler.get_states.__func__ get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__ current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -204,39 +194,19 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False self._need_to_send_sync = False
yield self._send_syncing_users_now() yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result): def process_replication(self, result):
stream = result.get("presence", {"rows": []}) stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]: for row in stream["rows"]:
( (
position, user_id, state, last_active_ts, position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) = row ) = row
state = UserPresenceState( self.user_to_current_state[user_id] = UserPresenceState(
user_id, state, last_active_ts, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg, last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active currently_active
) )
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object): class SynchrotronTyping(object):
@@ -246,9 +216,6 @@ class SynchrotronTyping(object):
self._room_typing = {} self._room_typing = {}
def stream_positions(self): def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial} return {"typing": self._latest_room_serial}
def process_replication(self, result): def process_replication(self, result):
@@ -299,14 +266,10 @@ class SynchrotronServer(HomeServer):
elif name == "client": elif name == "client":
resource = JsonResource(self, canonical_json=False) resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource) sync.register_servlets(self, resource)
events.register_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
RoomInitialSyncRestServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource, "/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource, "/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
}) })
root_resource = create_resource_tree(resources, Resource()) root_resource = create_resource_tree(resources, Resource())
@@ -344,10 +307,15 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client() http_client = self.get_simple_http_client()
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.worker_replication_url replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier() notifier = self.get_notifier()
presence_handler = self.get_presence_handler() presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler() typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream( def notify_from_stream(
result, stream_name, stream_key, room=None, user=None result, stream_name, stream_key, room=None, user=None
): ):
@@ -408,19 +376,23 @@ class SynchrotronServer(HomeServer):
notify_from_stream( notify_from_stream(
result, "typing", "typing_key", room="room_id" result, "typing", "typing_key", room="room_id"
) )
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
next_expire_broken_caches_ms = 0
while True: while True:
try: try:
args = store.stream_positions() args = store.stream_positions()
args.update(typing_handler.stream_positions()) args.update(typing_handler.stream_positions())
args["timeout"] = 30000 args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args) result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
yield presence_handler.process_replication(result) presence_handler.process_replication(result)
notify(result) notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)
@@ -446,15 +418,13 @@ def start(config_options):
setup_logging(config.worker_log_config, config.worker_log_file) setup_logging(config.worker_log_config, config.worker_log_file)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
ss = SynchrotronServer( ss = SynchrotronServer(
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string="Synapse/" + get_version_string(synapse), version_string=get_version_string("Synapse", synapse),
database_engine=database_engine, database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(), application_service_handler=SynchrotronApplicationService(),
) )
@@ -473,7 +443,6 @@ def start(config_options):
def start(): def start():
ss.get_datastore().start_profiling() ss.get_datastore().start_profiling()
ss.replicate() ss.replicate()
ss.get_state_handler().start_caching()
reactor.callWhenRunning(start) reactor.callWhenRunning(start)

View File

@@ -24,7 +24,7 @@ import subprocess
import sys import sys
import yaml import yaml
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m" RED = "\x1b[1;31m"

View File

@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging import logging
import re import re
@@ -81,7 +79,7 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None, rate_limited=True): sender=None, id=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
@@ -89,14 +87,6 @@ class ApplicationService(object):
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
# .protocols is a publicly visible field
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
self.rate_limited = rate_limited
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
# { # {
@@ -148,66 +138,65 @@ class ApplicationService(object):
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
@defer.inlineCallbacks def _matches_user(self, event, member_list):
def _matches_user(self, event, store): if (hasattr(event, "sender") and
if not event: self.is_interested_in_user(event.sender)):
defer.returnValue(False) return True
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key # also check m.room.member state key
if (event.type == EventTypes.Member and if (hasattr(event, "type") and event.type == EventTypes.Member
self.is_interested_in_user(event.state_key)): and hasattr(event, "state_key")
defer.returnValue(True) and self.is_interested_in_user(event.state_key)):
return True
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
# check joined member events # check joined member events
for user_id in member_list: for user_id in member_list:
if self.is_interested_in_user(user_id): if self.is_interested_in_user(user_id):
defer.returnValue(True) return True
defer.returnValue(False) return False
def _matches_room_id(self, event): def _matches_room_id(self, event):
if hasattr(event, "room_id"): if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
@defer.inlineCallbacks def _matches_aliases(self, event, alias_list):
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list: for alias in alias_list:
if self.is_interested_in_alias(alias): if self.is_interested_in_alias(alias):
defer.returnValue(True) return True
defer.returnValue(False) return False
@defer.inlineCallbacks def is_interested(self, event, restrict_to=None, aliases_for_event=None,
def is_interested(self, event, store=None): member_list=None):
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event(Event): The event to check.
store(DataStore) restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
Returns: Returns:
bool: True if this service would like to know about this event. bool: True if this service would like to know about this event.
""" """
# Do cheap checks first if aliases_for_event is None:
if self._matches_room_id(event): aliases_for_event = []
defer.returnValue(True) if member_list is None:
member_list = []
if (yield self._matches_aliases(event, store)): if restrict_to and restrict_to not in ApplicationService.NS_LIST:
defer.returnValue(True) # this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if (yield self._matches_user(event, store)): if not restrict_to:
defer.returnValue(True) return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
defer.returnValue(False) or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id): def is_interested_in_user(self, user_id):
return ( return (
@@ -227,17 +216,11 @@ class ApplicationService(object):
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
def is_exclusive_room(self, room_id): def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id) return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def is_rate_limited(self):
return self.rate_limited
def __str__(self): def __str__(self):
return "ApplicationService: %s" % (self.__dict__,) return "ApplicationService: %s" % (self.__dict__,)

View File

@@ -14,11 +14,9 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache
import logging import logging
import urllib import urllib
@@ -26,42 +24,6 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_metadata(info):
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient): class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and """This class manages HS -> AS communications, including querying and
pushing. pushing.
@@ -71,12 +33,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs) super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None response = None
try: try:
@@ -96,8 +54,6 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_alias(self, service, alias): def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None response = None
try: try:
@@ -115,84 +71,8 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex) logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
if service.url is None:
defer.returnValue([])
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
if service.url is None:
defer.returnValue({})
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
)
try:
info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning("query_3pe_protocol to %s did not return a"
" valid result", uri)
defer.returnValue(None)
defer.returnValue(info)
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
defer.returnValue(None)
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
if service.url is None:
defer.returnValue(True)
events = self._serialize(events) events = self._serialize(events)
if txn_id is None: if txn_id is None:

View File

@@ -48,12 +48,9 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from synapse.util.logcontext import preserve_fn from twisted.internet import defer
from synapse.util.metrics import Measure
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -76,7 +73,7 @@ class ApplicationServiceScheduler(object):
self.txn_ctrl = _TransactionController( self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer self.clock, self.store, self.as_api, create_recoverer
) )
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) self.queuer = _ServiceQueuer(self.txn_ctrl)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@@ -97,36 +94,38 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run. this schedules any other events in the queue to run.
""" """
def __init__(self, txn_ctrl, clock): def __init__(self, txn_ctrl):
self.queued_events = {} # dict of {service_id: [events]} self.queued_events = {} # dict of {service_id: [events]}
self.requests_in_flight = set() self.pending_requests = {} # dict of {service_id: Deferred}
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event): def enqueue(self, service, event):
# if this service isn't being sent something # if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event) if not self.pending_requests.get(service.id):
preserve_fn(self._send_request)(service) self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
@defer.inlineCallbacks def _send_request(self, service, events):
def _send_request(self, service): # send request and add callbacks
if service.id in self.requests_in_flight: d = self.txn_ctrl.send(service, events)
return d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
self.requests_in_flight.add(service.id) def _on_request_finish(self, service):
try: self.pending_requests[service.id] = None
while True: # if there are queued events, then send them.
events = self.queued_events.pop(service.id, []) if (service.id in self.queued_events
if not events: and len(self.queued_events[service.id]) > 0):
return self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
with Measure(self.clock, "servicequeuer.send"): def _on_request_fail(self, err):
try: logger.error("AS request failed: %s", err)
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
class _TransactionController(object): class _TransactionController(object):
@@ -150,12 +149,14 @@ class _TransactionController(object):
if service_is_up: if service_is_up:
sent = yield txn.send(self.as_api) sent = yield txn.send(self.as_api)
if sent: if sent:
yield txn.complete(self.store) txn.complete(self.store)
else: else:
preserve_fn(self._start_recoverer)(service) self._start_recoverer(service)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
preserve_fn(self._start_recoverer)(service) self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):

View File

@@ -64,12 +64,11 @@ class Config(object):
if isinstance(value, int) or isinstance(value, long): if isinstance(value, int) or isinstance(value, long):
return value return value
second = 1000 second = 1000
minute = 60 * second hour = 60 * 60 * second
hour = 60 * minute
day = 24 * hour day = 24 * hour
week = 7 * day week = 7 * day
year = 365 * day year = 365 * day
sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year} sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
size = 1 size = 1
suffix = value[-1] suffix = value[-1]
if suffix in sizes: if suffix in sizes:

View File

@@ -28,7 +28,6 @@ class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs): def default_config(cls, **kwargs):
return """\ return """\
@@ -86,7 +85,7 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename): def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [ required_string_fields = [
"id", "as_token", "hs_token", "sender_localpart" "id", "url", "as_token", "hs_token", "sender_localpart"
] ]
for field in required_string_fields: for field in required_string_fields:
if not isinstance(as_info.get(field), basestring): if not isinstance(as_info.get(field), basestring):
@@ -94,14 +93,6 @@ def _load_appservice(hostname, as_info, config_filename):
field, config_filename, field, config_filename,
)) ))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), basestring) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"] localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart: if urllib.quote(localpart) != localpart:
raise ValueError( raise ValueError(
@@ -110,11 +101,6 @@ def _load_appservice(hostname, as_info, config_filename):
user = UserID(localpart, hostname) user = UserID(localpart, hostname)
user_id = user.to_string() user_id = user.to_string()
# Rate limiting for users of this AS is on by default (excludes sender)
rate_limited = True
if isinstance(as_info.get("rate_limited"), bool):
rate_limited = as_info.get("rate_limited")
# namespace checks # namespace checks
if not isinstance(as_info.get("namespaces"), dict): if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.") raise KeyError("Requires 'namespaces' object.")
@@ -136,22 +122,6 @@ def _load_appservice(hostname, as_info, config_filename):
raise ValueError( raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj "Missing/bad type 'exclusive' key in %s", regex_obj
) )
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
if as_info["url"] is None:
logger.info(
"(%s) Explicitly empty 'url' provided. This application service"
" will not receive events or queries.",
config_filename,
)
return ApplicationService( return ApplicationService(
token=as_info["as_token"], token=as_info["as_token"],
url=as_info["url"], url=as_info["url"],
@@ -159,6 +129,4 @@ def _load_appservice(hostname, as_info, config_filename):
hs_token=as_info["hs_token"], hs_token=as_info["hs_token"],
sender=user_id, sender=user_id,
id=as_info["id"], id=as_info["id"],
protocols=protocols,
rate_limited=rate_limited
) )

View File

@@ -30,7 +30,7 @@ from .saml2 import SAML2Config
from .cas import CasConfig from .cas import CasConfig
from .password import PasswordConfig from .password import PasswordConfig
from .jwt import JWTConfig from .jwt import JWTConfig
from .password_auth_providers import PasswordAuthProviderConfig from .ldap import LDAPConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig from .workers import WorkerConfig
@@ -39,8 +39,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig, JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig,): WorkerConfig,):
pass pass

100
synapse/config/ldap.py Normal file
View File

@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
# verify dependencies are available
try:
import ldap3
ldap3 # to stop unused lint
except ImportError:
raise ConfigError(MISSING_LDAP3)
self.ldap_mode = LDAPMode.SIMPLE
# verify config sanity
self.require_keys(ldap_config, [
"uri",
"base",
"attributes",
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View File

@@ -1,66 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 Openmarket
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
import importlib
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
self.ldap_enabled = ldap_config.get("enabled", False)
if self.ldap_enabled:
from synapse.util.ldap_auth_provider import LdapAuthProvider
parsed_config = LdapAuthProvider.parse_config(ldap_config)
self.password_providers.append((LdapAuthProvider, parsed_config))
providers = config.get("password_providers", [])
for provider in providers:
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
module, clz = provider['module'].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
raise ConfigError(
"Failed to parse config for %r: %r" % (provider['module'], e)
)
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
return """\
# password_providers:
# - module: "synapse.util.ldap_auth_provider.LdapAuthProvider"
# config:
# enabled: true
# uri: "ldap://ldap.example.com:389"
# start_tls: true
# base: "ou=users,dc=example,dc=com"
# attributes:
# uid: "cn"
# mail: "email"
# name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
"""

View File

@@ -167,8 +167,6 @@ class ContentRepositoryConfig(Config):
# - '10.0.0.0/8' # - '10.0.0.0/8'
# - '172.16.0.0/12' # - '172.16.0.0/12'
# - '192.168.0.0/16' # - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
# #
# List of IP address CIDR ranges that the URL preview spider is allowed # List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist. # to access even if they are specified in url_preview_ip_range_blacklist.

View File

@@ -29,6 +29,7 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
@@ -141,6 +142,14 @@ class ServerConfig(Config):
# The GC threshold parameters to pass to `gc.set_threshold`, if defined # The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10] # gc_thresholds: [700, 10, 10]
# A list of other Home Servers to fetch the public room directory from
# and include in the public room directory of this home server
# This is a temporary stopgap solution to populate new server with a
# list of rooms until there exists a good solution of a decentralized
# room directory.
# secondary_directory_servers:
# - matrix.org
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:

View File

@@ -19,9 +19,6 @@ from OpenSSL import crypto
import subprocess import subprocess
import os import os
from hashlib import sha256
from unpaddedbase64 import encode_base64
GENERATE_DH_PARAMS = False GENERATE_DH_PARAMS = False
@@ -45,19 +42,6 @@ class TlsConfig(Config):
config.get("tls_dh_params_path"), "tls_dh_params" config.get("tls_dh_params_path"), "tls_dh_params"
) )
self.tls_fingerprints = config["tls_fingerprints"]
# Check that our own certificate is included in the list of fingerprints
# and include it if it is not.
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
# This config option applies to non-federation HTTP clients # This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such) # (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for # It should never be used in production, and is intended for
@@ -89,28 +73,6 @@ class TlsConfig(Config):
# Don't bind to the https port # Don't bind to the https port
no_tls: False no_tls: False
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds its the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it
# will be necessary to add the fingerprints of the certificates used by
# the loadbalancers to this list if they are different to the one
# synapse is using.
#
# Homeservers are permitted to cache the list of TLS fingerprints
# returned in the key responses up to the "valid_until_ts" returned in
# key. It may be necessary to publish the fingerprints of a new
# certificate and wait until the "valid_until_ts" of the previous key
# responses have passed before deploying it.
tls_fingerprints: []
# tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
""" % locals() """ % locals()
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):

View File

@@ -77,12 +77,10 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self): def __init__(self):
self.remote_key = defer.Deferred() self.remote_key = defer.Deferred()
self.host = None self.host = None
self._peer = None
def connectionMade(self): def connectionMade(self):
self._peer = self.transport.getPeer() self.host = self.transport.getHost()
logger.debug("Connected to %s", self._peer) logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host: if self.host:
self.sendHeader(b"Host", self.host) self.sendHeader(b"Host", self.host)
@@ -126,10 +124,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug( logger.debug("Timeout waiting for response from %s", self.host)
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response")) self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
@@ -138,5 +133,4 @@ class SynapseKeyClientFactory(Factory):
def protocol(self): def protocol(self):
protocol = SynapseKeyClientProtocol() protocol = SynapseKeyClientProtocol()
protocol.path = self.path protocol.path = self.path
protocol.host = self.host
return protocol return protocol

View File

@@ -22,7 +22,6 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn preserve_fn
) )
from synapse.util.metrics import Measure
from twisted.internet import defer from twisted.internet import defer
@@ -45,25 +44,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple("VerifyRequest", ( KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class KeyLookupError(ValueError):
pass
class Keyring(object): class Keyring(object):
@@ -93,32 +74,39 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each list of deferreds indicating success or failure to verify each
json object's signature for the given server_name. json object's signature for the given server_name.
""" """
verify_requests = [] group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json: for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
deferred = defer.fail(SynapseError( deferreds[group_id] = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
else: else:
deferred = defer.Deferred() deferreds[group_id] = defer.Deferred()
verify_request = VerifyKeyRequest( group = KeyGroup(server_name, group_id, key_ids)
server_name, key_ids, json_object, deferred
)
verify_requests.append(verify_request) group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_key_deferred(verify_request): def handle_key_deferred(group, deferred):
server_name = verify_request.server_name server_name = group.server_name
try: try:
_, key_id, verify_key = yield verify_request.deferred _, _, key_id, verify_key = yield deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@@ -140,7 +128,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = verify_request.json_object json_object = group_id_to_json[group.group_id]
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
@@ -169,34 +157,36 @@ class Keyring(object):
# Actually start fetching keys. # Actually start fetching keys.
wait_on_deferred.addBoth( wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(verify_requests) lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
) )
# When we've finished fetching all the keys for a given server_name, # When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed. # any lookups waiting will proceed.
server_to_request_ids = {} server_to_gids = {}
def remove_deferreds(res, server_name, verify_request): def remove_deferreds(res, server_name, group_id):
request_id = id(verify_request) server_to_gids[server_name].discard(group_id)
server_to_request_ids[server_name].discard(request_id) if not server_to_gids[server_name]:
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) d = server_to_deferred.pop(server_name, None)
if d: if d:
d.callback(None) d.callback(None)
return res return res
for verify_request in verify_requests: for g_id, deferred in deferreds.items():
server_name = verify_request.server_name server_name = group_id_to_group[g_id].server_name
request_id = id(verify_request) server_to_gids.setdefault(server_name, set()).add(g_id)
server_to_request_ids.setdefault(server_name, set()).add(request_id) deferred.addBoth(remove_deferreds, server_name, g_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
preserve_context_over_fn(handle_key_deferred, verify_request) preserve_context_over_fn(
for verify_request in verify_requests handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -230,7 +220,7 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for """Takes a dict of KeyGroups and tries to find at least one key for
each group. each group.
""" """
@@ -244,79 +234,76 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): merged_results = {}
merged_results = {}
missing_keys = {} missing_keys = {}
for verify_request in verify_requests: for group in group_id_to_group.values():
missing_keys.setdefault(verify_request.server_name, set()).update( missing_keys.setdefault(group.server_name, set()).update(
verify_request.key_ids group.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
) )
for server_name, groups in missing_groups.items()
}
for fn in key_fetch_fns: for group in missing_groups.values():
results = yield fn(missing_keys.items()) group_id_to_deferred[group.group_id].errback(SynapseError(
merged_results.update(results) 401,
"No key for %s with id %s" % (
# We now need to figure out which verify requests we have keys group.server_name, group.key_ids,
# for and which we don't ),
missing_keys = {} Codes.UNAUTHORIZED,
requests_missing_keys = [] ))
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext():
verify_request.deferred.callback((
server_name,
key_id,
result_keys[key_id],
))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys:
break
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err): def on_err(err):
for verify_request in verify_requests: for deferred in group_id_to_deferred.values():
if not verify_request.deferred.called: if not deferred.called:
verify_request.deferred.errback(err) deferred.errback(err)
do_iterations().addErrback(on_err) do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield preserve_context_over_deferred(defer.gatherResults( res = yield defer.gatherResults(
[ [
preserve_fn(self.store.get_server_verify_keys)( self.store.get_server_verify_keys(
server_name, key_ids server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name) ).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(dict(res)) defer.returnValue(dict(res))
@@ -337,13 +324,13 @@ class Keyring(object):
) )
defer.returnValue({}) defer.returnValue({})
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield defer.gatherResults(
[ [
preserve_fn(get_key)(p_name, p_keys) get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
union_of_keys = {} union_of_keys = {}
for result in results: for result in results:
@@ -369,7 +356,7 @@ class Keyring(object):
) )
except Exception as e: except Exception as e:
logger.info( logger.info(
"Unable to get key %r for %r directly: %s %s", "Unable to getting key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
@@ -383,13 +370,13 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield defer.gatherResults(
[ [
preserve_fn(get_key)(server_name, key_ids) get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
merged = {} merged = {}
for result in results: for result in results:
@@ -431,7 +418,7 @@ class Keyring(object):
for response in responses: for response in responses:
if (u"signatures" not in response if (u"signatures" not in response
or perspective_name not in response[u"signatures"]): or perspective_name not in response[u"signatures"]):
raise KeyLookupError( raise ValueError(
"Key response not signed by perspective server" "Key response not signed by perspective server"
" %r" % (perspective_name,) " %r" % (perspective_name,)
) )
@@ -454,21 +441,21 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]), list(response[u"signatures"][perspective_name]),
list(perspective_keys) list(perspective_keys)
) )
raise KeyLookupError( raise ValueError(
"Response not signed with a known key for perspective" "Response not signed with a known key for perspective"
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
processed_response = yield self.process_v2_response( processed_response = yield self.process_v2_response(
perspective_name, response, only_from_server=False perspective_name, response
) )
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys) keys.setdefault(server_name, {}).update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store_keys)( self.store_keys(
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
verify_keys=response_keys, verify_keys=response_keys,
@@ -476,7 +463,7 @@ class Keyring(object):
for server_name, response_keys in keys.items() for server_name, response_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@@ -497,10 +484,10 @@ class Keyring(object):
if (u"signatures" not in response if (u"signatures" not in response
or server_name not in response[u"signatures"]): or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server") raise ValueError("Key response not signed by remote server")
if "tls_fingerprints" not in response: if "tls_fingerprints" not in response:
raise KeyLookupError("Key response missing TLS fingerprints") raise ValueError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate( certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate crypto.FILETYPE_ASN1, tls_certificate
@@ -514,7 +501,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"]) response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints: if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise KeyLookupError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
@@ -524,7 +511,7 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store_keys)( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
@@ -534,13 +521,13 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items() for key_server_name, verify_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_ids=[], only_from_server=True): requested_ids=[]):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@@ -564,16 +551,9 @@ class Keyring(object):
results = {} results = {}
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise KeyLookupError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@@ -600,7 +580,7 @@ class Keyring(object):
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store.store_server_keys_json)( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
@@ -613,7 +593,7 @@ class Keyring(object):
for key_id in updated_key_ids for key_id in updated_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
results[server_name] = response_keys results[server_name] = response_keys
@@ -641,15 +621,15 @@ class Keyring(object):
if ("signatures" not in response if ("signatures" not in response
or server_name not in response["signatures"]): or server_name not in response["signatures"]):
raise KeyLookupError("Key response not signed by remote server") raise ValueError("Key response not signed by remote server")
if "tls_certificate" not in response: if "tls_certificate" not in response:
raise KeyLookupError("Key response missing TLS certificate") raise ValueError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"] tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64: if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise KeyLookupError("TLS certificate doesn't match") raise ValueError("TLS certificate doesn't match")
# Cache the result in the datastore. # Cache the result in the datastore.
@@ -665,7 +645,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]: for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]: if key_id not in response["verify_keys"]:
raise KeyLookupError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@@ -702,7 +682,7 @@ class Keyring(object):
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store.store_server_verify_key)( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
@@ -710,4 +690,4 @@ class Keyring(object):
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)

View File

@@ -99,7 +99,7 @@ class EventBase(object):
return d return d
def get(self, key, default=None): def get(self, key, default):
return self._event_dict.get(key, default) return self._event_dict.get(key, default)
def get_internal_metadata_dict(self): def get_internal_metadata_dict(self):

View File

@@ -15,30 +15,9 @@
class EventContext(object): class EventContext(object):
__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
"push_actions",
"prev_group",
"delta_ids",
"prev_state_events",
]
def __init__(self): def __init__(self, current_state=None):
# The current state including the current event self.current_state = current_state
self.current_state_ids = None
# The current state excluding the current event
self.prev_state_ids = None
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []
# A previously persisted state group and a delta between that
# and this state.
self.prev_group = None
self.delta_ids = None
self.prev_state_events = None

View File

@@ -88,8 +88,6 @@ def prune_event(event):
if "age_ts" in event.unsigned: if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)( return type(event)(
allowed_fields, allowed_fields,

View File

@@ -23,7 +23,6 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@@ -103,10 +102,10 @@ class FederationBase(object):
warn, pdu warn, pdu
) )
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( valid_pdus = yield defer.gatherResults(
deferreds, deferreds,
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@@ -130,7 +129,7 @@ class FederationBase(object):
for pdu in pdus for pdu in pdus
] ]
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json()) (p.origin, p.get_pdu_json())
for p in redacted_pdus for p in redacted_pdus
]) ])

View File

@@ -24,11 +24,10 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -52,35 +51,10 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationClient, self).__init__(hs) super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
self.state = hs.get_state_handler()
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
old_dict = self.pdu_destination_tried
self.pdu_destination_tried = {}
for event_id, destination_dict in old_dict.items():
destination_dict = {
dest: time
for dest, time in destination_dict.items()
if time + PDU_RETRY_TIME_MS > now
}
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self): def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache", cache_name="get_pdu_cache",
@@ -121,12 +95,8 @@ class FederationClient(FederationBase):
pdu.event_id pdu.event_id
) )
def send_presence(self, destination, states):
if destination != self.server_name:
self._transaction_queue.enqueue_presence(destination, states)
@log_function @log_function
def send_edu(self, destination, edu_type, content, key=None): def send_edu(self, destination, edu_type, content):
edu = Edu( edu = Edu(
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
@@ -136,13 +106,9 @@ class FederationClient(FederationBase):
sent_edus_counter.inc() sent_edus_counter.inc()
self._transaction_queue.enqueue_edu(edu, key=key) # TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
@log_function return defer.succeed(None)
def send_device_messages(self, destination):
"""Sends the device messages in the local database to the remote
destination"""
self._transaction_queue.enqueue_device_messages(destination)
@log_function @log_function
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
@@ -173,7 +139,7 @@ class FederationClient(FederationBase):
) )
@log_function @log_function
def query_client_keys(self, destination, content, timeout): def query_client_keys(self, destination, content):
"""Query device keys for a device hosted on a remote server. """Query device keys for a device hosted on a remote server.
Args: Args:
@@ -185,12 +151,10 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_device_keys") sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys( return self.transport_layer.query_client_keys(destination, content)
destination, content, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content, timeout): def claim_client_keys(self, destination, content):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
@@ -202,9 +166,7 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_one_time_keys") sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys( return self.transport_layer.claim_client_keys(destination, content)
destination, content, timeout
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -239,10 +201,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( pdus[:] = yield defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(pdus) defer.returnValue(pdus)
@@ -274,19 +236,12 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache: if self._get_pdu_cache:
ev = self._get_pdu_cache.get(event_id) e = self._get_pdu_cache.get(event_id)
if ev: if e:
defer.returnValue(ev) defer.returnValue(e)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu = None
signed_pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
if last_attempt + PDU_RETRY_TIME_MS > now:
continue
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
@@ -310,33 +265,39 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
pdu_attempts[destination] = now except SynapseError:
except SynapseError as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
) )
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
) )
continue continue
if self._get_pdu_cache is not None and signed_pdu: if self._get_pdu_cache is not None and pdu:
self._get_pdu_cache[event_id] = signed_pdu self._get_pdu_cache[event_id] = pdu
defer.returnValue(signed_pdu) defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -353,42 +314,6 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
try:
# First we try and ask for just the IDs, as thats far quicker if
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id,
)
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events(
[destination], room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
logger.warn("Failed to get %r", failed_to_fetch)
event_map = {
ev.event_id: ev for ev in fetched_events
}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue((pdus, auth_chain))
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
else:
raise e
result = yield self.transport_layer.get_room_state( result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id, destination, room_id, event_id=event_id,
) )
@@ -402,95 +327,18 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
seen_events = yield self.store.get_events([
ev.event_id for ev in itertools.chain(pdus, auth_chain)
])
signed_pdus = yield self._check_sigs_and_hash_and_fetch( signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, destination, pdus, outlier=True
[p for p in pdus if p.event_id not in seen_events],
outlier=True
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
) )
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, destination, auth_chain, outlier=True
[p for p in auth_chain if p.event_id not in seen_events],
outlier=True
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth)) defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
def get_events(self, destinations, room_id, event_ids, return_local=True):
"""Fetch events from some remote destinations, checking if we already
have them.
Args:
destinations (list)
room_id (str)
event_ids (list)
return_local (bool): Whether to include events we already have in
the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
signed_events = []
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
def random_server_list():
srvs = list(destinations)
random.shuffle(srvs)
return srvs
batch_size = 20
missing_events = list(missing_events)
for i in xrange(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
defer.returnValue((signed_events, failed_to_fetch))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):
@@ -566,19 +414,14 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )
break break
except CodeMessageException as e: except CodeMessageException:
if not 500 <= e.code < 600: raise
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_%s via %s: %s", "Failed to make_%s via %s: %s",
membership, destination, e.message membership, destination, e.message
) )
raise
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@@ -650,14 +493,8 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException as e: except CodeMessageException:
if not 500 <= e.code < 600: raise
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",
@@ -716,14 +553,24 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
def get_public_rooms(self, destination, limit=None, since_token=None, @defer.inlineCallbacks
search_filter=None): def get_public_rooms(self, destinations):
if destination == self.server_name: results_by_server = {}
return
return self.transport_layer.get_public_rooms( @defer.inlineCallbacks
destination, limit, since_token, search_filter def _get_result(s):
) if s == self.server_name:
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
@@ -814,8 +661,7 @@ class FederationClient(FederationBase):
if len(signed_events) >= limit: if len(signed_events) >= limit:
defer.returnValue(signed_events) defer.returnValue(signed_events)
users = yield self.state.get_current_user_in_room(room_id) servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
servers = set(servers) servers = set(servers)
servers.discard(self.server_name) servers.discard(self.server_name)
@@ -860,16 +706,14 @@ class FederationClient(FederationBase):
return srvs return srvs
deferreds = [ deferreds = [
preserve_fn(self.get_pdu)( self.get_pdu(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id, depth in ordered_missing[:limit - len(signed_events)] for e_id, depth in ordered_missing[:limit - len(signed_events)]
] ]
res = yield preserve_context_over_deferred( res = yield defer.DeferredList(deferreds, consumeErrors=True)
defer.DeferredList(deferreds, consumeErrors=True)
)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val: if result and val:
signed_events.append(val) signed_events.append(val)

View File

@@ -21,11 +21,10 @@ from .units import Transaction, Edu
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
@@ -49,15 +48,9 @@ class FederationServer(FederationBase):
def __init__(self, hs): def __init__(self, hs):
super(FederationServer, self).__init__(hs) super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer() self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer() self._server_linearizer = Linearizer()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are receipt of new PDUs from other home servers. The required methods are
@@ -188,76 +181,40 @@ class FederationServer(FederationBase):
except SynapseError as e: except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e) logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e: except Exception as e:
logger.exception("Failed to handle edu %r", edu_type) logger.exception("Failed to handle edu %r", edu_type, e)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, origin, room_id, event_id): def on_context_state_request(self, origin, room_id, event_id):
if not event_id: with (yield self._server_linearizer.queue((origin, room_id))):
raise NotImplementedError("Specify an event") if event_id:
pdus = yield self.handler.get_state_for_pdu(
in_room = yield self.auth.check_host_in_room(room_id, origin) origin, room_id, event_id,
if not in_room: )
raise AuthError(403, "Host not in room.") auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
) )
else:
resp = yield result
defer.returnValue((200, resp)) for event in auth_chain:
# We sign these again because there was a bug where we
@defer.inlineCallbacks # incorrectly signed things the first time round
def on_state_ids_request(self, origin, room_id, event_id): if self.hs.is_mine_id(event.event_id):
if not event_id: event.signatures.update(
raise NotImplementedError("Specify an event") compute_event_signature(
event,
in_room = yield self.auth.check_host_in_room(room_id, origin) self.hs.hostname,
if not in_room: self.hs.config.signing_key[0]
raise AuthError(403, "Host not in room.") )
)
state_ids = yield self.handler.get_state_ids_for_pdu( else:
room_id, event_id, raise NotImplementedError("Specify an event")
)
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
defer.returnValue((200, { defer.returnValue((200, {
"pdu_ids": state_ids,
"auth_chain_ids": auth_chain_ids,
}))
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus], "pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}) }))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -391,9 +348,27 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content) query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -603,7 +578,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
except: except:
logger.exception("Failed to get state for event: %s", pdu.event_id) logger.warn("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu( yield self.handler.on_receive_pdu(
origin, origin,

View File

@@ -17,16 +17,15 @@
from twisted.internet import defer from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import ( from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state
import synapse.metrics import synapse.metrics
import logging import logging
@@ -52,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.clock = hs.get_clock() self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track # Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are # of which destinations have transactions in flight and when they are
@@ -70,30 +69,20 @@ class TransactionQueue(object):
# destination -> list of tuple(edu, deferred) # destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {} self.pending_edus_by_dest = edus = {}
# Presence needs to be separate as we send single aggragate EDUs
self.pending_presence_by_dest = presence = {}
self.pending_edus_keyed_by_dest = edus_keyed = {}
metrics.register_callback( metrics.register_callback(
"pending_pdus", "pending_pdus",
lambda: sum(map(len, pdus.values())), lambda: sum(map(len, pdus.values())),
) )
metrics.register_callback( metrics.register_callback(
"pending_edus", "pending_edus",
lambda: ( lambda: sum(map(len, edus.values())),
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
+ sum(map(len, edus_keyed.values()))
),
) )
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self._clock.time_msec())
def can_send_to(self, destination): def can_send_to(self, destination):
"""Can we send messages to the given server? """Can we send messages to the given server?
@@ -130,69 +119,89 @@ class TransactionQueue(object):
if not destinations: if not destinations:
return return
deferreds = []
for destination in destinations: for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append( self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order) (pdu, deferred, order)
) )
preserve_context_over_fn( def chain(failure):
self._attempt_new_transaction, destination if not deferred.called:
) deferred.errback(failure)
def enqueue_presence(self, destination, states): def log_failure(f):
self.pending_presence_by_dest.setdefault(destination, {}).update({ logger.warn("Failed to send pdu to %s: %s", destination, f.value)
state.user_id: state for state in states
})
preserve_context_over_fn( deferred.addErrback(log_failure)
self._attempt_new_transaction, destination
)
def enqueue_edu(self, edu, key=None): with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
if key: deferred = defer.Deferred()
self.pending_edus_keyed_by_dest.setdefault( self.pending_edus_by_dest.setdefault(destination, []).append(
destination, {} (edu, deferred)
)[(edu.edu_type, key)] = edu
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination): def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
deferred = defer.Deferred()
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
self.pending_failures_by_dest.setdefault( self.pending_failures_by_dest.setdefault(
destination, [] destination, []
).append(failure) ).append(
(failure, deferred)
preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def enqueue_device_messages(self, destination): def chain(f):
if destination == self.server_name or destination == "localhost": if not deferred.called:
return deferred.errback(f)
if not self.can_send_to(destination): def log_failure(f):
return logger.warn("Failed to send failure to %s: %s", destination, f.value)
preserve_context_over_fn( deferred.addErrback(log_failure)
self._attempt_new_transaction, destination
) with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor()
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending
@@ -205,128 +214,55 @@ class TransactionQueue(object):
) )
return return
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
yield run_on_reactor()
while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
)
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter,
)
if not success:
break
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
@defer.inlineCallbacks
def _get_new_device_messages(self, destination):
last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
to_device_stream_id = self.store.get_to_device_stream_token()
contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
destination, last_device_stream_id, to_device_stream_id
)
edus = [
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.direct_to_device",
content=content,
)
for content in contents
]
defer.returnValue((edus, stream_id))
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id,
should_delete_from_device_stream, limiter):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
success = True
try:
logger.debug("TX [%s] _attempt_new_transaction", destination) logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
txn_id = str(self._next_txn_id) txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, txn_id, destination, txn_id,
len(pdus), len(pending_pdus),
len(edus), len(pending_edus),
len(failures) len(pending_failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new( transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()), origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id, transaction_id=txn_id,
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
@@ -345,9 +281,9 @@ class TransactionQueue(object):
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pdus), len(pending_pdus),
len(edus), len(pending_edus),
len(failures), len(pending_failures),
) )
with limiter: with limiter:
@@ -357,7 +293,7 @@ class TransactionQueue(object):
# keys work # keys work
def json_data_cb(): def json_data_cb():
data = transaction.get_dict() data = transaction.get_dict()
now = int(self.clock.time_msec()) now = int(self._clock.time_msec())
if "pdus" in data: if "pdus" in data:
for p in data["pdus"]: for p in data["pdus"]:
if "age_ts" in p: if "age_ts" in p:
@@ -397,19 +333,28 @@ class TransactionQueue(object):
logger.debug("TX [%s] Marked as delivered", destination) logger.debug("TX [%s] Marked as delivered", destination)
if code != 200: logger.debug("TX [%s] Yielding to callbacks...", destination)
for p in pdus:
logger.info( for deferred in deferreds:
"Failed to send event %s to %s", p.event_id, destination if code == 200:
) deferred.callback(None)
success = False else:
else: deferred.errback(RuntimeError("Got status %d" % code))
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream: # Ensures we don't continue until all callbacks on that
yield self.store.delete_device_msgs_for_remote( # deferred have fired
destination, device_stream_id try:
) yield deferred
self.last_device_stream_id_by_dest[destination] = device_stream_id except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
@@ -418,11 +363,6 @@ class TransactionQueue(object):
destination, destination,
e, e,
) )
success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e: except Exception as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
@@ -432,9 +372,13 @@ class TransactionQueue(object):
e, e,
) )
success = False for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
for p in pdus: finally:
logger.info("Failed to send event %s to %s", p.event_id, destination) # We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
defer.returnValue(success) # Check to see if there is anything else to send.
self._attempt_new_transaction(destination)

View File

@@ -54,28 +54,6 @@ class TransportLayerClient(object):
destination, path=path, args={"event_id": event_id}, destination, path=path, args={"event_id": event_id},
) )
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
destination (str): The host name of the remote home server we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = PREFIX + "/state_ids/%s/" % room_id
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@log_function @log_function
def get_event(self, destination, event_id, timeout=None): def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server. """ Requests the pdu with give id and origin from the given server.
@@ -248,22 +226,12 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_public_rooms(self, remote_server, limit, since_token, def get_public_rooms(self, remote_server):
search_filter=None):
path = PREFIX + "/publicRooms" path = PREFIX + "/publicRooms"
args = {}
if limit:
args["limit"] = [str(limit)]
if since_token:
args["since"] = [since_token]
# TODO(erikj): Actually send the search_filter across federation.
response = yield self.client.get_json( response = yield self.client.get_json(
destination=remote_server, destination=remote_server,
path=path, path=path,
args=args,
) )
defer.returnValue(response) defer.returnValue(response)
@@ -308,7 +276,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def query_client_keys(self, destination, query_content, timeout): def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
@@ -337,13 +305,12 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content, timeout): def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
Request: Request:
@@ -374,7 +341,6 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)

View File

@@ -18,16 +18,13 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import parse_json_object_from_request, parse_string
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
)
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
import functools import functools
import logging import logging
import simplejson as json
import re import re
import synapse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -63,16 +60,6 @@ class TransportLayerServer(JsonResource):
) )
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@@ -80,7 +67,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request, content): def authenticate_request(self, request):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@@ -88,11 +75,18 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
if content is not None: content = None
json_request["content"] = content
origin = None origin = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
def parse_auth_header(header_str): def parse_auth_header(header_str):
try: try:
params = auth.split(" ")[1].split(",") params = auth.split(" ")[1].split(",")
@@ -109,14 +103,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return (origin, key, sig) return (origin, key, sig)
except: except:
raise AuthenticationError( raise SynapseError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
raise NoAuthenticationError( raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@@ -127,7 +121,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]: if not json_request["signatures"]:
raise NoAuthenticationError( raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@@ -136,12 +130,10 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
defer.returnValue(origin) defer.returnValue((origin, content))
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name, def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler): room_list_handler):
self.handler = handler self.handler = handler
@@ -149,46 +141,29 @@ class BaseFederationServlet(object):
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler self.room_list_handler = room_list_handler
def _wrap(self, func): def _wrap(self, code):
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(func) @functools.wraps(code)
def new_func(request, *args, **kwargs): def new_code(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try: try:
origin = yield authenticator.authenticate_request(request, content) (origin, content) = yield authenticator.authenticate_request(request)
except NoAuthenticationError: with ratelimiter.ratelimit(origin) as d:
origin = None yield d
if self.REQUIRE_AUTH: response = yield code(
logger.exception("authenticate_request failed") origin, content, request.args, *args, **kwargs
raise )
except: except:
logger.exception("authenticate_request failed") logger.exception("authenticate_request failed")
raise raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response) defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish # Extra logic that functools.wraps() doesn't finish
new_func.__self__ = func.__self__ new_code.__self__ = code.__self__
return new_func return new_code
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$") pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -296,17 +271,6 @@ class FederationStateServlet(BaseFederationServlet):
) )
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/"
def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request(
origin,
room_id,
query.get("event_id", [None])[0],
)
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/"
@@ -403,8 +367,10 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, origin, content, query):
return self.handler.on_query_client_keys(origin, content) response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
@@ -454,10 +420,9 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, request):
content = parse_json_object_from_request(request)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
@@ -479,6 +444,11 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception raise last_exception
defer.returnValue((200, {})) defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet): class OpenIdUserInfo(BaseFederationServlet):
""" """
@@ -499,11 +469,9 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo" PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query): def on_GET(self, request):
token = query.get("access_token", [None])[0] token = parse_string(request, "access_token")
if token is None: if token is None:
defer.returnValue((401, { defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required" "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@@ -520,6 +488,11 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id})) defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class PublicRoomList(BaseFederationServlet): class PublicRoomList(BaseFederationServlet):
""" """
@@ -556,34 +529,15 @@ class PublicRoomList(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, origin, content, query): def on_GET(self, origin, content, query):
limit = parse_integer_from_args(query, "limit", 0) data = yield self.room_list_handler.get_local_public_room_list()
since_token = parse_string_from_args(query, "since", None)
data = yield self.room_list_handler.get_local_public_room_list(
limit, since_token
)
defer.returnValue((200, data)) defer.returnValue((200, data))
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
FederationStateIdsServlet,
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
@@ -601,7 +555,6 @@ SERVLET_CLASSES = (
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo, OpenIdUserInfo,
PublicRoomList, PublicRoomList,
FederationVersionServlet,
) )

View File

@@ -19,6 +19,7 @@ from .room import (
) )
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@@ -30,21 +31,10 @@ from .search import SearchHandler
class Handlers(object): class Handlers(object):
""" Deprecated. A collection of handlers. """ A collection of all the event handlers.
At some point most of the classes whose name ended "Handler" were There's no need to lazily create these; we'll just make them all eagerly
accessed through this class. at construction time.
However this makes it painful to unit test the handlers and to run cut
down versions of synapse that only use specific handlers because using a
single handler required creating all of the handlers. So some of the
handlers have been lifted out of the Handlers object and are now accessed
directly through the homeserver object itself.
Any new handlers should follow the new pattern of being accessed through
the homeserver object and should not be added to the Handlers object.
The remaining handlers should be moved out of the handlers object.
""" """
def __init__(self, hs): def __init__(self, hs):
@@ -52,6 +42,8 @@ class Handlers(object):
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)

View File

@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.constants import Membership, EventTypes
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from synapse.types import UserID from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,15 +31,11 @@ class BaseHandler(object):
Common base class for the event handlers. Common base class for the event handlers.
Attributes: Attributes:
store (synapse.storage.DataStore): store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler): state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@@ -55,20 +51,8 @@ class BaseHandler(object):
def ratelimit(self, requester): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
user_id = requester.user.to_string()
# The AS user itself is never rate limited.
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service is not None:
return # do not ratelimit app service senders
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
user_id, time_now, requester.user.to_string(), time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second, msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count, burst_count=self.hs.config.rc_message_burst_count,
) )
@@ -77,21 +61,33 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now)), retry_after_ms=int(1000 * (time_allowed - time_now)),
) )
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
self.hs.is_mine_id(state_key)
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks @defer.inlineCallbacks
def maybe_kick_guest_users(self, event, context=None): def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller. # Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden") guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join": if guest_access != "can_join":
if context:
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
current_state = current_state.values()
else:
current_state = yield self.store.get_current_state(event.room_id)
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) yield self.kick_guest_users(current_state)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -124,8 +120,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = synapse.types.create_requester( requester = Requester(target_user, "", True)
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership( yield handler.update_membership(
requester, requester,

View File

@@ -16,8 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.appservice import ApplicationService
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@@ -43,73 +42,36 @@ class ApplicationServicesHandler(object):
self.appservice_api = hs.get_application_service_api() self.appservice_api = hs.get_application_service_api()
self.scheduler = hs.get_application_service_scheduler() self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, current_id): def notify_interested_services(self, event):
"""Notifies (pushes) all application services interested in this event. """Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any Pushing is done asynchronously, so this method won't block for any
prolonged length of time. prolonged length of time.
Args: Args:
current_id(int): The current maximum ID. event(Event): The event to push out to interested services.
""" """
services = self.store.get_app_services() # Gather interested services
if not services or not self.notify_appservices: services = yield self._get_services_for_event(event)
return if len(services) == 0:
return # no services need notifying
self.current_max = max(self.current_max, current_id) # Do we know this user exists? If not, poke the user query API for
if self.is_processing: # all services which match that user regex. This needs to block as these
return # user queries need to be made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
with Measure(self.clock, "notify_interested_services"): if not self.started_scheduler:
self.is_processing = True self.scheduler.start().addErrback(log_failure)
try: self.started_scheduler = True
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
if not events: # Fork off pushes to these services
break for service in services:
self.scheduler.submit_event_for_as(service, event)
for event in events:
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
continue # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
for service in services:
preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user_exists(self, user_id): def query_user_exists(self, user_id):
@@ -142,12 +104,11 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
services = self.store.get_app_services() alias_query_services = yield self._get_services_for_event(
alias_query_services = [ event=None,
s for s in services if ( restrict_to=ApplicationService.NS_ALIASES,
s.is_interested_in_alias(room_alias_str) alias_list=[room_alias_str]
) )
]
for alias_service in alias_query_services: for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias( is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str alias_service, room_alias_str
@@ -160,93 +121,47 @@ class ApplicationServicesHandler(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields): def _get_services_for_event(self, event, restrict_to="", alias_list=None):
services = yield self._get_services_for_3pn(protocol)
results = yield preserve_context_over_deferred(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services
], consumeErrors=True))
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None):
services = self.store.get_app_services()
protocols = {}
# Collect up all the individual protocol responses out of the ASes
for s in services:
for p in s.protocols:
if only_protocol is not None and p != only_protocol:
continue
if p not in protocols:
protocols[p] = []
info = yield self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
if not infos:
return {}
# Merge the 'instances' lists of multiple results, but just take
# the other fields from the first as they ought to be identical
# copy the result so as not to corrupt the cached one
combined = dict(infos[0])
combined["instances"] = list(combined["instances"])
for info in infos[1:]:
combined["instances"].extend(info["instances"])
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols)
@defer.inlineCallbacks
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
Args: Args:
event(Event): The event to check. Can be None if alias_list is not. event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns: Returns:
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
services = self.store.get_app_services() member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
yield s.is_interested(event, self.store) s.is_interested(event, restrict_to, alias_list, member_list)
) )
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_user(self, user_id): def _get_services_for_user(self, user_id):
services = self.store.get_app_services() services = yield self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
s.is_interested_in_user(user_id) s.is_interested_in_user(user_id)
) )
] ]
return defer.succeed(interested_list) defer.returnValue(interested_list)
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
return defer.succeed(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
@@ -262,7 +177,7 @@ class ApplicationServicesHandler(object):
return return
# user not found; could be the AS though, so check. # user not found; could be the AS though, so check.
services = self.store.get_app_services() services = yield self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id] service_list = [s for s in services if s.sender == user_id]
defer.returnValue(len(service_list) == 0) defer.returnValue(len(service_list) == 0)

View File

@@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@@ -28,6 +29,12 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@@ -38,10 +45,6 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = { self.checkers = {
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
@@ -51,18 +54,25 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
account_handler = _AccountHandler( self.ldap_enabled = hs.config.ldap_enabled
hs, check_user_exists=self.check_user_exists if self.ldap_enabled:
) if not ldap3:
raise RuntimeError(
self.password_providers = [ 'Missing ldap3 library. This is required for LDAP Authentication.'
module(config=config, account_handler=account_handler) )
for module, config in hs.config.password_providers self.ldap_mode = hs.config.ldap_mode
] self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base
self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@@ -133,30 +143,13 @@ class AuthHandler(BaseHandler):
creds = session['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
errordict = {}
if 'type' in authdict: if 'type' in authdict:
login_type = authdict['type'] if authdict['type'] not in self.checkers:
if login_type not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED) raise LoginError(400, "", Codes.UNRECOGNIZED)
try: result = yield self.checkers[authdict['type']](authdict, clientip)
result = yield self.checkers[login_type](authdict, clientip) if result:
if result: creds[authdict['type']] = result
creds[login_type] = result self._save_session(session)
self._save_session(session)
except LoginError, e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
# validation token (thus sending a new email) each time it
# got a 401 with a 'flows' field.
# (https://github.com/vector-im/vector-web/issues/2447).
#
# Grandfather in the old behaviour for now to avoid
# breaking old riot deployments.
raise e
# this step failed. Merge the error dict into the response
# so that the client can have another go.
errordict = e.error_dict()
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
@@ -165,7 +158,6 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
ret.update(errordict)
defer.returnValue((False, ret, clientdict, session['id'])) defer.returnValue((False, ret, clientdict, session['id']))
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -238,6 +230,7 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -247,7 +240,11 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
return self._check_password(user_id, password) if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@@ -283,17 +280,8 @@ class AuthHandler(BaseHandler):
data = pde.response data = pde.response
resp_body = simplejson.loads(data) resp_body = simplejson.loads(data)
if 'success' in resp_body: if 'success' in resp_body and resp_body['success']:
# Note that we do NOT check the hostname here: we explicitly defer.returnValue(True)
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -360,189 +348,301 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
def validate_password_login(self, user_id, password): @defer.inlineCallbacks
def login_with_password(self, user_id, password):
""" """
Authenticates the user with their username and password. Authenticates the user with their username and password.
Used only by the v1 login API. Used only by the v1 login API.
Args: Args:
user_id (str): complete @user:id user_id (str): User ID
password (str): Password password (str): Password
Returns:
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args:
user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however, if not (yield self._check_password(user_id, password)):
# it's possible we raced against a DELETE operation. The thing we logger.warn("Failed password login for user %s", user_id)
# really don't want is active access_tokens without a record of the raise LoginError(403, "", errcode=Codes.FORBIDDEN)
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token)) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def get_login_tuple_for_user_id(self, user_id):
""" """
Checks to see if a user with the given id exists. Will check case Gets login tuple for the user with the given user ID.
insensitively, but return None if there are multiple inexact matches. The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
Args: Args:
(str) user_id: complete @user:id user_id (str): User ID
Returns: Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or A tuple of:
multiple matches The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
""" """
res = yield self._find_user_id_and_pwd_hash(user_id) user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
defer.returnValue(res[0]) logger.info("Logging in user %s", user_id)
defer.returnValue(None) access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case """Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact insensitively, but will throw if there are multiple inexact matches.
matches.
Returns: Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)` tuple: A 2-tuple of `(canonical_user_id, password_hash)`
None: if there is not exactly one match
""" """
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos: if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id) logger.warn("Attempted to login as %s but they do not exist", user_id)
elif len(user_infos) == 1: raise LoginError(403, "", errcode=Codes.FORBIDDEN)
# a single match (possibly not exact)
result = user_infos.popitem() if len(user_infos) > 1:
elif user_id in user_infos: if user_id not in user_infos:
# multiple matches, but one is exact logger.warn(
result = (user_id, user_infos[user_id]) "Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else: else:
# multiple matches, none of them exact defer.returnValue(user_infos.popitem())
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
"""Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if login fails
""" """
for provider in self.password_providers: Returns:
is_valid = yield provider.check_password(user_id, password) True if the user_id successfully authenticated
if is_valid: """
defer.returnValue(user_id) valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
defer.returnValue(True)
canonical_user_id = yield self._check_local_password(user_id, password) valid_local_password = yield self._check_local_password(user_id, password)
if valid_local_password:
defer.returnValue(True)
if canonical_user_id: defer.returnValue(False)
defer.returnValue(canonical_user_id)
# unknown username or invalid password. We raise a 403 here, but note
# that if we're doing user-interactive login, it turns all LoginErrors
# into a 401 anyway.
raise LoginError(
403, "Invalid password",
errcode=Codes.FORBIDDEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database. try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
user_id is checked case insensitively, but will return None if there are defer.returnValue(self.validate_hash(password, password_hash))
multiple inexact matches. except LoginError:
defer.returnValue(False)
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id, or None if unknown user / bad password
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
if self.ldap_mode not in LDAPMode.LIST:
raise RuntimeError(
'Invalid ldap mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
"Attempting ldap connection with %s",
self.ldap_uri
)
localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
# bind with the the local users ldap credentials
bind_dn = "{prop}={value},{base}".format(
prop=self.ldap_attributes['uid'],
value=localpart,
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established ldap connection in simple mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in simple mode through StartTLS: %s",
conn
)
conn.bind()
elif self.ldap_mode == LDAPMode.SEARCH:
# connect with preconfigured credentials and search for local user
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established ldap connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in search mode through StartTLS: %s",
conn
)
conn.bind()
# find matching dn
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug("ldap search filter: %s", query)
result = conn.search(self.ldap_base, query)
if result and len(conn.response) == 1:
# found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('ldap search found dn: %s', user_dn)
# unbind and reconnect, rebind with found dn
conn.unbind()
conn = ldap3.Connection(
server,
user_dn,
password,
auto_bind=True
)
else:
# found 0 or > 1 results, abort!
logger.warn(
"ldap search returned unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False)
logger.info(
"User authenticated against ldap server: %s",
conn
)
# check for existing account, if none exists, create one
if not (yield self.does_user_exist(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug("ldap registration filter: %s", query)
result = conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"ldap registration successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(result)
)
defer.returnValue(False)
defer.returnValue(True)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks
def issue_access_token(self, user_id):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token, yield self.store.add_access_token_to_user(user_id, access_token)
device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_refresh_token(self, user_id, device_id=None): def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id) refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token, yield self.store.add_refresh_token_to_user(user_id, refresh_token)
device_id)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None, def generate_access_token(self, user_id, extra_caveats=None):
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats: for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
@@ -572,14 +672,13 @@ class AuthHandler(BaseHandler):
return macaroon.serialize() return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon) auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True)
return user_id return self.get_user_from_macaroon(macaroon)
except Exception: except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@@ -590,11 +689,21 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_id = requester.access_token_id if requester else None except_access_token_ids = [requester.access_token_id] if requester else []
try: try:
yield self.store.user_set_password_hash(user_id, password_hash) yield self.store.user_set_password_hash(user_id, password_hash)
@@ -603,26 +712,14 @@ class AuthHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_id user_id, except_access_token_ids
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id user_id, except_access_token_ids
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at): def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
# We've now moving towards the Home Server being the entity that
# is responsible for validating threepids used for resetting passwords
# on accounts, so in future Synapse will gain knowledge of specific
# types (mediums) of threepid. For now, we still use the existing
# infrastructure, but this is the start of synapse gaining knowledge
# of specific types of threepid (and fixes the fact that checking
# for the presenc eof an email address during password reset was
# case sensitive).
if medium == 'email':
address = address.lower()
yield self.store.user_add_threepid( yield self.store.user_add_threepid(
user_id, medium, address, validated_at, user_id, medium, address, validated_at,
self.hs.get_clock().time_msec() self.hs.get_clock().time_msec()
@@ -653,7 +750,7 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Hashed password (str). Hashed password (str).
""" """
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, return bcrypt.hashpw(password + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds)) bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
@@ -671,30 +768,3 @@ class AuthHandler(BaseHandler):
stored_hash.encode('utf-8')) == stored_hash stored_hash.encode('utf-8')) == stored_hash
else: else:
return False return False
class _AccountHandler(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
def __init__(self, hs, check_user_exists):
self.hs = hs
self._check_user_exists = check_user_exists
def check_user_exists(self, user_id):
"""Check if user exissts.
Returns:
Deferred(bool)
"""
return self._check_user_exists(user_id)
def register(self, localpart):
"""Registers a new user with given localpart
Returns:
Deferred: a 2-tuple of (user_id, access_token)
"""
reg = self.hs.get_handlers().registration_handler
return reg.register(localpart=localpart)

View File

@@ -1,181 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api import errors
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string(10).upper()
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys())
)
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),)
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})

View File

@@ -1,117 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
origin, sender_user_id
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
@defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
message_id = random_string(16)
remote_edu_contents = {}
for destination, messages in remote_messages.items():
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
}
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation.send_device_messages(destination)

View File

@@ -19,7 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias, UserID, get_domain_from_id from synapse.types import RoomAlias, UserID
import logging import logging
import string import string
@@ -55,8 +55,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions. # TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association. # TODO(erikj): Check if there is a current association.
if not servers: if not servers:
users = yield self.state.get_current_user_in_room(room_id) servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
if not servers: if not servers:
raise SynapseError(400, "Failed to get server list") raise SynapseError(400, "Failed to get server list")
@@ -194,8 +193,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND Codes.NOT_FOUND
) )
users = yield self.state.get_current_user_in_room(room_id) extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
extra_servers = set(get_domain_from_id(u) for u in users)
servers = set(extra_servers) | set(servers) servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.
@@ -288,12 +286,13 @@ class DirectoryHandler(BaseHandler):
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None): def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on # Any application service "interested" in an alias they are regexing on
# can modify the alias. # can modify the alias.
# Users can only modify the alias if ALL the interested services have # Users can only modify the alias if ALL the interested services have
# non-exclusive locks on the alias (or there are no interested services) # non-exclusive locks on the alias (or there are no interested services)
services = self.store.get_app_services() services = yield self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string()) s for s in services if s.is_interested_in_alias(alias.to_string())
] ]
@@ -301,12 +300,14 @@ class DirectoryHandler(BaseHandler):
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender:
# this user IS the app service so they can do whatever they like # this user IS the app service so they can do whatever they like
return defer.succeed(True) defer.returnValue(True)
return
elif service.is_exclusive_alias(alias.to_string()): elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias. # another service has an exclusive lock on this alias.
return defer.succeed(False) defer.returnValue(False)
return
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
return defer.succeed(True) defer.returnValue(True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _user_can_delete_alias(self, alias, user_id): def _user_can_delete_alias(self, alias, user_id):

View File

@@ -1,279 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ujson as json
import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks
def query_devices(self, query_body, timeout):
""" Handle a device key query from a client
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
local_query = {}
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries
failures = {}
results = {}
if local_query:
local_result = yield self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
@defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
]))
defer.returnValue({
"device_keys": results, "failures": failures,
})
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})

View File

@@ -47,7 +47,6 @@ class EventStreamHandler(BaseHandler):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -91,7 +90,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence. # Send down presence.
if event.state_key == auth_user_id: if event.state_key == auth_user_id:
# Send down presence for everyone in the room. # Send down presence for everyone in the room.
users = yield self.state.get_current_user_in_room(event.room_id) users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
users, users,
as_event=True, as_event=True,

View File

@@ -26,10 +26,7 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import ( from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
@@ -101,9 +98,6 @@ class FederationHandler(BaseHandler):
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
auth_chain and state are None if we already have the necessary state
and prev_events in the db
""" """
event = pdu event = pdu
@@ -121,25 +115,16 @@ class FederationHandler(BaseHandler):
# FIXME (erikj): Awful hack to make the case where we are not currently # FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work # in the room work
# If state and auth_chain are None, then we don't need to do this check is_in_room = yield self.auth.check_host_in_room(
# as we already know we have enough state in the DB to handle this event.room_id,
# event. self.server_name
if state and auth_chain and not event.internal_metadata.is_outlier(): )
is_in_room = yield self.auth.check_host_in_room( if not is_in_room and not event.internal_metadata.is_outlier():
event.room_id, logger.debug("Got event for room we're not in.")
self.server_name
)
else:
is_in_room = True
if not is_in_room:
logger.info(
"Got event for room we're not in: %r %r",
event.room_id, event.event_id
)
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event auth_chain, state, event
) )
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
@@ -230,28 +215,17 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally prev_state = context.current_state.get((event.type, event.state_key))
# joined the room. Don't bother if the user is just if not prev_state or prev_state.membership != Membership.JOIN:
# changing their profile info. # Only fire user_joined_room if the user has acutally
newly_joined = True # joined the room. Don't bother if the user is just
prev_state_id = context.prev_state_ids.get( # changing their profile info.
(event.type, event.state_key)
)
if prev_state_id:
prev_state = yield self.store.get_event(
prev_state_id, allow_none=True,
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
event_to_state_ids = yield self.store.get_state_ids_for_events( event_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@@ -259,30 +233,6 @@ class FederationHandler(BaseHandler):
) )
) )
# We only want to pull out member events that correspond to the
# server's domain.
def check_match(id):
try:
return server_name == get_domain_from_id(id)
except:
return False
event_map = yield self.store.get_events([
e_id for key_to_eid in event_to_state_ids.values()
for key, e_id in key_to_eid
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.items()
}
def redact_disallowed(event, state): def redact_disallowed(event, state):
if not state: if not state:
return event return event
@@ -299,7 +249,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
domain = get_domain_from_id(ev.state_key) domain = UserID.from_string(ev.state_key).domain
except: except:
continue continue
@@ -324,7 +274,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return This will attempt to get more events from the remote. This may return
@@ -334,6 +284,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill( events = yield self.replication_layer.backfill(
dest, dest,
room_id, room_id,
@@ -382,61 +335,32 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state events_to_state[e_id] = state
required_auth = set(
a_id
for event in events + state_events.values() + auth_events.values()
for a_id, _ in event.auth_events
)
auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
})
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch
)
results = yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.replication_layer.get_pdu)(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True
)).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
for event in results if event
for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_events( seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth:
logger.info("Missing auth for backfill: %r", missing_auth)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = [] ev_infos = []
for a in auth_events.values(): for a in auth_events.values():
if a.event_id in seen_events: if a.event_id in seen_events:
@@ -448,7 +372,6 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in a.auth_events for a_id, _ in a.auth_events
if a_id in auth_events
} }
}) })
@@ -460,7 +383,6 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events for a_id, _ in event_map[e_id].auth_events
if a_id in auth_events
} }
}) })
@@ -504,10 +426,6 @@ class FederationHandler(BaseHandler):
) )
max_depth = sorted_extremeties_tuple[0][1] max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth: if current_depth > max_depth:
logger.debug( logger.debug(
"Not backfilling as we don't need to. %d < %d", "Not backfilling as we don't need to. %d < %d",
@@ -604,24 +522,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
states = yield preserve_context_over_deferred(defer.gatherResults([ states = yield defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) self.state_handler.resolve_state_groups(room_id, [e])
for e in event_ids for e in event_ids
])) ])
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
} for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
@@ -731,7 +637,7 @@ class FederationHandler(BaseHandler):
pass pass
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event auth_chain, state, event
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
@@ -782,9 +688,7 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e) logger.warn("Failed to create join %r because %s", event, e)
raise e raise e
# The remote hasn't signed it yet, obviously. We'll do the full checks self.auth.check(event, auth_events=context.current_state)
# when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
defer.returnValue(event) defer.returnValue(event)
@@ -832,12 +736,17 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context) destinations = set()
destinations = set( for k, s in context.current_state.items():
get_domain_from_id(user_id) for user_id in users_in_room try:
if not self.hs.is_mine_id(user_id) if k[0] == EventTypes.Member:
) if s.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin) destinations.discard(origin)
@@ -849,15 +758,13 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations) self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = context.prev_state_ids.values() state_ids = [e.event_id for e in context.current_state.values()]
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids [event.event_id] + state_ids
)) ))
state = yield self.store.get_events(context.prev_state_ids.values())
defer.returnValue({ defer.returnValue({
"state": state.values(), "state": context.current_state.values(),
"auth_chain": auth_chain, "auth_chain": auth_chain,
}) })
@@ -1011,9 +918,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
# The remote hasn't signed it yet, obviously. We'll do the full checks self.auth.check(event, auth_events=context.current_state)
# when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@@ -1057,12 +962,18 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context) destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(get_domain_from_id(s.state_key))
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@@ -1076,11 +987,14 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
"""Returns the state at the event. i.e. not including said event.
"""
yield run_on_reactor() yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, [event_id] room_id, [event_id]
) )
@@ -1119,34 +1033,6 @@ class FederationHandler(BaseHandler):
else: else:
defer.returnValue([]) defer.returnValue([])
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
yield run_on_reactor()
state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id]
)
if state_groups:
_, state = state_groups.items().pop()
results = state
event = yield self.store.get_event(event_id)
if event and event.is_state():
# Get previous state
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id
else:
del results[(event.type, event.state_key)]
defer.returnValue(results.values())
else:
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit): def on_backfill_request(self, origin, room_id, pdu_list, limit):
@@ -1179,17 +1065,16 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
if self.hs.is_mine_id(event.event_id): # FIXME: This is a temporary work around where we occasionally
# FIXME: This is a temporary work around where we occasionally # return events slightly differently than when they were
# return events slightly differently than when they were # originally signed
# originally signed event.signatures.update(
event.signatures.update( compute_event_signature(
compute_event_signature( event,
event, self.hs.hostname,
self.hs.hostname, self.hs.config.signing_key[0]
self.hs.config.signing_key[0]
)
) )
)
if do_auth: if do_auth:
in_room = yield self.auth.check_host_in_room( in_room = yield self.auth.check_host_in_room(
@@ -1199,12 +1084,6 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server(
origin, event.room_id, [event]
)
event = events[0]
defer.returnValue(event) defer.returnValue(event)
else: else:
defer.returnValue(None) defer.returnValue(None)
@@ -1235,12 +1114,11 @@ class FederationHandler(BaseHandler):
backfilled=backfilled, backfilled=backfilled,
) )
if not backfilled: # this intentionally does not yield: we don't care about the result
# this intentionally does not yield: we don't care about the result # and don't need to wait for it.
# and don't need to wait for it. preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
preserve_fn(self.hs.get_pusherpool().on_new_notifications)( event_stream_id, max_stream_id
event_stream_id, max_stream_id )
)
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1251,9 +1129,9 @@ class FederationHandler(BaseHandler):
a bunch of outliers, but not a chunk of individual events that depend a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations. on each other for state calculations.
""" """
contexts = yield preserve_context_over_deferred(defer.gatherResults( contexts = yield defer.gatherResults(
[ [
preserve_fn(self._prep_event)( self._prep_event(
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
@@ -1261,7 +1139,7 @@ class FederationHandler(BaseHandler):
) )
for ev_info in event_infos for ev_info in event_infos
] ]
)) )
yield self.store.persist_events( yield self.store.persist_events(
[ [
@@ -1272,19 +1150,11 @@ class FederationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_auth_tree(self, origin, auth_events, state, event): def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event seperately. Persists the event seperately.
Will attempt to fetch missing auth events.
Args:
origin (str): Where the events came from
auth_events (list)
state (list)
event (Event)
Returns: Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event 2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event` call for `event`
@@ -1297,7 +1167,7 @@ class FederationHandler(BaseHandler):
event_map = { event_map = {
e.event_id: e e.event_id: e
for e in itertools.chain(auth_events, state, [event]) for e in auth_events
} }
create_event = None create_event = None
@@ -1306,29 +1176,10 @@ class FederationHandler(BaseHandler):
create_event = e create_event = e
break break
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events:
if e_id not in event_map:
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu(
[origin],
e_id,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
else:
logger.info("Failed to find auth event %r", e_id)
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
auth_for_e = { auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events for e_id, _ in e.auth_events
if e_id in event_map
} }
if create_event: if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event auth_for_e[(EventTypes.Create, "")] = create_event
@@ -1377,13 +1228,7 @@ class FederationHandler(BaseHandler):
) )
if not auth_events: if not auth_events:
auth_events_ids = yield self.auth.compute_auth_events( auth_events = context.current_state
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
@@ -1409,7 +1254,8 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
yield self.maybe_kick_guest_users(event) full_context = yield self.store.get_current_state(room_id=event.room_id)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context) defer.returnValue(context)
@@ -1477,11 +1323,6 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event.is_state():
event_key = (event.type, event.state_key)
else:
event_key = None
if event_auth_events - current_state: if event_auth_events - current_state:
have_events = yield self.store.have_events( have_events = yield self.store.have_events(
event_auth_events - current_state event_auth_events - current_state
@@ -1555,9 +1396,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield preserve_context_over_deferred(defer.gatherResults( different_events = yield defer.gatherResults(
[ [
preserve_fn(self.store.get_event)( self.store.get_event(
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,
@@ -1566,7 +1407,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)
@@ -1585,16 +1426,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state_ids = dict(context.current_state_ids) context.current_state.update(auth_events)
context.current_state_ids.update({ context.state_group = None
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth) logger.info("Different auth after resolution: %s", different_auth)
@@ -1615,8 +1448,8 @@ class FederationHandler(BaseHandler):
if do_resolution: if do_resolution:
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events( auth_ids = self.auth.compute_auth_events(
event, context.prev_state_ids event, context.current_state
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@@ -1672,16 +1505,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state_ids = dict(context.current_state_ids) context.current_state.update(auth_events)
context.current_state_ids.update({ context.state_group = None
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
@@ -1867,12 +1692,12 @@ class FederationHandler(BaseHandler):
) )
try: try:
yield self.auth.check_from_context(event, context) self.auth.check(event, context.current_state)
except AuthError as e: except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e) logger.warn("Denying new third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, context) yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
@@ -1898,11 +1723,11 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check_from_context(event, context) self.auth.check(event, auth_events=context.current_state)
except AuthError as e: except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e) logger.warn("Denying third party invite %r because %s", event, e)
raise e raise e
yield self._check_signature(event, context) yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
@@ -1916,24 +1741,16 @@ class FederationHandler(BaseHandler):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
original_invite = None original_invite = context.current_state.get(key)
original_invite_id = context.prev_state_ids.get(key) if not original_invite:
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
if original_invite:
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info( logger.info(
"Could not find invite event for third_party_invite: %r", "Could not find invite event for third_party_invite - "
event_dict "discarding: %s" % (event_dict,)
) )
# We don't discard here as this is not the appropriate place to do return
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
@@ -1941,13 +1758,13 @@ class FederationHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_signature(self, event, context): def _check_signature(self, event, auth_events):
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
Args: Args:
event (Event): The m.room.member event to check event (Event): The m.room.member event to check
context (EventContext): auth_events (dict<(event type, state_key), event>):
Raises: Raises:
AuthError: if signature didn't match any keys, or key has been AuthError: if signature didn't match any keys, or key has been
@@ -1958,14 +1775,10 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
token = signed["token"] token = signed["token"]
invite_event_id = context.prev_state_ids.get( invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
invite_event = None
if invite_event_id:
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
if not invite_event: if not invite_event:
raise AuthError(403, "Could not find invite") raise AuthError(403, "Could not find invite")

View File

@@ -1,443 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, StreamToken,
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler):
def __init__(self, hs):
super(InitialSyncHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
return
except AuthError:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)

View File

@@ -16,17 +16,19 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError, LimitExceededError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.streams.config import PaginationConfig
from synapse.types import ( from synapse.types import (
UserID, RoomAlias, RoomStreamToken, get_domain_from_id UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
) )
from synapse.util.async import run_on_reactor, ReadWriteLock from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@@ -34,7 +36,6 @@ from ._base import BaseHandler
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -47,6 +48,7 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
@@ -64,7 +66,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True, event_filter=None): as_client_event=True):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@@ -73,18 +75,18 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
else: else:
pagin_config.from_token = ( pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token_for_room( yield self.hs.get_event_sources().get_current_token(
room_id=room_id direction='b'
) )
) )
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@@ -127,13 +129,8 @@ class MessageHandler(BaseHandler):
room_id, max_topo room_id, max_topo
) )
events, next_key = yield self.store.paginate_room_events( events, next_key = yield data_source.get_pagination_rows(
room_id=room_id, requester.user, source_config, room_id
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
event_filter=event_filter,
) )
next_token = pagin_config.from_token.copy_and_replace( next_token = pagin_config.from_token.copy_and_replace(
@@ -147,9 +144,6 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, self.store,
user_id, user_id,
@@ -170,6 +164,101 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def get_files(self, requester, room_id, pagin_config):
"""Get files in a room.
Args:
requester (Requester): The user requesting files.
room_id (str): The room they want files from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
)
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id
)
if source_config.direction == 'b':
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token(
room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
events, next_key = yield self.store.paginate_room_file_events(
room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield filter_events_for_client(
self.store,
user_id,
events,
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now)
for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None): def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
""" """
@@ -240,27 +329,12 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath" "Tried to send member event through non-member codepath"
) )
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message(
event.sender, time_now,
msg_rate_hz=self.hs.config.rc_messages_per_second,
burst_count=self.hs.config.rc_message_burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now)),
)
user = UserID.from_string(event.sender) user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state(): if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context) prev_state = self.deduplicate_state_event(event, context)
if prev_state is not None: if prev_state is not None:
defer.returnValue(prev_state) defer.returnValue(prev_state)
@@ -275,7 +349,6 @@ class MessageHandler(BaseHandler):
presence = self.hs.get_presence_handler() presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user) yield presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context): def deduplicate_state_event(self, event, context):
""" """
Checks whether event is in the latest resolved state in context. Checks whether event is in the latest resolved state in context.
@@ -283,17 +356,13 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) prev_event = context.current_state.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id: if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content) prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content) next_content = encode_canonical_json(event.content)
if prev_content == next_content: if prev_content == next_content:
defer.returnValue(prev_event) return prev_event
return return None
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_nonmember_event( def create_and_send_nonmember_event(
@@ -404,7 +473,375 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()] [serialize_event(c, now) for c in room_state.values()]
) )
@measure_func("_create_new_client_event") def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield defer.gatherResults(
[
self.store.get_recent_events_for_room(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
get_presence(),
get_receipts(),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None): def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
@@ -416,20 +853,6 @@ class MessageHandler(BaseHandler):
builder.room_id, builder.room_id,
) )
# We want to limit the max number of prev events we point to in our
# new event
if len(latest_ret) > 10:
# Sort by reverse depth, so we point to the most recent.
latest_ret.sort(key=lambda a: -a[2])
new_latest_ret = latest_ret[:5]
# We also randomly point to some of the older events, to make
# sure that we don't completely ignore the older events.
if latest_ret[5:]:
sample_size = min(5, len(latest_ret[5:]))
new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
latest_ret = new_latest_ret
if latest_ret: if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1 depth = max([d for _, _, d in latest_ret]) + 1
else: else:
@@ -462,15 +885,14 @@ class MessageHandler(BaseHandler):
event = builder.build() event = builder.build()
logger.debug( logger.debug(
"Created event %s with state: %s", "Created event %s with current state: %s",
event.event_id, context.prev_state_ids, event.event_id, context.current_state,
) )
defer.returnValue( defer.returnValue(
(event, context,) (event, context,)
) )
@measure_func("handle_new_client_event")
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event( def handle_new_client_event(
self, self,
@@ -486,12 +908,12 @@ class MessageHandler(BaseHandler):
self.ratelimit(requester) self.ratelimit(requester)
try: try:
yield self.auth.check_from_context(event, context) self.auth.check(event, auth_events=context.current_state)
except AuthError as err: except AuthError as err:
logger.warn("Denying new event %r because %s", event, err) logger.warn("Denying new event %r because %s", event, err)
raise err raise err
yield self.maybe_kick_guest_users(event, context) yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)
@@ -519,15 +941,6 @@ class MessageHandler(BaseHandler):
e.sender == event.sender e.sender == event.sender
) )
state_to_include_ids = [
e_id
for k, e_id in context.current_state_ids.items()
if k[0] in self.hs.config.room_invite_state_types
or k[0] == EventTypes.Member and k[1] == event.sender
]
state_to_include = yield self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [ event.unsigned["invite_room_state"] = [
{ {
"type": e.type, "type": e.type,
@@ -535,7 +948,9 @@ class MessageHandler(BaseHandler):
"content": e.content, "content": e.content,
"sender": e.sender, "sender": e.sender,
} }
for e in state_to_include.values() for k, e in context.current_state.items()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)
@@ -557,14 +972,7 @@ class MessageHandler(BaseHandler):
) )
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
auth_events_ids = yield self.auth.compute_auth_events( if self.auth.check_redaction(event, auth_events=context.current_state):
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
if self.auth.check_redaction(event, auth_events=auth_events):
original_event = yield self.store.get_event( original_event = yield self.store.get_event(
event.redacts, event.redacts,
check_redacted=False, check_redacted=False,
@@ -578,7 +986,7 @@ class MessageHandler(BaseHandler):
"You don't have permission to redact events" "You don't have permission to redact events"
) )
if event.type == EventTypes.Create and context.prev_state_ids: if event.type == EventTypes.Create and context.current_state:
raise AuthError( raise AuthError(
403, 403,
"Changing the room create event is forbidden", "Changing the room create event is forbidden",
@@ -599,17 +1007,21 @@ class MessageHandler(BaseHandler):
event_stream_id, max_stream_id event_stream_id, max_stream_id
) )
users_in_room = yield self.store.get_joined_users_from_context(event, context) destinations = set()
for k, s in context.current_state.items():
destinations = [ try:
get_domain_from_id(user_id) for user_id in users_in_room if k[0] == EventTypes.Member:
if not self.hs.is_mine_id(user_id) if s.content["membership"] == Membership.JOIN:
] destinations.add(get_domain_from_id(s.state_key))
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def _notify():
yield run_on_reactor() yield run_on_reactor()
yield self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
@@ -619,6 +1031,6 @@ class MessageHandler(BaseHandler):
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)
preserve_fn(federation_handler.handle_new_event)( federation_handler.handle_new_event(
event, destinations=destinations, event, destinations=destinations,
) )

View File

@@ -52,11 +52,6 @@ bump_active_time_counter = metrics.register_counter("bump_active_time")
get_updates_counter = metrics.register_counter("get_updates", labels=["type"]) get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
state_transition_counter = metrics.register_counter(
"state_transition", labels=["from", "to"]
)
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active" # "currently_active"
@@ -93,8 +88,6 @@ class PresenceHandler(object):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.state = hs.get_state_handler()
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
@@ -196,13 +189,6 @@ class PresenceHandler(object):
5000, 5000,
) )
self.clock.call_later(
60,
self.clock.looping_call,
self._persist_unpersisted_changes,
60 * 1000,
)
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -217,7 +203,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
logger.info( logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes", "Performing _on_shutdown. Persiting %d unpersisted changes",
len(self.user_to_current_state) len(self.user_to_current_state)
) )
@@ -228,27 +214,6 @@ class PresenceHandler(object):
]) ])
logger.info("Finished _on_shutdown") logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
logger.info(
"Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_states(self, new_states): def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes """Updates presence of users. Sets the appropriate timeouts. Pokes
@@ -265,12 +230,6 @@ class PresenceHandler(object):
to_notify = {} # Changes we want to notify everyone about to_notify = {} # Changes we want to notify everyone about
to_federation_ping = {} # These need sending keep-alives to_federation_ping = {} # These need sending keep-alives
# Only bother handling the last presence change for each user
new_states_dict = {}
for new_state in new_states:
new_states_dict[new_state.user_id] = new_state
new_state = new_states_dict.values()
for new_state in new_states: for new_state in new_states:
user_id = new_state.user_id user_id = new_state.user_id
@@ -544,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_interested_parties(self, states, calculate_remote_hosts=True): def _get_interested_parties(self, states):
"""Given a list of states return which entities (rooms, users, servers) """Given a list of states return which entities (rooms, users, servers)
are interested in the given states. are interested in the given states.
@@ -567,17 +526,14 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {} hosts_to_states = {}
if calculate_remote_hosts: for room_id, states in room_ids_to_states.items():
for room_id, states in room_ids_to_states.items(): local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) if not local_states:
if not local_states: continue
continue
users = yield self.state.get_current_user_in_room(room_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
hosts = set(get_domain_from_id(u) for u in users) for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items(): for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states) local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
@@ -609,24 +565,24 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states) self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states): def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers. """Sends state updates to remote servers.
Args: Args:
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
""" """
now = self.clock.time_msec()
for host, states in hosts_to_states.items(): for host, states in hosts_to_states.items():
self.federation.send_presence(host, states) self.federation.send_edu(
destination=host,
edu_type="m.presence",
content={
"push": [
_format_user_presence_state(state, now)
for state in states
]
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def incoming_presence(self, origin, content): def incoming_presence(self, origin, content):
@@ -647,13 +603,6 @@ class PresenceHandler(object):
) )
continue continue
if get_domain_from_id(user_id) != origin:
logger.info(
"Got presence update from %r with bad 'user_id': %r",
origin, user_id,
)
continue
presence_state = push.get("presence", None) presence_state = push.get("presence", None)
if not presence_state: if not presence_state:
logger.info( logger.info(
@@ -713,17 +662,17 @@ class PresenceHandler(object):
defer.returnValue([ defer.returnValue([
{ {
"type": "m.presence", "type": "m.presence",
"content": format_user_presence_state(state, now), "content": _format_user_presence_state(state, now),
} }
for state in updates for state in updates
]) ])
else: else:
defer.returnValue([ defer.returnValue([
format_user_presence_state(state, now) for state in updates _format_user_presence_state(state, now) for state in updates
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False): def set_state(self, target_user, state):
"""Set the presence state of the user. """Set the presence state of the user.
""" """
status_msg = state.get("status_msg", None) status_msg = state.get("status_msg", None)
@@ -740,13 +689,10 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id) prev_state = yield self.current_state_for_user(user_id)
new_fields = { new_fields = {
"state": presence "state": presence,
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
} }
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE: if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec() new_fields["last_active_ts"] = self.clock.time_msec()
@@ -765,13 +711,13 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part # don't need to send to local clients here, as that is done as part
# of the event stream/sync. # of the event stream/sync.
# TODO: Only send to servers not already in the room. # TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id)
if self.is_mine(user): if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string()) state = yield self.current_state_for_user(user.to_string())
hosts = set(get_domain_from_id(u) for u in user_ids) hosts = yield self.store.get_joined_hosts_for_room(room_id)
self._push_to_remotes({host: (state,) for host in hosts}) self._push_to_remotes({host: (state,) for host in hosts})
else: else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.is_mine_id, user_ids) user_ids = filter(self.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids) states = yield self.current_state_for_users(user_ids)
@@ -947,38 +893,28 @@ class PresenceHandler(object):
def should_notify(old_state, new_state): def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties. """Decides if a presence state change should be sent to interested parties.
""" """
if old_state == new_state:
return False
if old_state.status_msg != new_state.status_msg: if old_state.status_msg != new_state.status_msg:
notify_reason_counter.inc("status_msg_change")
return True
if old_state.state != new_state.state:
notify_reason_counter.inc("state_change")
state_transition_counter.inc(old_state.state, new_state.state)
return True return True
if old_state.state == PresenceState.ONLINE: if old_state.state == PresenceState.ONLINE:
if new_state.currently_active != old_state.currently_active: if new_state.state != PresenceState.ONLINE:
notify_reason_counter.inc("current_active_change") # Always notify for online -> anything
return True return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: if new_state.currently_active != old_state.currently_active:
# Only notify about last active bumps if we're not currently acive return True
if not new_state.currently_active:
notify_reason_counter.inc("last_active_change_online")
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped. # Always notify for a transition where last active gets bumped.
notify_reason_counter.inc("last_active_change_not_online") return True
if old_state.state != new_state.state:
return True return True
return False return False
def format_user_presence_state(state, now): def _format_user_presence_state(state, now):
"""Convert UserPresenceState to a format that can be sent down to clients """Convert UserPresenceState to a format that can be sent down to clients
and to other servers. and to other servers.
""" """
@@ -1005,7 +941,6 @@ class PresenceEventSource(object):
self.get_presence_handler = hs.get_presence_handler self.get_presence_handler = hs.get_presence_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -1068,7 +1003,7 @@ class PresenceEventSource(object):
user_ids_to_check = set() user_ids_to_check = set()
for room_id in room_ids: for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
user_ids_to_check.update(users) user_ids_to_check.update(users)
user_ids_to_check.update(friends) user_ids_to_check.update(friends)
@@ -1091,7 +1026,7 @@ class PresenceEventSource(object):
defer.returnValue(([ defer.returnValue(([
{ {
"type": "m.presence", "type": "m.presence",
"content": format_user_presence_state(s, now), "content": _format_user_presence_state(s, now),
} }
for s in updates.values() for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE if include_offline or s.state != PresenceState.OFFLINE

View File

@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID from synapse.types import UserID, Requester
from ._base import BaseHandler from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -65,13 +65,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["displayname"]) defer.returnValue(result["displayname"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False): def set_displayname(self, target_user, requester, new_displayname):
"""target_user is the user whose displayname is to be changed; """target_user is the user whose displayname is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if not by_admin and target_user != requester.user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '': if new_displayname == '':
@@ -111,13 +111,13 @@ class ProfileHandler(BaseHandler):
defer.returnValue(result["avatar_url"]) defer.returnValue(result["avatar_url"])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False): def set_avatar_url(self, target_user, requester, new_avatar_url):
"""target_user is the user whose avatar_url is to be changed; """target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change.""" auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if not by_admin and target_user != requester.user: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url") raise AuthError(400, "Cannot set another user's avatar_url")
yield self.store.set_profile_avatar_url( yield self.store.set_profile_avatar_url(
@@ -165,9 +165,7 @@ class ProfileHandler(BaseHandler):
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the user isn't a guest because we don't let guests set
# profile or avatar data. # profile or avatar data.
# XXX why are we recreating `requester` here for each room? requester = Requester(user, "", False)
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, user,

View File

@@ -18,7 +18,6 @@ from ._base import BaseHandler
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import get_domain_from_id
import logging import logging
@@ -38,7 +37,6 @@ class ReceiptsHandler(BaseHandler):
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, def received_client_receipt(self, room_id, receipt_type, user_id,
@@ -135,8 +133,7 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"] event_ids = receipt["event_ids"]
data = receipt["data"] data = receipt["data"]
users = yield self.state.get_current_user_in_room(room_id) remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy() remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name) remotedomains.discard(self.server_name)
@@ -156,7 +153,6 @@ class ReceiptsHandler(BaseHandler):
} }
}, },
}, },
key=(room_id, receipt_type, user_id),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@@ -14,18 +14,18 @@
# limitations under the License. # limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID, Requester
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from synapse.http.client import CaptchaServerHttpClient
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,13 +52,6 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
if localpart[0] == '_':
raise SynapseError(
400,
"User ID may not begin with _",
Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@@ -106,13 +99,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@@ -193,7 +181,7 @@ class RegistrationHandler(BaseHandler):
def appservice_register(self, user_localpart, as_token): def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname) user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token) service = yield self.store.get_app_service_by_token(as_token)
if not service: if not service:
raise AuthError(403, "Invalid application service token.") raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id): if not service.is_interested_in_user(user_id):
@@ -208,13 +196,15 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart, create_profile_with_localpart=user.localpart,
) )
defer.returnValue(user_id) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
@@ -304,10 +294,11 @@ class RegistrationHandler(BaseHandler):
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield identity_handler.bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None): def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# valid user IDs must not clash with any user ID namespaces claimed by # valid user IDs must not clash with any user ID namespaces claimed by
# application services. # application services.
services = self.store.get_app_services() services = yield self.store.get_app_services()
interested_services = [ interested_services = [
s for s in services s for s in services
if s.is_interested_in_user(user_id) if s.is_interested_in_user(user_id)
@@ -369,7 +360,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, requester, localpart, displayname, duration_in_ms, def get_or_create_user(self, localpart, displayname, duration_seconds,
password_hash=None): password_hash=None):
"""Creates a new user if the user does not exist, """Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one. else revokes all previous access tokens and generates a new one.
@@ -399,8 +390,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_access_token( token = self.auth_handler().generate_short_term_login_token(
user_id, None, duration_in_ms) user_id, duration_seconds)
if need_register: if need_register:
yield self.store.register( yield self.store.register(
@@ -417,7 +408,7 @@ class RegistrationHandler(BaseHandler):
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, requester, displayname, by_admin=True, user, Requester(user, token, False), displayname
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))

View File

@@ -20,10 +20,12 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, JoinRules, RoomCreationPreset EventTypes, JoinRules, RoomCreationPreset, Membership,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from collections import OrderedDict from collections import OrderedDict
@@ -34,6 +36,8 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
id_server_scheme = "https://" id_server_scheme = "https://"
@@ -192,11 +196,6 @@ class RoomCreationHandler(BaseHandler):
}, },
ratelimit=False) ratelimit=False)
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
content["is_direct"] = is_direct
for invitee in invite_list: for invitee in invite_list:
yield room_member_handler.update_membership( yield room_member_handler.update_membership(
requester, requester,
@@ -204,7 +203,6 @@ class RoomCreationHandler(BaseHandler):
room_id, room_id,
"invite", "invite",
ratelimit=False, ratelimit=False,
content=content,
) )
for invite_3pid in invite_3pid_list: for invite_3pid in invite_3pid_list:
@@ -344,6 +342,149 @@ class RoomCreationHandler(BaseHandler):
) )
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
self.remote_list_request_cache = ResponseCache()
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
)
self.fetch_all_remote_lists()
def get_local_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks
def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks
def handle_room(room_id):
current_state = yield self.state_handler.get_current_state(room_id)
# Double check that this is actually a public room.
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
result = {"room_id": room_id}
num_joined_users = len([
1 for _, event in current_state.items()
if event.type == EventTypes.Member
and event.membership == Membership.JOIN
])
if num_joined_users == 0:
return
result["num_joined_members"] = num_joined_users
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
results.append(result)
yield concurrently_execute(handle_room, room_ids, 10)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results})
@defer.inlineCallbacks
def fetch_all_remote_lists(self):
deferred = self.hs.get_replication_layer().get_public_rooms(
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
specified in the secondary_directory_servers config option.
XXX: Pagination...
"""
# We return the results from out cache which is updated by a looping call,
# unless we're missing a cache entry, in which case wait for the result
# of the fetch if there's one in progress. If not, omit that server.
wait = False
for s in self.hs.config.secondary_directory_servers:
if s not in self.remote_list_cache:
logger.warn("No cached room list from %s: waiting for fetch", s)
wait = True
break
if wait and self.remote_list_request_cache.get(()):
yield self.remote_list_request_cache.get(())
public_rooms = yield self.get_local_public_room_list()
# keep track of which room IDs we've seen so we can de-dup
room_ids = set()
# tag all the ones in our list with our server name.
# Also add the them to the de-deping set
for room in public_rooms['chunk']:
room["server_name"] = self.hs.hostname
room_ids.add(room["room_id"])
# Now add the results from federation
for server_name, server_result in self.remote_list_cache.items():
for room in server_result["chunk"]:
if room["room_id"] not in room_ids:
room["server_name"] = server_name
public_rooms["chunk"].append(room)
room_ids.add(room["room_id"])
defer.returnValue(public_rooms)
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, is_guest): def get_event_context(self, user, room_id, event_id, limit, is_guest):
@@ -437,7 +578,7 @@ class RoomEventSource(object):
logger.warn("Stream has topological part!!!! %r", from_key) logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,) from_key = "s%s" % (from_token.stream,)
app_service = self.store.get_app_service_by_user_id( app_service = yield self.store.get_app_service_by_user_id(
user.to_string() user.to_string()
) )
if app_service: if app_service:
@@ -475,11 +616,8 @@ class RoomEventSource(object):
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))
def get_current_key(self): def get_current_key(self, direction='f'):
return self.store.get_room_events_max_id() return self.store.get_room_events_max_id(direction)
def get_current_key_for_room(self, room_id):
return self.store.get_room_events_max_id(room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pagination_rows(self, user, config, key): def get_pagination_rows(self, user, config, key):

View File

@@ -1,403 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from collections import namedtuple
from unpaddedbase64 import encode_base64, decode_base64
import logging
import msgpack
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs)
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We explicitly don't bother caching searches.
return self._get_public_room_list(limit, since_token, search_filter)
result = self.response_cache.get((limit, since_token))
if not result:
result = self.response_cache.set(
(limit, since_token),
self._get_public_room_list(limit, since_token)
)
return result
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
since_token = None
rooms_to_order_value = {}
rooms_to_num_joined = {}
rooms_to_latest_event_ids = {}
newly_visible = []
newly_unpublished = []
if since_token:
stream_token = since_token.stream_ordering
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id
)
# We want to return rooms in a particular order: the number of joined
# users. We then arbitrarily use the room_id as a tie breaker.
@defer.inlineCallbacks
def get_order_for_room(room_id):
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
if not latest_event_ids:
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
rooms_to_latest_event_ids[room_id] = latest_event_ids
if not latest_event_ids:
return
joined_users = yield self.state_handler.get_current_user_in_room(
room_id, latest_event_ids,
)
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
if num_joined_users == 0:
return
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries]
# `sorted_rooms` should now be a list of all public room ids that is
# stable across pagination. Therefore, we can use indices into this
# list as our pagination tokens.
# Filter out rooms that we don't want to return
rooms_to_scan = [
r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
]
total_room_count = len(rooms_to_scan)
if since_token:
# Filter out rooms we've already returned previously
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it.
if since_token.direction_is_forward:
rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
else:
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
# Actually generate the entries. _generate_room_entry will append to
# chunk but will stop if len(chunk) > limit
chunk = []
if limit and not search_filter:
step = limit + 1
for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _generate_room_entry
# won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan[i:i + step], 10
)
if len(chunk) >= limit + 1:
break
else:
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan, 5
)
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
# Work out the new limit of the batch for pagination, or None if we
# know there are no more results that would be returned.
# i.e., [since_token.current_limit..new_limit] is the batch of rooms
# we've returned (or the reverse if we paginated backwards)
# We tried to pull out limit + 1 rooms above, so if we have <= limit
# then we know there are no more results to return
new_limit = None
if chunk and (not limit or len(chunk) > limit):
if not since_token or since_token.direction_is_forward:
if limit:
chunk = chunk[:limit]
last_room_id = chunk[-1]["room_id"]
else:
if limit:
chunk = chunk[-limit:]
last_room_id = chunk[0]["room_id"]
new_limit = sorted_rooms.index(last_room_id)
results = {
"chunk": chunk,
"total_room_count_estimate": total_room_count,
}
if since_token:
results["new_rooms"] = bool(newly_visible)
if not since_token or since_token.direction_is_forward:
if new_limit is not None:
results["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=True,
).to_token()
if since_token:
results["prev_batch"] = since_token.copy_and_replace(
direction_is_forward=False,
current_limit=since_token.current_limit + 1,
).to_token()
else:
if new_limit is not None:
results["prev_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=False,
).to_token()
if since_token:
results["next_batch"] = since_token.copy_and_replace(
direction_is_forward=True,
current_limit=since_token.current_limit - 1,
).to_token()
defer.returnValue(results)
@defer.inlineCallbacks
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
search_filter):
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
result = {
"room_id": room_id,
"num_joined_members": num_joined_users,
}
current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
event_map = yield self.store.get_events([
event_id for key, event_id in current_state_ids.items()
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
EventTypes.Topic,
EventTypes.CanonicalAlias,
EventTypes.RoomHistoryVisibility,
EventTypes.GuestAccess,
"m.room.avatar",
)
])
current_state = {
(ev.type, ev.state_key): ev
for ev in event_map.values()
}
# Double check that this is actually a public room.
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
if _matches_room_entry(result, search_filter):
chunk.append(result)
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We currently don't support searching across federation, so we have
# to do it manually without pagination
limit = None
since_token = None
res = yield self._get_remote_list_cached(
server_name, limit=limit, since_token=since_token,
)
if search_filter:
res = {"chunk": [
entry
for entry in list(res.get("chunk", []))
if _matches_room_entry(entry, search_filter)
]}
defer.returnValue(res)
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None):
repl_layer = self.hs.get_replication_layer()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
result = self.remote_response_cache.get((server_name, limit, since_token))
if not result:
result = self.remote_response_cache.set(
(server_name, limit, since_token),
repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
)
return result
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
"stream_ordering", # stream_ordering of the first public room list
"public_room_stream_id", # public room stream id for first public room list
"current_limit", # The number of previous rooms returned
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
))):
KEY_DICT = {
"stream_ordering": "s",
"public_room_stream_id": "p",
"current_limit": "n",
"direction_is_forward": "d",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
def from_token(cls, token):
return RoomListNextBatch(**{
cls.REVERSE_KEY_DICT[key]: val
for key, val in msgpack.loads(decode_base64(token)).items()
})
def to_token(self):
return encode_base64(msgpack.dumps({
self.KEY_DICT[key]: val
for key, val in self._asdict().items()
}))
def copy_and_replace(self, **kwds):
return self._replace(
**kwds
)
def _matches_room_entry(room_entry, search_filter):
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
return True
elif generic_search_term in room_entry.get("topic", "").upper():
return True
elif generic_search_term in room_entry.get("canonical_alias", "").upper():
return True
else:
return True
return False

View File

@@ -14,22 +14,24 @@
# limitations under the License. # limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, EventTypes, Membership,
) )
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,13 +61,10 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids, prev_event_ids,
txn_id=None, txn_id=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
if content is None:
content = {}
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership content = {"membership": membership}
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
@@ -85,12 +84,6 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
) )
# Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return
yield msg_handler.handle_new_client_event( yield msg_handler.handle_new_client_event(
requester, requester,
event, event,
@@ -99,26 +92,20 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.prev_state_ids.get( prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, target.to_string()),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# room. Don't bother if the user is just changing their profile # Only fire user_joined_room if the user has acutally joined the
# info. # room. Don't bother if the user is just changing their profile
newly_joined = True # info.
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target, room_id) yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event and prev_member_event.membership == Membership.JOIN:
prev_member_event = yield self.store.get_event(prev_member_event_id) user_left_room(self.distributor, target, room_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content): def remote_join(self, remote_room_hosts, room_id, user, content):
@@ -155,9 +142,8 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None, remote_room_hosts=None,
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
key = (room_id,) key = (target, room_id,)
with (yield self.member_linearizer.queue(key)): with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership( result = yield self._update_membership(
@@ -169,7 +155,6 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed, third_party_signed=third_party_signed,
ratelimit=ratelimit, ratelimit=ratelimit,
content=content,
) )
defer.returnValue(result) defer.returnValue(result)
@@ -185,11 +170,7 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None, remote_room_hosts=None,
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
if content is None:
content = {}
effective_membership_state = action effective_membership_state = action
if action in ["kick", "unban"]: if action in ["kick", "unban"]:
effective_membership_state = "leave" effective_membership_state = "leave"
@@ -207,32 +188,29 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts = [] remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) old_state = current_state.get((EventTypes.Member, target.to_string()))
if old_state_id: old_membership = old_state.content.get("membership") if old_state else None
old_state = yield self.store.get_event(old_state_id, allow_none=True) if action == "unban" and old_membership != "ban":
old_membership = old_state.content.get("membership") if old_state else None raise SynapseError(
if action == "unban" and old_membership != "ban": 403,
raise SynapseError( "Cannot unban user who was not banned (membership=%s)" % old_membership,
403, errcode=Codes.BAD_STATE
"Cannot unban user who was not banned" )
" (membership=%s)" % old_membership, if old_membership == "ban" and action != "unban":
errcode=Codes.BAD_STATE raise SynapseError(
) 403,
if old_membership == "ban" and action != "unban": "Cannot %s user who was banned" % (action,),
raise SynapseError( errcode=Codes.BAD_STATE
403, )
"Cannot %s user who was banned" % (action,),
errcode=Codes.BAD_STATE
)
is_host_in_room = yield self._is_host_in_room(current_state_ids) is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state_ids): if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@@ -242,7 +220,7 @@ class RoomMemberHandler(BaseHandler):
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
content["membership"] = Membership.JOIN content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
@@ -296,7 +274,6 @@ class RoomMemberHandler(BaseHandler):
txn_id=txn_id, txn_id=txn_id,
ratelimit=ratelimit, ratelimit=ratelimit,
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
content=content,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -338,20 +315,18 @@ class RoomMemberHandler(BaseHandler):
) )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = synapse.types.create_requester(target_user) requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = yield message_handler.deduplicate_state_event(event, context) prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None: if prev_event is not None:
return return
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest: if requester.is_guest and not self._can_guest_join(context.current_state):
guest_can_join = yield self._can_guest_join(context.prev_state_ids) # This should be an auth check, but guests are a local concept,
if not guest_can_join: # so don't really fit into the general auth process.
# This should be an auth check, but guests are a local concept, raise AuthError(403, "Guest access not allowed")
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield message_handler.handle_new_client_event( yield message_handler.handle_new_client_event(
requester, requester,
@@ -361,39 +336,27 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.prev_state_ids.get( prev_member_event = context.current_state.get(
(EventTypes.Member, event.state_key), (EventTypes.Member, target_user.to_string()),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# room. Don't bother if the user is just changing their profile # Only fire user_joined_room if the user has acutally joined the
# info. # room. Don't bother if the user is just changing their profile
newly_joined = True # info.
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target_user, room_id) yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event and prev_member_event.membership == Membership.JOIN:
prev_member_event = yield self.store.get_event(prev_member_event_id) user_left_room(self.distributor, target_user, room_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
@defer.inlineCallbacks def _can_guest_join(self, current_state):
def _can_guest_join(self, current_state_ids):
""" """
Returns whether a guest can join a room based on its current state. Returns whether a guest can join a room based on its current state.
""" """
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
if not guest_access_id: return (
defer.returnValue(False)
guest_access = yield self.store.get_event(guest_access_id)
defer.returnValue(
guest_access guest_access
and guest_access.content and guest_access.content
and "guest_access" in guest_access.content and "guest_access" in guest_access.content
@@ -712,24 +675,3 @@ class RoomMemberHandler(BaseHandler):
if membership: if membership:
yield self.store.forget(user_id, room_id) yield self.store.forget(user_id, room_id)
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
for (etype, state_key), event_id in current_state_ids.items():
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
event = yield self.store.get_event(event_id, allow_none=True)
if not event:
continue
if event.membership == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)

View File

@@ -35,7 +35,6 @@ SyncConfig = collections.namedtuple("SyncConfig", [
"filter_collection", "filter_collection",
"is_guest", "is_guest",
"request_key", "request_key",
"device_id",
]) ])
@@ -114,7 +113,6 @@ class SyncResult(collections.namedtuple("SyncResult", [
"joined", # JoinedSyncResult for each joined room. "joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
])): ])):
__slots__ = [] __slots__ = []
@@ -128,8 +126,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.joined or self.joined or
self.invited or self.invited or
self.archived or self.archived or
self.account_data or self.account_data
self.to_device
) )
@@ -141,8 +138,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler() self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.response_cache = ResponseCache(hs) self.response_cache = ResponseCache()
self.state = hs.get_state_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False): full_state=False):
@@ -359,11 +355,11 @@ class SyncHandler(object):
Returns: Returns:
A Deferred map from ((type, state_key)->Event) A Deferred map from ((type, state_key)->Event)
""" """
state_ids = yield self.store.get_state_ids_for_event(event.event_id) state = yield self.store.get_state_for_event(event.event_id)
if event.is_state(): if event.is_state():
state_ids = state_ids.copy() state = state.copy()
state_ids[(event.type, event.state_key)] = event.event_id state[(event.type, event.state_key)] = event
defer.returnValue(state_ids) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_at(self, room_id, stream_position): def get_state_at(self, room_id, stream_position):
@@ -416,66 +412,62 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"): with Measure(self.clock, "compute_state_delta"):
if full_state: if full_state:
if batch: if batch:
current_state_ids = yield self.store.get_state_ids_for_event( current_state = yield self.store.get_state_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state_ids = yield self.store.get_state_ids_for_event( state = yield self.store.get_state_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
else: else:
current_state_ids = yield self.get_state_at( current_state = yield self.get_state_at(
room_id, stream_position=now_token room_id, stream_position=now_token
) )
state_ids = current_state_ids state = current_state
timeline_state = { timeline_state = {
(event.type, event.state_key): event.event_id (event.type, event.state_key): event
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state_ids = _calculate_state( state = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_ids, timeline_start=state,
previous={}, previous={},
current=current_state_ids, current=current_state,
) )
elif batch.limited: elif batch.limited:
state_at_previous_sync = yield self.get_state_at( state_at_previous_sync = yield self.get_state_at(
room_id, stream_position=since_token room_id, stream_position=since_token
) )
current_state_ids = yield self.store.get_state_ids_for_event( current_state = yield self.store.get_state_for_event(
batch.events[-1].event_id batch.events[-1].event_id
) )
state_at_timeline_start = yield self.store.get_state_ids_for_event( state_at_timeline_start = yield self.store.get_state_for_event(
batch.events[0].event_id batch.events[0].event_id
) )
timeline_state = { timeline_state = {
(event.type, event.state_key): event.event_id (event.type, event.state_key): event
for event in batch.events if event.is_state() for event in batch.events if event.is_state()
} }
state_ids = _calculate_state( state = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
previous=state_at_previous_sync, previous=state_at_previous_sync,
current=current_state_ids, current=current_state,
) )
else: else:
state_ids = {} state = {}
state = {} defer.returnValue({
if state_ids: (e.type, e.state_key): e
state = yield self.store.get_events(state_ids.values()) for e in sync_config.filter_collection.filter_room_state(state.values())
})
defer.returnValue({
(e.type, e.state_key): e
for e in sync_config.filter_collection.filter_room_state(state.values())
})
@defer.inlineCallbacks @defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config): def unread_notifs_for_room_id(self, room_id, sync_config):
@@ -493,9 +485,9 @@ class SyncHandler(object):
) )
defer.returnValue(notifs) defer.returnValue(notifs)
# There is no new information in this period, so your notification # There is no new information in this period, so your notification
# count is whatever it was last time. # count is whatever it was last time.
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False): def generate_sync_result(self, sync_config, since_token=None, full_state=False):
@@ -535,57 +527,15 @@ class SyncHandler(object):
sync_result_builder, newly_joined_rooms, newly_joined_users sync_result_builder, newly_joined_rooms, newly_joined_users
) )
yield self._generate_sync_entry_for_to_device(sync_result_builder)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined, joined=sync_result_builder.joined,
invited=sync_result_builder.invited, invited=sync_result_builder.invited,
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates
`sync_result_builder` with the result.
Args:
sync_result_builder(SyncResultBuilder)
Returns:
Deferred(dict): A dictionary containing the per room account data.
"""
user_id = sync_result_builder.sync_config.user.to_string()
device_id = sync_result_builder.sync_config.device_id
now_token = sync_result_builder.now_token
since_stream_id = 0
if sync_result_builder.since_token is not None:
since_stream_id = int(sync_result_builder.since_token.to_device_key)
if since_stream_id != int(now_token.to_device_key):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
logger.debug("Deleting messages up to %d", since_stream_id)
yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
logger.debug("Getting messages up to %d", now_token.to_device_key)
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
logger.debug("Got messages up to %d: %r", stream_id, messages)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
)
sync_result_builder.to_device = messages
else:
sync_result_builder.to_device = []
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_account_data(self, sync_result_builder): def _generate_sync_entry_for_account_data(self, sync_result_builder):
"""Generates the account data portion of the sync response. Populates """Generates the account data portion of the sync response. Populates
@@ -676,7 +626,7 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_users) extra_users_ids = set(newly_joined_users)
for room_id in newly_joined_rooms: for room_id in newly_joined_rooms:
users = yield self.state.get_current_user_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
extra_users_ids.update(users) extra_users_ids.update(users)
extra_users_ids.discard(user.to_string()) extra_users_ids.discard(user.to_string())
@@ -788,7 +738,7 @@ class SyncHandler(object):
assert since_token assert since_token
app_service = self.store.get_app_service_by_user_id(user_id) app_service = yield self.store.get_app_service_by_user_id(user_id)
if app_service: if app_service:
rooms = yield self.store.get_app_service_rooms(app_service) rooms = yield self.store.get_app_service_rooms(app_service)
joined_room_ids = set(r.room_id for r in rooms) joined_room_ids = set(r.room_id for r in rooms)
@@ -816,13 +766,8 @@ class SyncHandler(object):
# the last sync (even if we have since left). This is to make sure # the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary # we do send down the room, and with full state, where necessary
if room_id in joined_room_ids or has_join: if room_id in joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token) old_state = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) old_mem_ev = old_state.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
old_mem_ev = yield self.store.get_event(
old_mem_ev_id, allow_none=True
)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id) newly_joined_rooms.append(room_id)
@@ -1114,25 +1059,27 @@ def _calculate_state(timeline_contains, timeline_start, previous, current):
Returns: Returns:
dict dict
""" """
event_id_to_key = { event_id_to_state = {
e: key e.event_id: e
for key, e in itertools.chain( for e in itertools.chain(
timeline_contains.items(), timeline_contains.values(),
previous.items(), previous.values(),
timeline_start.items(), timeline_start.values(),
current.items(), current.values(),
) )
} }
c_ids = set(e for e in current.values()) c_ids = set(e.event_id for e in current.values())
tc_ids = set(e for e in timeline_contains.values()) tc_ids = set(e.event_id for e in timeline_contains.values())
p_ids = set(e for e in previous.values()) p_ids = set(e.event_id for e in previous.values())
ts_ids = set(e for e in timeline_start.values()) ts_ids = set(e.event_id for e in timeline_start.values())
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
evs = (event_id_to_state[e] for e in state_ids)
return { return {
event_id_to_key[e]: e for e in state_ids (e.type, e.state_key): e
for e in evs
} }
@@ -1156,7 +1103,6 @@ class SyncResultBuilder(object):
self.joined = [] self.joined = []
self.invited = [] self.invited = []
self.archived = [] self.archived = []
self.device = []
class RoomSyncResultBuilder(object): class RoomSyncResultBuilder(object):

View File

@@ -16,10 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import preserve_fn from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.types import UserID
from synapse.types import UserID, get_domain_from_id
import logging import logging
@@ -34,13 +33,6 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user_id")) RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
# How often we expect remote servers to resend us presence.
FEDERATION_TIMEOUT = 60 * 1000
# How often to resend typing across federation.
FEDERATION_PING_INTERVAL = 40 * 1000
class TypingHandler(object): class TypingHandler(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@@ -48,12 +40,8 @@ class TypingHandler(object):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.wheel_timer = WheelTimer(bucket_size=5000)
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@@ -62,7 +50,7 @@ class TypingHandler(object):
hs.get_distributor().observe("user_left_room", self.user_left_room) hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop self._member_typing_until = {} # clock time we expect to stop
self._member_last_federation_poke = {} self._member_typing_timer = {} # deferreds to manage theabove
# map room IDs to serial numbers # map room IDs to serial numbers
self._room_serials = {} self._room_serials = {}
@@ -70,49 +58,12 @@ class TypingHandler(object):
# map room IDs to sets of users currently typing # map room IDs to sets of users currently typing
self._room_typing = {} self._room_typing = {}
self.clock.looping_call( def tearDown(self):
self._handle_timeouts, """Cancels all the pending timers.
5000, Normally this shouldn't be needed, but it's required from unit tests
) to avoid a "Reactor was unclean" warning."""
for t in self._member_typing_timer.values():
def _handle_timeouts(self): self.clock.cancel_call_later(t)
logger.info("Checking for typing timeouts")
now = self.clock.time_msec()
members = set(self.wheel_timer.fetch(now))
for member in members:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
continue
until = self._member_typing_until.get(member, None)
if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id)
preserve_fn(self._stopped_typing)(member)
continue
# Check if we need to resend a keep alive over federation for this
# user.
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
preserve_fn(self._push_remote)(
member=member,
typing=True
)
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
self.wheel_timer.insert(
now=now,
obj=member,
then=now + 60 * 1000,
)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@defer.inlineCallbacks @defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout): def started_typing(self, target_user, auth_user, room_id, timeout):
@@ -131,17 +82,23 @@ class TypingHandler(object):
"%s has started typing in %s", target_user_id, room_id "%s has started typing in %s", target_user_id, room_id
) )
until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member.user_id in self._room_typing.get(room_id, set()) was_present = member in self._member_typing_until
now = self.clock.time_msec() if member in self._member_typing_timer:
self._member_typing_until[member] = now + timeout self.clock.cancel_call_later(self._member_typing_timer[member])
self.wheel_timer.insert( def _cb():
now=now, logger.debug(
obj=member, "%s has timed out in %s", target_user.to_string(), room_id
then=now + timeout, )
self._stopped_typing(member)
self._member_typing_until[member] = until
self._member_typing_timer[member] = self.clock.call_later(
timeout / 1000.0, _cb
) )
if was_present: if was_present:
@@ -149,7 +106,8 @@ class TypingHandler(object):
defer.returnValue(None) defer.returnValue(None)
yield self._push_update( yield self._push_update(
member=member, room_id=room_id,
user_id=target_user_id,
typing=True, typing=True,
) )
@@ -172,6 +130,10 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id) member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member])
del self._member_typing_timer[member]
yield self._stopped_typing(member) yield self._stopped_typing(member)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -183,101 +145,79 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _stopped_typing(self, member): def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()): if member not in self._member_typing_until:
# No point # No point
defer.returnValue(None) defer.returnValue(None)
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
yield self._push_update( yield self._push_update(
member=member, room_id=member.room_id,
user_id=member.user_id,
typing=False, typing=False,
) )
@defer.inlineCallbacks del self._member_typing_until[member]
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
yield self._push_remote(member, typing)
self._push_update_local( if member in self._member_typing_timer:
member=member, # Don't cancel it - either it already expired, or the real
typing=typing # stopped_typing() will cancel it
) del self._member_typing_timer[member]
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_remote(self, member, typing): def _push_update(self, room_id, user_id, typing):
users = yield self.state.get_current_user_in_room(member.room_id) domains = yield self.store.get_joined_hosts_for_room(room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec() deferreds = []
self.wheel_timer.insert( for domain in domains:
now=now, if domain == self.server_name:
obj=member, self._push_update_local(
then=now + FEDERATION_PING_INTERVAL, room_id=room_id,
) user_id=user_id,
typing=typing
for domain in set(get_domain_from_id(u) for u in users): )
if domain != self.server_name: else:
self.federation.send_edu( deferreds.append(self.federation.send_edu(
destination=domain, destination=domain,
edu_type="m.typing", edu_type="m.typing",
content={ content={
"room_id": member.room_id, "room_id": room_id,
"user_id": member.user_id, "user_id": user_id,
"typing": typing, "typing": typing,
}, },
key=member, ))
)
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):
room_id = content["room_id"] room_id = content["room_id"]
user_id = content["user_id"] user_id = content["user_id"]
member = RoomMember(user_id=user_id, room_id=room_id)
# Check that the string is a valid user id # Check that the string is a valid user id
user = UserID.from_string(user_id) UserID.from_string(user_id)
if user.domain != origin: domains = yield self.store.get_joined_hosts_for_room(room_id)
logger.info(
"Got typing update from %r with bad 'user_id': %r",
origin, user_id,
)
return
users = yield self.state.get_current_user_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains: if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
self.wheel_timer.insert(
now=now,
obj=member,
then=now + FEDERATION_TIMEOUT,
)
self._push_update_local( self._push_update_local(
member=member, room_id=room_id,
user_id=user_id,
typing=content["typing"] typing=content["typing"]
) )
def _push_update_local(self, member, typing): def _push_update_local(self, room_id, user_id, typing):
room_set = self._room_typing.setdefault(member.room_id, set()) room_set = self._room_typing.setdefault(room_id, set())
if typing: if typing:
room_set.add(member.user_id) room_set.add(user_id)
else: else:
room_set.discard(member.user_id) room_set.discard(user_id)
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[member.room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
self.notifier.on_new_event( with PreserveLoggingContext():
"typing_key", self._latest_room_serial, rooms=[member.room_id] self.notifier.on_new_event(
) "typing_key", self._latest_room_serial, rooms=[room_id]
)
def get_all_typing_updates(self, last_id, current_id): def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state. # TODO: Work out a way to do this without scanning the entire state.

View File

@@ -155,7 +155,9 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn(send_request) response = yield preserve_context_over_fn(
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break
@@ -246,7 +248,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None, def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False, timeout=None): long_retries=False):
""" Sends the specifed json data using PUT """ Sends the specifed json data using PUT
Args: Args:
@@ -259,8 +261,6 @@ class MatrixFederationHttpClient(object):
use as the request body. use as the request body.
long_retries (bool): A boolean that indicates whether we should long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -287,7 +287,6 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries, long_retries=long_retries,
timeout=timeout,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
@@ -303,8 +302,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=True, def post_json(self, destination, path, data={}, long_retries=True):
timeout=None):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@@ -315,8 +313,6 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should long_retries (bool): A boolean that indicates whether we should
retry for a short or long time. retry for a short or long time.
timeout(int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@@ -337,7 +333,6 @@ class MatrixFederationHttpClient(object):
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=True, long_retries=True,
timeout=timeout,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:

View File

@@ -19,7 +19,6 @@ from synapse.api.errors import (
) )
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure
import synapse.metrics import synapse.metrics
import synapse.events import synapse.events
@@ -75,12 +74,12 @@ response_db_txn_duration = metrics.register_distribution(
_next_request_id = 0 _next_request_id = 0
def request_handler(include_metrics=False): def request_handler(report_metrics=True):
"""Decorator for ``wrap_request_handler``""" """Decorator for ``wrap_request_handler``"""
return lambda request_handler: wrap_request_handler(request_handler, include_metrics) return lambda request_handler: wrap_request_handler(request_handler, report_metrics)
def wrap_request_handler(request_handler, include_metrics=False): def wrap_request_handler(request_handler, report_metrics):
"""Wraps a method that acts as a request handler with the necessary logging """Wraps a method that acts as a request handler with the necessary logging
and exception handling. and exception handling.
@@ -104,56 +103,54 @@ def wrap_request_handler(request_handler, include_metrics=False):
_next_request_id += 1 _next_request_id += 1
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"): if report_metrics:
request_metrics = RequestMetrics() request_metrics = RequestMetrics()
request_metrics.start(self.clock, name=self.__class__.__name__) request_metrics.start(self.clock)
request_context.request = request_id request_context.request = request_id
with request.processing(): with request.processing():
try:
with PreserveLoggingContext(request_context):
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
try: try:
with PreserveLoggingContext(request_context): if report_metrics:
if include_metrics:
yield request_handler(self, request, request_metrics)
else:
yield request_handler(self, request)
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{
"error": "Internal server error",
"errcode": Codes.UNKNOWN,
},
send_cors=True
)
finally:
try:
request_metrics.stop( request_metrics.stop(
self.clock, request self.clock, request, self.__class__.__name__
) )
except Exception as e: except:
logger.warn("Failed to stop metrics: %r", e) pass
return wrapped_request_handler return wrapped_request_handler
@@ -208,7 +205,6 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )
@@ -223,9 +219,9 @@ class JsonResource(HttpServer, resource.Resource):
# It does its own metric reporting because _async_render dispatches to # It does its own metric reporting because _async_render dispatches to
# a callback and it's the class name of that callback we want to report # a callback and it's the class name of that callback we want to report
# against rather than the JsonResource itself. # against rather than the JsonResource itself.
@request_handler(include_metrics=True) @request_handler(report_metrics=False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request, request_metrics): def _async_render(self, request):
""" This gets called from render() every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
@@ -234,6 +230,9 @@ class JsonResource(HttpServer, resource.Resource):
self._send_response(request, 200, {}) self._send_response(request, 200, {})
return return
request_metrics = RequestMetrics()
request_metrics.start(self.clock)
# Loop through all the registered callbacks to check if the method # Loop through all the registered callbacks to check if the method
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
@@ -247,6 +246,12 @@ class JsonResource(HttpServer, resource.Resource):
callback = path_entry.callback callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
kwargs = intern_dict({ kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items() for name, value in m.groupdict().items()
@@ -257,13 +262,10 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return code, response = callback_return
self._send_response(request, code, response) self._send_response(request, code, response)
servlet_instance = getattr(callback, "__self__", None) try:
if servlet_instance is not None: request_metrics.stop(self.clock, request, servlet_classname)
servlet_classname = servlet_instance.__class__.__name__ except:
else: pass
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
return return
@@ -295,12 +297,11 @@ class JsonResource(HttpServer, resource.Resource):
class RequestMetrics(object): class RequestMetrics(object):
def start(self, clock, name): def start(self, clock):
self.start = clock.time_msec() self.start = clock.time_msec()
self.start_context = LoggingContext.current_context() self.start_context = LoggingContext.current_context()
self.name = name
def stop(self, clock, request): def stop(self, clock, request, servlet_classname):
context = LoggingContext.current_context() context = LoggingContext.current_context()
tag = "" tag = ""
@@ -314,26 +315,26 @@ class RequestMetrics(object):
) )
return return
incoming_requests_counter.inc(request.method, self.name, tag) incoming_requests_counter.inc(request.method, servlet_classname, tag)
response_timer.inc_by( response_timer.inc_by(
clock.time_msec() - self.start, request.method, clock.time_msec() - self.start, request.method,
self.name, tag servlet_classname, tag
) )
ru_utime, ru_stime = context.get_resource_usage() ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by( response_ru_utime.inc_by(
ru_utime, request.method, self.name, tag ru_utime, request.method, servlet_classname, tag
) )
response_ru_stime.inc_by( response_ru_stime.inc_by(
ru_stime, request.method, self.name, tag ru_stime, request.method, servlet_classname, tag
) )
response_db_txn_count.inc_by( response_db_txn_count.inc_by(
context.db_txn_count, request.method, self.name, tag context.db_txn_count, request.method, servlet_classname, tag
) )
response_db_txn_duration.inc_by( response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, self.name, tag context.db_txn_duration, request.method, servlet_classname, tag
) )
@@ -392,30 +393,17 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
if send_cors: if send_cors:
set_cors_headers(request) request.setHeader("Access-Control-Allow-Origin", "*")
request.setHeader("Access-Control-Allow-Methods",
"GET, POST, PUT, DELETE, OPTIONS")
request.setHeader("Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept")
request.write(json_bytes) request.write(json_bytes)
finish_request(request) finish_request(request)
return NOT_DONE_YET return NOT_DONE_YET
def set_cors_headers(request):
"""Set the CORs headers so that javascript running in a web browsers can
use this API
Args:
request (twisted.web.http.Request): The http request to add CORs to.
"""
request.setHeader("Access-Control-Allow-Origin", "*")
request.setHeader(
"Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS"
)
request.setHeader(
"Access-Control-Allow-Headers",
"Origin, X-Requested-With, Content-Type, Accept"
)
def finish_request(request): def finish_request(request):
""" Finish writing the response to the request. """ Finish writing the response to the request.

View File

@@ -41,13 +41,9 @@ def parse_integer(request, name, default=None, required=False):
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer. parameter is present and not an integer.
""" """
return parse_integer_from_args(request.args, name, default, required) if name in request.args:
def parse_integer_from_args(args, name, default=None, required=False):
if name in args:
try: try:
return int(args[name][0]) return int(request.args[name][0])
except: except:
message = "Query parameter %r must be an integer" % (name,) message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message) raise SynapseError(400, message)
@@ -120,15 +116,9 @@ def parse_string(request, name, default=None, required=False,
parameter is present, must be one of a list of allowed values and parameter is present, must be one of a list of allowed values and
is not one of those allowed values. is not one of those allowed values.
""" """
return parse_string_from_args(
request.args, name, default, required, allowed_values, param_type,
)
if name in request.args:
def parse_string_from_args(args, name, default=None, required=False, value = request.args[name][0]
allowed_values=None, param_type="string"):
if name in args:
value = args[name][0]
if allowed_values is not None and value not in allowed_values: if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % ( message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values) name, ", ".join(repr(v) for v in allowed_values)

View File

@@ -13,25 +13,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Because otherwise 'resource' collides with synapse.metrics.resource
from __future__ import absolute_import
import logging import logging
from resource import getrusage, RUSAGE_SELF
import functools import functools
import os
import stat
import time import time
import gc import gc
from twisted.internet import reactor from twisted.internet import reactor
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
MemoryUsageMetric,
) )
from .process_collector import register_process_collector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
all_metrics = [] all_metrics = []
all_collectors = []
class Metrics(object): class Metrics(object):
@@ -42,12 +45,6 @@ class Metrics(object):
def __init__(self, name): def __init__(self, name):
self.name_prefix = name self.name_prefix = name
def make_subspace(self, name):
return Metrics("%s_%s" % (self.name_prefix, name))
def register_collector(self, func):
all_collectors.append(func)
def _register(self, metric_class, name, *args, **kwargs): def _register(self, metric_class, name, *args, **kwargs):
full_name = "%s_%s" % (self.name_prefix, name) full_name = "%s_%s" % (self.name_prefix, name)
@@ -69,21 +66,6 @@ class Metrics(object):
return self._register(CacheMetric, *args, **kwargs) return self._register(CacheMetric, *args, **kwargs)
def register_memory_metrics(hs):
try:
import psutil
process = psutil.Process()
process.memory_info().rss
except (ImportError, AttributeError):
logger.warn(
"psutil is not installed or incorrect version."
" Disabling memory metrics."
)
return
metric = MemoryUsageMetric(hs, psutil)
all_metrics.append(metric)
def get_metrics_for(pkg_name): def get_metrics_for(pkg_name):
""" Returns a Metrics instance for conveniently creating metrics """ Returns a Metrics instance for conveniently creating metrics
namespaced with the given name prefix. """ namespaced with the given name prefix. """
@@ -96,8 +78,8 @@ def get_metrics_for(pkg_name):
def render_all(): def render_all():
strs = [] strs = []
for collector in all_collectors: # TODO(paul): Internal hack
collector() update_resource_metrics()
for metric in all_metrics: for metric in all_metrics:
try: try:
@@ -111,21 +93,73 @@ def render_all():
return "\n".join(strs) return "\n".join(strs)
register_process_collector(get_metrics_for("process")) # Now register some standard process-wide state metrics, to give indications of
# process resource usage
rusage = None
python_metrics = get_metrics_for("python") def update_resource_metrics():
global rusage
rusage = getrusage(RUSAGE_SELF)
gc_time = python_metrics.register_distribution("gc_time", labels=["gen"]) resource_metrics = get_metrics_for("process.resource")
gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"])
python_metrics.register_callback(
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
)
reactor_metrics = get_metrics_for("python.twisted.reactor") # msecs
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# kilobytes
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
# Not every OS will have a /proc/self/fd directory
if not os.path.exists("/proc/self/fd"):
return counts
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time") tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls") pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
gc_time = reactor_metrics.register_distribution("gc_time", labels=["gen"])
gc_unreachable = reactor_metrics.register_counter("gc_unreachable", labels=["gen"])
reactor_metrics.register_callback(
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
)
def runUntilCurrentTimer(func): def runUntilCurrentTimer(func):

View File

@@ -98,9 +98,9 @@ class CallbackMetric(BaseMetric):
value = self.callback() value = self.callback()
if self.is_scalar(): if self.is_scalar():
return ["%s %.12g" % (self.name, value)] return ["%s %d" % (self.name, value)]
return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) return ["%s%s %d" % (self.name, self._render_key(k), value[k])
for k in sorted(value.keys())] for k in sorted(value.keys())]
@@ -153,43 +153,3 @@ class CacheMetric(object):
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
] ]
class MemoryUsageMetric(object):
"""Keeps track of the current memory usage, using psutil.
The class will keep the current min/max/sum/counts of rss over the last
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
"""
UPDATE_HZ = 2 # number of times to get memory per second
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
def __init__(self, hs, psutil):
clock = hs.get_clock()
self.memory_snapshots = []
self.process = psutil.Process()
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
def _update_curr_values(self):
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
self.memory_snapshots.append(self.process.memory_info().rss)
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
def render(self):
if not self.memory_snapshots:
return []
max_rss = max(self.memory_snapshots)
min_rss = min(self.memory_snapshots)
sum_rss = sum(self.memory_snapshots)
len_rss = len(self.memory_snapshots)
return [
"process_psutil_rss:max %d" % max_rss,
"process_psutil_rss:min %d" % min_rss,
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]

View File

@@ -1,122 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
TICKS_PER_SEC = 100
BYTES_PER_PAGE = 4096
HAVE_PROC_STAT = os.path.exists("/proc/stat")
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
# Field indexes from /proc/self/stat, taken from the proc(5) manpage
STAT_FIELDS = {
"utime": 14,
"stime": 15,
"starttime": 22,
"vsize": 23,
"rss": 24,
}
stats = {}
# In order to report process_start_time_seconds we need to know the
# machine's boot time, because the value in /proc/self/stat is relative to
# this
boot_time = None
if HAVE_PROC_STAT:
with open("/proc/stat") as _procstat:
for line in _procstat:
if line.startswith("btime "):
boot_time = int(line.split()[1])
def update_resource_metrics():
if HAVE_PROC_SELF_STAT:
global stats
with open("/proc/self/stat") as s:
line = s.read()
# line is PID (command) more stats go here ...
raw_stats = line.split(") ", 1)[1].split(" ")
for (name, index) in STAT_FIELDS.iteritems():
# subtract 3 from the index, because proc(5) is 1-based, and
# we've lost the first two fields in PID and COMMAND above
stats[name] = int(raw_stats[index - 3])
def _count_fds():
# Not every OS will have a /proc/self/fd directory
if not HAVE_PROC_SELF_FD:
return 0
return len(os.listdir("/proc/self/fd"))
def register_process_collector(process_metrics):
process_metrics.register_collector(update_resource_metrics)
if HAVE_PROC_SELF_STAT:
process_metrics.register_callback(
"cpu_user_seconds_total",
lambda: float(stats["utime"]) / TICKS_PER_SEC
)
process_metrics.register_callback(
"cpu_system_seconds_total",
lambda: float(stats["stime"]) / TICKS_PER_SEC
)
process_metrics.register_callback(
"cpu_seconds_total",
lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC
)
process_metrics.register_callback(
"virtual_memory_bytes",
lambda: int(stats["vsize"])
)
process_metrics.register_callback(
"resident_memory_bytes",
lambda: int(stats["rss"]) * BYTES_PER_PAGE
)
process_metrics.register_callback(
"start_time_seconds",
lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC
)
if HAVE_PROC_SELF_FD:
process_metrics.register_callback(
"open_fds",
lambda: _count_fds()
)
if HAVE_PROC_SELF_LIMITS:
def _get_max_fds():
with open("/proc/self/limits") as limits:
for line in limits:
if not line.startswith("Max open files "):
continue
# Line is Max open files $SOFT $HARD
return int(line.split()[3])
return None
process_metrics.register_callback(
"max_fds",
lambda: _get_max_fds()
)

View File

@@ -19,8 +19,7 @@ from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.metrics import Measure
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
import synapse.metrics import synapse.metrics
@@ -68,8 +67,10 @@ class _NotifierUserStream(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user_id, rooms, current_token, time_now_ms): def __init__(self, user_id, rooms, current_token, time_now_ms,
appservice=None):
self.user_id = user_id self.user_id = user_id
self.appservice = appservice
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
@@ -106,6 +107,11 @@ class _NotifierUserStream(object):
notifier.user_to_user_stream.pop(self.user_id) notifier.user_to_user_stream.pop(self.user_id)
if self.appservice:
notifier.appservice_to_user_streams.get(
self.appservice, set()
).discard(self)
def count_listeners(self): def count_listeners(self):
return len(self.notify_deferred.observers()) return len(self.notify_deferred.observers())
@@ -136,6 +142,7 @@ class Notifier(object):
def __init__(self, hs): def __init__(self, hs):
self.user_to_user_stream = {} self.user_to_user_stream = {}
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@@ -161,6 +168,8 @@ class Notifier(object):
all_user_streams |= x all_user_streams |= x
for x in self.user_to_user_stream.values(): for x in self.user_to_user_stream.values():
all_user_streams.add(x) all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
return sum(stream.count_listeners() for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
@@ -173,8 +182,11 @@ class Notifier(object):
"users", "users",
lambda: len(self.user_to_user_stream), lambda: len(self.user_to_user_stream),
) )
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
)
@preserve_fn
def on_new_room_event(self, event, room_stream_id, max_room_stream_id, def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]): extra_users=[]):
""" Used by handlers to inform the notifier something has happened """ Used by handlers to inform the notifier something has happened
@@ -196,7 +208,6 @@ class Notifier(object):
self.notify_replication() self.notify_replication()
@preserve_fn
def _notify_pending_new_room_events(self, max_room_stream_id): def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous """Notify for the room events that were queued waiting for a previous
event to be persisted. event to be persisted.
@@ -214,11 +225,24 @@ class Notifier(object):
else: else:
self._on_new_room_event(event, room_stream_id, extra_users) self._on_new_room_event(event, room_stream_id, extra_users)
@preserve_fn
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.appservice_handler.notify_interested_services(room_stream_id) self.appservice_handler.notify_interested_services(event)
app_streams = set()
for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
appservice, set()
)
app_streams |= app_user_streams
if event.type == EventTypes.Member and event.membership == Membership.JOIN: if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id) self._user_joined_room(event.state_key, event.room_id)
@@ -227,36 +251,35 @@ class Notifier(object):
"room_key", room_stream_id, "room_key", room_stream_id,
users=extra_users, users=extra_users,
rooms=[event.room_id], rooms=[event.room_id],
extra_streams=app_streams,
) )
@preserve_fn def on_new_event(self, stream_key, new_token, users=[], rooms=[],
def on_new_event(self, stream_key, new_token, users=[], rooms=[]): extra_streams=set()):
""" Used to inform listeners that something has happend event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
with PreserveLoggingContext(): with PreserveLoggingContext():
with Measure(self.clock, "on_new_event"): user_streams = set()
user_streams = set()
for user in users: for user in users:
user_stream = self.user_to_user_stream.get(str(user)) user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None: if user_stream is not None:
user_streams.add(user_stream) user_streams.add(user_stream)
for room in rooms: for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set()) user_streams |= self.room_to_user_streams.get(room, set())
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
for user_stream in user_streams: for user_stream in user_streams:
try: try:
user_stream.notify(stream_key, new_token, time_now_ms) user_stream.notify(stream_key, new_token, time_now_ms)
except: except:
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
self.notify_replication() self.notify_replication()
@preserve_fn
def on_new_replication_data(self): def on_new_replication_data(self):
"""Used to inform replication listeners that something has happend """Used to inform replication listeners that something has happend
without waking up any of the normal user event streams""" without waking up any of the normal user event streams"""
@@ -271,6 +294,7 @@ class Notifier(object):
""" """
user_stream = self.user_to_user_stream.get(user_id) user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user_id)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
if room_ids is None: if room_ids is None:
rooms = yield self.store.get_rooms_for_user(user_id) rooms = yield self.store.get_rooms_for_user(user_id)
@@ -278,6 +302,7 @@ class Notifier(object):
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user_id=user_id, user_id=user_id,
rooms=room_ids, rooms=room_ids,
appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
) )
@@ -423,8 +448,7 @@ class Notifier(object):
def _is_world_readable(self, room_id): def _is_world_readable(self, room_id):
state = yield self.state_handler.get_current_state( state = yield self.state_handler.get_current_state(
room_id, room_id,
EventTypes.RoomHistoryVisibility, EventTypes.RoomHistoryVisibility
"",
) )
if state and "history_visibility" in state.content: if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable") defer.returnValue(state.content["history_visibility"] == "world_readable")
@@ -453,6 +477,11 @@ class Notifier(object):
s = self.room_to_user_streams.setdefault(room, set()) s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream) s.add(user_stream)
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
def _user_joined_room(self, user_id, room_id): def _user_joined_room(self, user_id, room_id):
new_user_stream = self.user_to_user_stream.get(user_id) new_user_stream = self.user_to_user_stream.get(user_id)
if new_user_stream is not None: if new_user_stream is not None:

View File

@@ -38,16 +38,15 @@ class ActionGenerator:
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "evaluator_for_event"): with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store, context event, self.hs, self.store, context.current_state
) )
with Measure(self.clock, "action_for_event_by_user"):
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, context event, context.current_state
) )
context.push_actions = [ context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items() (uid, actions) for uid, actions in actions_by_user.items()
] ]

View File

@@ -217,27 +217,6 @@ BASE_APPEND_OVERRIDE_RULES = [
'dont_notify' 'dont_notify'
] ]
}, },
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
] ]
@@ -263,8 +242,23 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
# XXX: once m.direct is standardised everywhere, we should use it to detect {
# a DM from the user's perspective rather than this heuristic. 'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{ {
'rule_id': 'global/underride/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [
@@ -291,34 +285,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
} }
] ]
}, },
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one',
'conditions': [
{
'kind': 'room_member_count',
'is': '2',
'_id': 'member_count',
},
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.encrypted',
'_id': '_encrypted',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
{ {
'rule_id': 'global/underride/.m.rule.message', 'rule_id': 'global/underride/.m.rule.message',
'conditions': [ 'conditions': [
@@ -335,25 +301,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
'value': False 'value': False
} }
] ]
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
'rule_id': 'global/underride/.m.rule.encrypted',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.encrypted',
'_id': '_encrypted',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': False
}
]
} }
] ]

View File

@@ -19,19 +19,52 @@ from twisted.internet import defer
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.visibility import filter_events_for_clients_context from synapse.visibility import filter_events_for_clients
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store, context): def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules_for_room( rules_by_user = yield store.bulk_get_push_rules(user_ids)
event, context
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@defer.inlineCallbacks
def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
local_users_in_room = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
) )
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in local_users_in_room:
user_ids.add(uid)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite': if event.type == 'm.room.member' and event.content['membership'] == 'invite':
@@ -39,13 +72,12 @@ def evaluator_for_event(event, hs, store, context):
if invited_user and hs.is_mine_id(invited_user): if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user) has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher: if has_pusher:
rules_by_user = dict(rules_by_user) user_ids.add(invited_user)
rules_by_user[invited_user] = yield store.get_push_rules_for_user(
invited_user rules_by_user = yield _get_rules(room_id, user_ids, store)
)
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
event.room_id, rules_by_user, store room_id, rules_by_user, user_ids, store
)) ))
@@ -58,13 +90,14 @@ class BulkPushRuleEvaluator:
the same logic to run the actual rules, but could be optimised further the same logic to run the actual rules, but could be optimised further
(see https://matrix.org/jira/browse/SYN-562) (see https://matrix.org/jira/browse/SYN-562)
""" """
def __init__(self, room_id, rules_by_user, store): def __init__(self, room_id, rules_by_user, users_in_room, store):
self.room_id = room_id self.room_id = room_id
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
self.users_in_room = users_in_room
self.store = store self.store = store
@defer.inlineCallbacks @defer.inlineCallbacks
def action_for_event_by_user(self, event, context): def action_for_event_by_user(self, event, current_state):
actions_by_user = {} actions_by_user = {}
# None of these users can be peeking since this list of users comes # None of these users can be peeking since this list of users comes
@@ -74,25 +107,27 @@ class BulkPushRuleEvaluator:
(u, False) for u in self.rules_by_user.keys() (u, False) for u in self.rules_by_user.keys()
] ]
filtered_by_user = yield filter_events_for_clients_context( filtered_by_user = yield filter_events_for_clients(
self.store, user_tuples, [event], {event.event_id: context} self.store, user_tuples, [event], {event.event_id: current_state}
) )
room_members = yield self.store.get_joined_users_from_context( room_members = set(
event, context e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
) )
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
condition_cache = {} condition_cache = {}
display_names = {}
for ev in current_state.values():
nm = ev.content.get("displayname", None)
if nm and ev.type == EventTypes.Member:
display_names[ev.state_key] = nm
for uid, rules in self.rules_by_user.items(): for uid, rules in self.rules_by_user.items():
display_name = None display_name = display_names.get(uid, None)
member_ev_id = context.current_state_ids.get((EventTypes.Member, uid))
if member_ev_id:
member_ev = yield self.store.get_event(member_ev_id, allow_none=True)
if member_ev:
display_name = member_ev.content.get("displayname", None)
filtered = filtered_by_user[uid] filtered = filtered_by_user[uid]
if len(filtered) == 0: if len(filtered) == 0:

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
@@ -93,11 +92,7 @@ class EmailPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
try: self.timed_call.cancel()
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@@ -145,15 +140,12 @@ class EmailPusher(object):
being run. being run.
""" """
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
fn = self.store.get_unread_push_actions_for_user_in_range_for_email unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering) self.user_id, start, self.max_stream_ordering
)
soonest_due_at = None soonest_due_at = None
if not unprocessed:
yield self.save_last_stream_ordering_and_success(self.max_stream_ordering)
return
for push_action in unprocessed: for push_action in unprocessed:
received_at = push_action['received_ts'] received_at = push_action['received_ts']
if received_at is None: if received_at is None:
@@ -198,10 +190,7 @@ class EmailPusher(object):
soonest_due_at = should_notify_at soonest_due_at = should_notify_at
if self.timed_call is not None: if self.timed_call is not None:
try: self.timed_call.cancel()
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None self.timed_call = None
if soonest_due_at is not None: if soonest_due_at is not None:

View File

@@ -16,7 +16,6 @@
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
import push_rule_evaluator import push_rule_evaluator
@@ -110,11 +109,7 @@ class HttpPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
try: self.timed_call.cancel()
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def _process(self): def _process(self):
@@ -146,8 +141,7 @@ class HttpPusher(object):
run once per pusher. run once per pusher.
""" """
fn = self.store.get_unread_push_actions_for_user_in_range_for_http unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
unprocessed = yield fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@@ -245,7 +239,7 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge): def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event( ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id self.state_handler, event, self.user_id
) )
d = { d = {

View File

@@ -22,7 +22,7 @@ from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute
from synapse.push.presentable_names import ( from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event, descriptor_from_member_events calculate_room_name, name_from_member_event, descriptor_from_member_events
) )
from synapse.types import UserID from synapse.types import UserID
@@ -139,7 +139,7 @@ class Mailer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _fetch_room_state(room_id): def _fetch_room_state(room_id):
room_state = yield self.state_handler.get_current_state_ids(room_id) room_state = yield self.state_handler.get_current_state(room_id)
state_by_room[room_id] = room_state state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email # Run at most 3 of these at once: sync does 10 at a time but email
@@ -159,12 +159,11 @@ class Mailer(object):
) )
rooms.append(roomvars) rooms.append(roomvars)
reason['room_name'] = yield calculate_room_name( reason['room_name'] = calculate_room_name(
self.store, state_by_room[reason['room_id']], user_id, state_by_room[reason['room_id']], user_id, fallback_to_members=True
fallback_to_members=True
) )
summary_text = yield self.make_summary_text( summary_text = self.make_summary_text(
notifs_by_room, state_by_room, notif_events, user_id, reason notifs_by_room, state_by_room, notif_events, user_id, reason
) )
@@ -204,15 +203,12 @@ class Mailer(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids): def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state):
my_member_event_id = room_state_ids[("m.room.member", user_id)] my_member_event = room_state[("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
is_invite = my_member_event.content["membership"] == "invite" is_invite = my_member_event.content["membership"] == "invite"
room_name = yield calculate_room_name(self.store, room_state_ids, user_id)
room_vars = { room_vars = {
"title": room_name, "title": calculate_room_name(room_state, user_id),
"hash": string_ordinal_total(room_id), # See sender avatar hash "hash": string_ordinal_total(room_id), # See sender avatar hash
"notifs": [], "notifs": [],
"invite": is_invite, "invite": is_invite,
@@ -222,7 +218,7 @@ class Mailer(object):
if not is_invite: if not is_invite:
for n in notifs: for n in notifs:
notifvars = yield self.get_notif_vars( notifvars = yield self.get_notif_vars(
n, user_id, notif_events[n['event_id']], room_state_ids n, user_id, notif_events[n['event_id']], room_state
) )
# merge overlapping notifs together. # merge overlapping notifs together.
@@ -247,7 +243,7 @@ class Mailer(object):
defer.returnValue(room_vars) defer.returnValue(room_vars)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids): def get_notif_vars(self, notif, user_id, notif_event, room_state):
results = yield self.store.get_events_around( results = yield self.store.get_events_around(
notif['room_id'], notif['event_id'], notif['room_id'], notif['event_id'],
before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
@@ -265,19 +261,17 @@ class Mailer(object):
the_events.append(notif_event) the_events.append(notif_event)
for event in the_events: for event in the_events:
messagevars = yield self.get_message_vars(notif, event, room_state_ids) messagevars = self.get_message_vars(notif, event, room_state)
if messagevars is not None: if messagevars is not None:
ret['messages'].append(messagevars) ret['messages'].append(messagevars)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks def get_message_vars(self, notif, event, room_state):
def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message: if event.type != EventTypes.Message:
return return None
sender_state_event_id = room_state_ids[("m.room.member", event.sender)] sender_state_event = room_state[("m.room.member", event.sender)]
sender_state_event = yield self.store.get_event(sender_state_event_id)
sender_name = name_from_member_event(sender_state_event) sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content.get("avatar_url") sender_avatar_url = sender_state_event.content.get("avatar_url")
@@ -305,7 +299,7 @@ class Mailer(object):
if "body" in event.content: if "body" in event.content:
ret["body_text_plain"] = event.content["body"] ret["body_text_plain"] = event.content["body"]
defer.returnValue(ret) return ret
def add_text_message_vars(self, messagevars, event): def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format") msgformat = event.content.get("format")
@@ -327,8 +321,7 @@ class Mailer(object):
return messagevars return messagevars
@defer.inlineCallbacks def make_summary_text(self, notifs_by_room, state_by_room,
def make_summary_text(self, notifs_by_room, room_state_ids,
notif_events, user_id, reason): notif_events, user_id, reason):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
# Only one room has new stuff # Only one room has new stuff
@@ -337,63 +330,56 @@ class Mailer(object):
# If the room has some kind of name, use it, but we don't # If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll # want the generated-from-names one here otherwise we'll
# end up with, "new message from Bob in the Bob room" # end up with, "new message from Bob in the Bob room"
room_name = yield calculate_room_name( room_name = calculate_room_name(
self.store, room_state_ids[room_id], user_id, fallback_to_members=False state_by_room[room_id], user_id, fallback_to_members=False
) )
my_member_event_id = room_state_ids[room_id][("m.room.member", user_id)] my_member_event = state_by_room[room_id][("m.room.member", user_id)]
my_member_event = yield self.store.get_event(my_member_event_id)
if my_member_event.content["membership"] == "invite": if my_member_event.content["membership"] == "invite":
inviter_member_event_id = room_state_ids[room_id][ inviter_member_event = state_by_room[room_id][
("m.room.member", my_member_event.sender) ("m.room.member", my_member_event.sender)
] ]
inviter_member_event = yield self.store.get_event(
inviter_member_event_id
)
inviter_name = name_from_member_event(inviter_member_event) inviter_name = name_from_member_event(inviter_member_event)
if room_name is None: if room_name is None:
defer.returnValue(INVITE_FROM_PERSON % { return INVITE_FROM_PERSON % {
"person": inviter_name, "person": inviter_name,
"app": self.app_name "app": self.app_name
}) }
else: else:
defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % { return INVITE_FROM_PERSON_TO_ROOM % {
"person": inviter_name, "person": inviter_name,
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
}) }
sender_name = None sender_name = None
if len(notifs_by_room[room_id]) == 1: if len(notifs_by_room[room_id]) == 1:
# There is just the one notification, so give some detail # There is just the one notification, so give some detail
event = notif_events[notifs_by_room[room_id][0]["event_id"]] event = notif_events[notifs_by_room[room_id][0]["event_id"]]
if ("m.room.member", event.sender) in room_state_ids[room_id]: if ("m.room.member", event.sender) in state_by_room[room_id]:
state_event_id = room_state_ids[room_id][ state_event = state_by_room[room_id][("m.room.member", event.sender)]
("m.room.member", event.sender)
]
state_event = yield self.store.get_event(state_event_id)
sender_name = name_from_member_event(state_event) sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None: if sender_name is not None and room_name is not None:
defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % { return MESSAGE_FROM_PERSON_IN_ROOM % {
"person": sender_name, "person": sender_name,
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
}) }
elif sender_name is not None: elif sender_name is not None:
defer.returnValue(MESSAGE_FROM_PERSON % { return MESSAGE_FROM_PERSON % {
"person": sender_name, "person": sender_name,
"app": self.app_name, "app": self.app_name,
}) }
else: else:
# There's more than one notification for this room, so just # There's more than one notification for this room, so just
# say there are several # say there are several
if room_name is not None: if room_name is not None:
defer.returnValue(MESSAGES_IN_ROOM % { return MESSAGES_IN_ROOM % {
"room": room_name, "room": room_name,
"app": self.app_name, "app": self.app_name,
}) }
else: else:
# If the room doesn't have a name, say who the messages # If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
@@ -402,24 +388,22 @@ class Mailer(object):
for n in notifs_by_room[room_id] for n in notifs_by_room[room_id]
])) ]))
member_events = yield self.store.get_events([ return MESSAGES_FROM_PERSON % {
room_state_ids[room_id][("m.room.member", s)] "person": descriptor_from_member_events([
for s in sender_ids state_by_room[room_id][("m.room.member", s)]
]) for s in sender_ids
]),
defer.returnValue(MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name, "app": self.app_name,
}) }
else: else:
# Stuff's happened in multiple different rooms # Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail # ...but we still refer to the 'reason' room which triggered the mail
if reason['room_name'] is not None: if reason['room_name'] is not None:
defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % { return MESSAGES_IN_ROOM_AND_OTHERS % {
"room": reason['room_name'], "room": reason['room_name'],
"app": self.app_name, "app": self.app_name,
}) }
else: else:
# If the reason room doesn't have a name, say who the messages # If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room" # are from explicitly to avoid, "messages in the Bob room"
@@ -428,15 +412,13 @@ class Mailer(object):
for n in notifs_by_room[reason['room_id']] for n in notifs_by_room[reason['room_id']]
])) ]))
member_events = yield self.store.get_events([ return MESSAGES_FROM_PERSON_AND_OTHERS % {
room_state_ids[room_id][("m.room.member", s)] "person": descriptor_from_member_events([
for s in sender_ids state_by_room[reason['room_id']][("m.room.member", s)]
]) for s in sender_ids
]),
defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name, "app": self.app_name,
}) }
def make_room_link(self, room_id): def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS

View File

@@ -14,18 +14,17 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.push.presentable_names import ( from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event calculate_room_name, name_from_member_event
) )
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_badge_count(store, user_id): def get_badge_count(store, user_id):
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ invites, joins = yield defer.gatherResults([
preserve_fn(store.get_invited_rooms_for_user)(user_id), store.get_invited_rooms_for_user(user_id),
preserve_fn(store.get_rooms_for_user)(user_id), store.get_rooms_for_user(user_id),
], consumeErrors=True)) ], consumeErrors=True)
my_receipts_by_room = yield store.get_receipts_for_user( my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read", user_id, "m.read",
@@ -49,22 +48,21 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(store, state_handler, ev, user_id): def get_context_for_event(state_handler, ev, user_id):
ctx = {} ctx = {}
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id) room_state = yield state_handler.get_current_state(ev.room_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.namer, an alias or
# a list of people in the room # a list of people in the room
name = yield calculate_room_name( name = calculate_room_name(
store, room_state_ids, user_id, fallback_to_single_member=False room_state, user_id, fallback_to_single_member=False
) )
if name: if name:
ctx['name'] = name ctx['name'] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event = room_state[("m.room.member", ev.sender)]
sender_state_event = yield store.get_event(sender_state_event_id)
ctx['sender_display_name'] = name_from_member_event(sender_state_event) ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx) defer.returnValue(ctx)

View File

@@ -17,7 +17,7 @@
from twisted.internet import defer from twisted.internet import defer
import pusher import pusher
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.util.logcontext import preserve_fn
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
import logging import logging
@@ -102,14 +102,14 @@ class PusherPool:
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user(self, user_id, except_access_token_id=None): def remove_pushers_by_user(self, user_id, except_token_ids=[]):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access tokens id %r", "Removing all pushers for user %s except access tokens ids %r",
user_id, except_access_token_id user_id, except_token_ids
) )
for p in all: for p in all:
if p['user_name'] == user_id and p['access_token'] != except_access_token_id: if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']
@@ -130,12 +130,10 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
preserve_fn(p.on_new_notifications)( p.on_new_notifications(min_stream_id, max_stream_id)
min_stream_id, max_stream_id
)
) )
yield preserve_context_over_deferred(defer.gatherResults(deferreds)) yield defer.gatherResults(deferreds)
except: except:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@@ -157,10 +155,10 @@ class PusherPool:
if u in self.pushers: if u in self.pushers:
for p in self.pushers[u].values(): for p in self.pushers[u].values():
deferreds.append( deferreds.append(
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id) p.on_new_receipts(min_stream_id, max_stream_id)
) )
yield preserve_context_over_deferred(defer.gatherResults(deferreds)) yield defer.gatherResults(deferreds)
except: except:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")

View File

@@ -36,7 +36,6 @@ REQUIREMENTS = {
"blist": ["blist"], "blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"], "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"], "pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
@@ -52,9 +51,6 @@ CONDITIONAL_REQUIREMENTS = {
"ldap": { "ldap": {
"ldap3>=1.0": ["ldap3>=1.0"], "ldap3>=1.0": ["ldap3>=1.0"],
}, },
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
} }

View File

@@ -17,7 +17,6 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource from synapse.replication.presence_resource import PresenceResource
from synapse.api.errors import SynapseError
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@@ -41,9 +40,7 @@ STREAM_NAMES = (
("backfill",), ("backfill",),
("push_rules",), ("push_rules",),
("pushers",), ("pushers",),
("caches",), ("state",),
("to_device",),
("public_rooms",),
) )
@@ -73,7 +70,6 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers. * "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules. * "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers. * "pushers": Per user changes to their pushers.
* "caches": Cache invalidations.
The API takes two additional query parameters: The API takes two additional query parameters:
@@ -132,8 +128,7 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token() pushers_token = self.store.get_pushers_stream_token()
caches_token = self.store.get_cache_stream_token() state_token = self.store.get_state_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@@ -144,10 +139,7 @@ class ReplicationResource(Resource):
backfill_token, backfill_token,
push_rules_token, push_rules_token,
pushers_token, pushers_token,
0, # State stream is no longer a thing state_token,
caches_token,
int(stream_token.to_device_key),
int(public_rooms_token),
)) ))
@request_handler() @request_handler()
@@ -167,8 +159,7 @@ class ReplicationResource(Resource):
def replicate(): def replicate():
return self.replicate(request_streams, limit) return self.replicate(request_streams, limit)
writer = yield self.notifier.wait_for_replication(replicate, timeout) result = yield self.notifier.wait_for_replication(replicate, timeout)
result = writer.finish()
for stream_name, stream_content in result.items(): for stream_name, stream_content in result.items():
logger.info( logger.info(
@@ -186,10 +177,7 @@ class ReplicationResource(Resource):
def replicate(self, request_streams, limit): def replicate(self, request_streams, limit):
writer = _Writer() writer = _Writer()
current_token = yield self.current_replication_token() current_token = yield self.current_replication_token()
logger.debug("Replicating up to %r", current_token) logger.info("Replicating up to %r", current_token)
if limit == 0:
raise SynapseError(400, "Limit cannot be 0")
yield self.account_data(writer, current_token, limit, request_streams) yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams) yield self.events(writer, current_token, limit, request_streams)
@@ -199,13 +187,11 @@ class ReplicationResource(Resource):
yield self.receipts(writer, current_token, limit, request_streams) yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams) yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams) yield self.pushers(writer, current_token, limit, request_streams)
yield self.caches(writer, current_token, limit, request_streams) yield self.state(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
logger.debug("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer) defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams): def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams") request_token = request_streams.get("streams")
@@ -242,48 +228,27 @@ class ReplicationResource(Resource):
request_events = current_token.events request_events = current_token.events
if request_backfill is None: if request_backfill is None:
request_backfill = current_token.backfill request_backfill = current_token.backfill
no_new_tokens = (
request_events == current_token.events
and request_backfill == current_token.backfill
)
if no_new_tokens:
return
res = yield self.store.get_all_new_events( res = yield self.store.get_all_new_events(
request_backfill, request_events, request_backfill, request_events,
current_token.backfill, current_token.events, current_token.backfill, current_token.events,
limit limit
) )
writer.write_header_and_rows("events", res.new_forward_events, (
upto_events_token = _position_from_rows( "position", "internal", "json", "state_group"
res.new_forward_events, current_token.events ))
) writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group"
upto_backfill_token = _position_from_rows( ))
res.new_backfill_events, current_token.backfill
)
if request_events != upto_events_token:
writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group"
), position=upto_events_token)
if request_backfill != upto_backfill_token:
writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group",
), position=upto_backfill_token)
writer.write_header_and_rows( writer.write_header_and_rows(
"forward_ex_outliers", res.forward_ex_outliers, "forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group"), ("position", "event_id", "state_group")
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"backward_ex_outliers", res.backward_ex_outliers, "backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group"), ("position", "event_id", "state_group")
) )
writer.write_header_and_rows( writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",), "state_resets", res.state_resets, ("position",)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -292,38 +257,29 @@ class ReplicationResource(Resource):
request_presence = request_streams.get("presence") request_presence = request_streams.get("presence")
if request_presence is not None and request_presence != current_position: if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates( presence_rows = yield self.presence_handler.get_all_presence_updates(
request_presence, current_position request_presence, current_position
) )
upto_token = _position_from_rows(presence_rows, current_position)
writer.write_header_and_rows("presence", presence_rows, ( writer.write_header_and_rows("presence", presence_rows, (
"position", "user_id", "state", "last_active_ts", "position", "user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts", "last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active", "status_msg", "currently_active",
), position=upto_token) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def typing(self, writer, current_token, request_streams): def typing(self, writer, current_token, request_streams):
current_position = current_token.typing current_position = current_token.presence
request_typing = request_streams.get("typing") request_typing = request_streams.get("typing")
if request_typing is not None and request_typing != current_position: if request_typing is not None:
# If they have a higher token than current max, we can assume that
# they had been talking to a previous instance of the master. Since
# we reset the token on restart, the best (but hacky) thing we can
# do is to simply resend down all the typing notifications.
if request_typing > current_position:
request_typing = 0
typing_rows = yield self.typing_handler.get_all_typing_updates( typing_rows = yield self.typing_handler.get_all_typing_updates(
request_typing, current_position request_typing, current_position
) )
upto_token = _position_from_rows(typing_rows, current_position)
writer.write_header_and_rows("typing", typing_rows, ( writer.write_header_and_rows("typing", typing_rows, (
"position", "room_id", "typing" "position", "room_id", "typing"
), position=upto_token) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def receipts(self, writer, current_token, limit, request_streams): def receipts(self, writer, current_token, limit, request_streams):
@@ -331,14 +287,13 @@ class ReplicationResource(Resource):
request_receipts = request_streams.get("receipts") request_receipts = request_streams.get("receipts")
if request_receipts is not None and request_receipts != current_position: if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts( receipts_rows = yield self.store.get_all_updated_receipts(
request_receipts, current_position, limit request_receipts, current_position, limit
) )
upto_token = _position_from_rows(receipts_rows, current_position)
writer.write_header_and_rows("receipts", receipts_rows, ( writer.write_header_and_rows("receipts", receipts_rows, (
"position", "room_id", "receipt_type", "user_id", "event_id", "data" "position", "room_id", "receipt_type", "user_id", "event_id", "data"
), position=upto_token) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def account_data(self, writer, current_token, limit, request_streams): def account_data(self, writer, current_token, limit, request_streams):
@@ -353,36 +308,23 @@ class ReplicationResource(Resource):
user_account_data = current_position user_account_data = current_position
if room_account_data is None: if room_account_data is None:
room_account_data = current_position room_account_data = current_position
no_new_tokens = (
user_account_data == current_position
and room_account_data == current_position
)
if no_new_tokens:
return
user_rows, room_rows = yield self.store.get_all_updated_account_data( user_rows, room_rows = yield self.store.get_all_updated_account_data(
user_account_data, room_account_data, current_position, limit user_account_data, room_account_data, current_position, limit
) )
upto_users_token = _position_from_rows(user_rows, current_position)
upto_rooms_token = _position_from_rows(room_rows, current_position)
writer.write_header_and_rows("user_account_data", user_rows, ( writer.write_header_and_rows("user_account_data", user_rows, (
"position", "user_id", "type", "content" "position", "user_id", "type", "content"
), position=upto_users_token) ))
writer.write_header_and_rows("room_account_data", room_rows, ( writer.write_header_and_rows("room_account_data", room_rows, (
"position", "user_id", "room_id", "type", "content" "position", "user_id", "room_id", "type", "content"
), position=upto_rooms_token) ))
if tag_account_data is not None: if tag_account_data is not None:
tag_rows = yield self.store.get_all_updated_tags( tag_rows = yield self.store.get_all_updated_tags(
tag_account_data, current_position, limit tag_account_data, current_position, limit
) )
upto_tag_token = _position_from_rows(tag_rows, current_position)
writer.write_header_and_rows("tag_account_data", tag_rows, ( writer.write_header_and_rows("tag_account_data", tag_rows, (
"position", "user_id", "room_id", "tags" "position", "user_id", "room_id", "tags"
), position=upto_tag_token) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules(self, writer, current_token, limit, request_streams): def push_rules(self, writer, current_token, limit, request_streams):
@@ -390,15 +332,14 @@ class ReplicationResource(Resource):
push_rules = request_streams.get("push_rules") push_rules = request_streams.get("push_rules")
if push_rules is not None and push_rules != current_position: if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates( rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit push_rules, current_position, limit
) )
upto_token = _position_from_rows(rows, current_position)
writer.write_header_and_rows("push_rules", rows, ( writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op", "position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions" "priority_class", "priority", "conditions", "actions"
), position=upto_token) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def pushers(self, writer, current_token, limit, request_streams): def pushers(self, writer, current_token, limit, request_streams):
@@ -406,64 +347,37 @@ class ReplicationResource(Resource):
pushers = request_streams.get("pushers") pushers = request_streams.get("pushers")
if pushers is not None and pushers != current_position: if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers( updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit pushers, current_position, limit
) )
upto_token = _position_from_rows(updated, current_position)
writer.write_header_and_rows("pushers", updated, ( writer.write_header_and_rows("pushers", updated, (
"position", "user_id", "access_token", "profile_tag", "kind", "position", "user_id", "access_token", "profile_tag", "kind",
"app_id", "app_display_name", "device_display_name", "pushkey", "app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data" "ts", "lang", "data"
), position=upto_token) ))
writer.write_header_and_rows("deleted_pushers", deleted, ( writer.write_header_and_rows("deleted_pushers", deleted, (
"position", "user_id", "app_id", "pushkey" "position", "user_id", "app_id", "pushkey"
), position=upto_token) ))
@defer.inlineCallbacks @defer.inlineCallbacks
def caches(self, writer, current_token, limit, request_streams): def state(self, writer, current_token, limit, request_streams):
current_position = current_token.caches current_position = current_token.state
caches = request_streams.get("caches") state = request_streams.get("state")
if caches is not None and caches != current_position: if state is not None:
updated_caches = yield self.store.get_all_updated_caches( state_groups, state_group_state = (
caches, current_position, limit yield self.store.get_all_new_state_groups(
state, current_position, limit
)
) )
upto_token = _position_from_rows(updated_caches, current_position) writer.write_header_and_rows("state_groups", state_groups, (
writer.write_header_and_rows("caches", updated_caches, ( "position", "room_id", "event_id"
"position", "cache_func", "keys", "invalidation_ts" ))
), position=upto_token) writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
@defer.inlineCallbacks ))
def to_device(self, writer, current_token, limit, request_streams):
current_position = current_token.to_device
to_device = request_streams.get("to_device")
if to_device is not None and to_device != current_position:
to_device_rows = yield self.store.get_all_new_device_messages(
to_device, current_position, limit
)
upto_token = _position_from_rows(to_device_rows, current_position)
writer.write_header_and_rows("to_device", to_device_rows, (
"position", "user_id", "device_id", "message_json"
), position=upto_token)
@defer.inlineCallbacks
def public_rooms(self, writer, current_token, limit, request_streams):
current_position = current_token.public_rooms
public_rooms = request_streams.get("public_rooms")
if public_rooms is not None and public_rooms != current_position:
public_rooms_rows = yield self.store.get_all_new_public_rooms(
public_rooms, current_position, limit
)
upto_token = _position_from_rows(public_rooms_rows, current_position)
writer.write_header_and_rows("public_rooms", public_rooms_rows, (
"position", "room_id", "visibility"
), position=upto_token)
class _Writer(object): class _Writer(object):
@@ -473,11 +387,11 @@ class _Writer(object):
self.total = 0 self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None): def write_header_and_rows(self, name, rows, fields, position=None):
if not rows:
return
if position is None: if position is None:
if rows: position = rows[-1][0]
position = rows[-1][0]
else:
return
self.streams[name] = { self.streams[name] = {
"position": position if type(position) is int else str(position), "position": position if type(position) is int else str(position),
@@ -487,16 +401,13 @@ class _Writer(object):
self.total += len(rows) self.total += len(rows)
def __nonzero__(self):
return bool(self.total)
def finish(self): def finish(self):
return self.streams return self.streams
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms", "push_rules", "pushers", "state"
))): ))):
__slots__ = [] __slots__ = []
@@ -511,20 +422,3 @@ class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
def __str__(self): def __str__(self):
return "_".join(str(value) for value in self) return "_".join(str(value) for value in self)
def _position_from_rows(rows, current_position):
"""Calculates a position to return for a stream. Ideally we want to return the
position of the last row, as that will be the most correct. However, if there
are no rows we fall back to using the current position to stop us from
repeatedly hitting the storage layer unncessarily thinking there are updates.
(Not all advances of the token correspond to an actual update)
We can't just always return the current position, as we often limit the
number of rows we replicate, and so the stream may lag. The assumption is
that if the storage layer returns no new rows then we are not lagging and
we are at the `current_position`.
"""
if rows:
return rows[-1][0]
return current_position

View File

@@ -14,43 +14,15 @@
# limitations under the License. # limitations under the License.
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import PostgresEngine
from twisted.internet import defer from twisted.internet import defer
from ._slaved_id_tracker import SlavedIdTracker
import logging
logger = logging.getLogger(__name__)
class BaseSlavedStore(SQLBaseStore): class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs) super(BaseSlavedStore, self).__init__(hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id",
)
else:
self._cache_id_gen = None
def stream_positions(self): def stream_positions(self):
pos = {} return {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
return pos
def process_replication(self, result): def process_replication(self, result):
stream = result.get("caches")
if stream:
for row in stream["rows"]:
(
position, cache_func, keys, invalidation_ts,
) = row
try:
getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func)
self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)

View File

@@ -28,13 +28,3 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
get_app_service_by_token = DataStore.get_app_service_by_token.__func__ get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__ get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
get_app_services = DataStore.get_app_services.__func__
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
create_appservice_txn = DataStore.create_appservice_txn.__func__
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
_get_last_txn = DataStore._get_last_txn.__func__
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
get_appservice_state = DataStore.get_appservice_state.__func__
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
set_appservice_state = DataStore.set_appservice_state.__func__

View File

@@ -1,53 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_max_stream_id", "stream_id",
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
self._device_inbox_id_gen.get_current_token()
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
delete_messages_for_device = DataStore.delete_messages_for_device.__func__
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("to_device")
if stream:
self._device_inbox_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
return super(SlavedDeviceInboxStore, self).process_replication(result)

View File

@@ -1,23 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
]

Some files were not shown because too many files have changed in this diff Show More