Compare commits

..

3 Commits

Author SHA1 Message Date
Erik Johnston
6c11bb956c Fix pypy repr for DomainSpecificString 2016-04-11 10:58:24 +01:00
Erik Johnston
58de897c77 Merge branch 'develop' into HEAD 2016-04-11 10:40:49 +01:00
TimePath
06602461e7 Support PyPy 2016-04-10 18:30:56 +10:00
325 changed files with 7522 additions and 24417 deletions

View File

@@ -1,471 +1,3 @@
Changes in synapse v0.18.1 (2016-10-0)
======================================
No changes since v0.18.1-rc1
Changes in synapse v0.18.1-rc1 (2016-09-30)
===========================================
Features:
* Add total_room_count_estimate to ``/publicRooms`` (PR #1133)
Changes:
* Time out typing over federation (PR #1140)
* Restructure LDAP authentication (PR #1153)
Bug fixes:
* Fix 3pid invites when server is already in the room (PR #1136)
* Fix upgrading with SQLite taking lots of CPU for a few days
after upgrade (PR #1144)
* Fix upgrading from very old database versions (PR #1145)
* Fix port script to work with recently added tables (PR #1146)
Changes in synapse v0.18.0 (2016-09-19)
=======================================
The release includes major changes to the state storage database schemas, which
significantly reduce database size. Synapse will attempt to upgrade the current
data in the background. Servers with large SQLite database may experience
degradation of performance while this upgrade is in progress, therefore you may
want to consider migrating to using Postgres before upgrading very large SQLite
databases
Changes:
* Make public room search case insensitive (PR #1127)
Bug fixes:
* Fix and clean up publicRooms pagination (PR #1129)
Changes in synapse v0.18.0-rc1 (2016-09-16)
===========================================
Features:
* Add ``only=highlight`` on ``/notifications`` (PR #1081)
* Add server param to /publicRooms (PR #1082)
* Allow clients to ask for the whole of a single state event (PR #1094)
* Add is_direct param to /createRoom (PR #1108)
* Add pagination support to publicRooms (PR #1121)
* Add very basic filter API to /publicRooms (PR #1126)
* Add basic direct to device messaging support for E2E (PR #1074, #1084, #1104,
#1111)
Changes:
* Move to storing state_groups_state as deltas, greatly reducing DB size (PR
#1065)
* Reduce amount of state pulled out of the DB during common requests (PR #1069)
* Allow PDF to be rendered from media repo (PR #1071)
* Reindex state_groups_state after pruning (PR #1085)
* Clobber EDUs in send queue (PR #1095)
* Conform better to the CAS protocol specification (PR #1100)
* Limit how often we ask for keys from dead servers (PR #1114)
Bug fixes:
* Fix /notifications API when used with ``from`` param (PR #1080)
* Fix backfill when cannot find an event. (PR #1107)
Changes in synapse v0.17.3 (2016-09-09)
=======================================
This release fixes a major bug that stopped servers from handling rooms with
over 1000 members.
Changes in synapse v0.17.2 (2016-09-08)
=======================================
This release contains security bug fixes. Please upgrade.
No changes since v0.17.2-rc1
Changes in synapse v0.17.2-rc1 (2016-09-05)
===========================================
Features:
* Start adding store-and-forward direct-to-device messaging (PR #1046, #1050,
#1062, #1066)
Changes:
* Avoid pulling the full state of a room out so often (PR #1047, #1049, #1063,
#1068)
* Don't notify for online to online presence transitions. (PR #1054)
* Occasionally persist unpersisted presence updates (PR #1055)
* Allow application services to have an optional 'url' (PR #1056)
* Clean up old sent transactions from DB (PR #1059)
Bug fixes:
* Fix None check in backfill (PR #1043)
* Fix membership changes to be idempotent (PR #1067)
* Fix bug in get_pdu where it would sometimes return events with incorrect
signature
Changes in synapse v0.17.1 (2016-08-24)
=======================================
Changes:
* Delete old received_transactions rows (PR #1038)
* Pass through user-supplied content in /join/$room_id (PR #1039)
Bug fixes:
* Fix bug with backfill (PR #1040)
Changes in synapse v0.17.1-rc1 (2016-08-22)
===========================================
Features:
* Add notification API (PR #1028)
Changes:
* Don't print stack traces when failing to get remote keys (PR #996)
* Various federation /event/ perf improvements (PR #998)
* Only process one local membership event per room at a time (PR #1005)
* Move default display name push rule (PR #1011, #1023)
* Fix up preview URL API. Add tests. (PR #1015)
* Set ``Content-Security-Policy`` on media repo (PR #1021)
* Make notify_interested_services faster (PR #1022)
* Add usage stats to prometheus monitoring (PR #1037)
Bug fixes:
* Fix token login (PR #993)
* Fix CAS login (PR #994, #995)
* Fix /sync to not clobber status_msg (PR #997)
* Fix redacted state events to include prev_content (PR #1003)
* Fix some bugs in the auth/ldap handler (PR #1007)
* Fix backfill request to limit URI length, so that remotes don't reject the
requests due to path length limits (PR #1012)
* Fix AS push code to not send duplicate events (PR #1025)
Changes in synapse v0.17.0 (2016-08-08)
=======================================
This release contains significant security bug fixes regarding authenticating
events received over federation. PLEASE UPGRADE.
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Changes:
* Add federation /version API (PR #990)
* Make psutil dependency optional (PR #992)
Bug fixes:
* Fix URL preview API to exclude HTML comments in description (PR #988)
* Fix error handling of remote joins (PR #991)
Changes in synapse v0.17.0-rc4 (2016-08-05)
===========================================
Changes:
* Change the way we summarize URLs when previewing (PR #973)
* Add new ``/state_ids/`` federation API (PR #979)
* Speed up processing of ``/state/`` response (PR #986)
Bug fixes:
* Fix event persistence when event has already been partially persisted
(PR #975, #983, #985)
* Fix port script to also copy across backfilled events (PR #982)
Changes in synapse v0.17.0-rc3 (2016-08-02)
===========================================
Changes:
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
* Add some basic admin API docs (PR #963)
Bug fixes:
* Send the correct host header when fetching keys (PR #941)
* Fix joining a room that has missing auth events (PR #964)
* Fix various push bugs (PR #966, #970)
* Fix adding emails on registration (PR #968)
Changes in synapse v0.17.0-rc2 (2016-08-02)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08)
==========================================
THIS IS A CRITICAL SECURITY UPDATE.
This fixes a bug which allowed users' accounts to be accessed by unauthorised
users.
Changes in synapse v0.16.1 (2016-06-20)
=======================================
Bug fixes:
* Fix assorted bugs in ``/preview_url`` (PR #872)
* Fix TypeError when setting unicode passwords (PR #873)
Performance improvements:
* Turn ``use_frozen_events`` off by default (PR #877)
* Disable responding with canonical json for federation (PR #878)
Changes in synapse v0.16.1-rc1 (2016-06-15)
===========================================
Features: None
Changes:
* Log requester for ``/publicRoom`` endpoints when possible (PR #856)
* 502 on ``/thumbnail`` when can't connect to remote server (PR #862)
* Linearize fetching of gaps on incoming events (PR #871)
Bugs fixes:
* Fix bug where rooms where marked as published by default (PR #857)
* Fix bug where joining room with an event with invalid sender (PR #868)
* Fix bug where backfilled events were sent down sync streams (PR #869)
* Fix bug where outgoing connections could wedge indefinitely, causing push
notifications to be unreliable (PR #870)
Performance improvements:
* Improve ``/publicRooms`` performance(PR #859)
Changes in synapse v0.16.0 (2016-06-09)
=======================================
NB: As of v0.14 all AS config files must have an ID field.
Bug fixes:
* Don't make rooms published by default (PR #857)
Changes in synapse v0.16.0-rc2 (2016-06-08)
===========================================
Features:
* Add configuration option for tuning GC via ``gc.set_threshold`` (PR #849)
Changes:
* Record metrics about GC (PR #771, #847, #852)
* Add metric counter for number of persisted events (PR #841)
Bug fixes:
* Fix 'From' header in email notifications (PR #843)
* Fix presence where timeouts were not being fired for the first 8h after
restarts (PR #842)
* Fix bug where synapse sent malformed transactions to AS's when retrying
transactions (Commits 310197b, 8437906)
Performance improvements:
* Remove event fetching from DB threads (PR #835)
* Change the way we cache events (PR #836)
* Add events to cache when we persist them (PR #840)
Changes in synapse v0.16.0-rc1 (2016-06-03)
===========================================
Version 0.15 was not released. See v0.15.0-rc1 below for additional changes.
Features:
* Add email notifications for missed messages (PR #759, #786, #799, #810, #815,
#821)
* Add a ``url_preview_ip_range_whitelist`` config param (PR #760)
* Add /report endpoint (PR #762)
* Add basic ignore user API (PR #763)
* Add an openidish mechanism for proving that you own a given user_id (PR #765)
* Allow clients to specify a server_name to avoid 'No known servers' (PR #794)
* Add secondary_directory_servers option to fetch room list from other servers
(PR #808, #813)
Changes:
* Report per request metrics for all of the things using request_handler (PR
#756)
* Correctly handle ``NULL`` password hashes from the database (PR #775)
* Allow receipts for events we haven't seen in the db (PR #784)
* Make synctl read a cache factor from config file (PR #785)
* Increment badge count per missed convo, not per msg (PR #793)
* Special case m.room.third_party_invite event auth to match invites (PR #814)
Bug fixes:
* Fix typo in event_auth servlet path (PR #757)
* Fix password reset (PR #758)
Performance improvements:
* Reduce database inserts when sending transactions (PR #767)
* Queue events by room for persistence (PR #768)
* Add cache to ``get_user_by_id`` (PR #772)
* Add and use ``get_domain_from_id`` (PR #773)
* Use tree cache for ``get_linearized_receipts_for_room`` (PR #779)
* Remove unused indices (PR #782)
* Add caches to ``bulk_get_push_rules*`` (PR #804)
* Cache ``get_event_reference_hashes`` (PR #806)
* Add ``get_users_with_read_receipts_in_room`` cache (PR #809)
* Use state to calculate ``get_users_in_room`` (PR #811)
* Load push rules in storage layer so that they get cached (PR #825)
* Make ``get_joined_hosts_for_room`` use get_users_in_room (PR #828)
* Poke notifier on next reactor tick (PR #829)
* Change CacheMetrics to be quicker (PR #830)
Changes in synapse v0.15.0-rc1 (2016-04-26)
===========================================
Features:
* Add login support for Javascript Web Tokens, thanks to Niklas Riekenbrauck
(PR #671,#687)
* Add URL previewing support (PR #688)
* Add login support for LDAP, thanks to Christoph Witzany (PR #701)
* Add GET endpoint for pushers (PR #716)
Changes:
* Never notify for member events (PR #667)
* Deduplicate identical ``/sync`` requests (PR #668)
* Require user to have left room to forget room (PR #673)
* Use DNS cache if within TTL (PR #677)
* Let users see their own leave events (PR #699)
* Deduplicate membership changes (PR #700)
* Increase performance of pusher code (PR #705)
* Respond with error status 504 if failed to talk to remote server (PR #731)
* Increase search performance on postgres (PR #745)
Bug fixes:
* Fix bug where disabling all notifications still resulted in push (PR #678)
* Fix bug where users couldn't reject remote invites if remote refused (PR #691)
* Fix bug where synapse attempted to backfill from itself (PR #693)
* Fix bug where profile information was not correctly added when joining remote
rooms (PR #703)
* Fix bug where register API required incorrect key name for AS registration
(PR #727)
Changes in synapse v0.14.0 (2016-03-30) Changes in synapse v0.14.0 (2016-03-30)
======================================= =======================================
@@ -979,7 +511,7 @@ Configuration:
* Add support for changing the bind host of the metrics listener via the * Add support for changing the bind host of the metrics listener via the
``metrics_bind_host`` option. ``metrics_bind_host`` option.
Changes in synapse v0.9.0-r5 (2015-05-21) Changes in synapse v0.9.0-r5 (2015-05-21)
========================================= =========================================
@@ -1321,7 +853,7 @@ See UPGRADE for information about changes to the client server API, including
breaking backwards compatibility with VoIP calls and registration API. breaking backwards compatibility with VoIP calls and registration API.
Homeserver: Homeserver:
* When a user changes their displayname or avatar the server will now update * When a user changes their displayname or avatar the server will now update
all their join states to reflect this. all their join states to reflect this.
* The server now adds "age" key to events to indicate how old they are. This * The server now adds "age" key to events to indicate how old they are. This
is clock independent, so at no point does any server or webclient have to is clock independent, so at no point does any server or webclient have to
@@ -1379,7 +911,7 @@ Changes in synapse 0.2.2 (2014-09-06)
===================================== =====================================
Homeserver: Homeserver:
* When the server returns state events it now also includes the previous * When the server returns state events it now also includes the previous
content. content.
* Add support for inviting people when creating a new room. * Add support for inviting people when creating a new room.
* Make the homeserver inform the room via `m.room.aliases` when a new alias * Make the homeserver inform the room via `m.room.aliases` when a new alias
@@ -1391,7 +923,7 @@ Webclient:
* Handle `m.room.aliases` events. * Handle `m.room.aliases` events.
* Asynchronously send messages and show a local echo. * Asynchronously send messages and show a local echo.
* Inform the UI when a message failed to send. * Inform the UI when a message failed to send.
* Only autoscroll on receiving a new message if the user was already at the * Only autoscroll on receiving a new message if the user was already at the
bottom of the screen. bottom of the screen.
* Add support for ban/kick reasons. * Add support for ban/kick reasons.

View File

@@ -11,10 +11,8 @@ recursive-include synapse/storage/schema *.sql
recursive-include synapse/storage/schema *.py recursive-include synapse/storage/schema *.py
recursive-include docs * recursive-include docs *
recursive-include res *
recursive-include scripts * recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py recursive-include tests *.py
recursive-include synapse/static *.css recursive-include synapse/static *.css
@@ -24,7 +22,5 @@ recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh exclude jenkins*.sh
exclude jenkins*
recursive-exclude jenkins *.sh
prune demo/etc prune demo/etc

View File

@@ -11,8 +11,8 @@ VoIP. The basics you need to know to get up and running are:
like ``#matrix:matrix.org`` or ``#test:localhost:8448``. like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future - Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
you will normally refer to yourself and others using a third party identifier you will normally refer to yourself and others using a 3PID: email
(3PID): email address, phone number, etc rather than manipulating Matrix user IDs) address, phone number, etc rather than manipulating Matrix user IDs)
The overall architecture is:: The overall architecture is::
@@ -58,13 +58,12 @@ the spec in the context of a codebase and let you run your own homeserver and
generally help bootstrap the ecosystem. generally help bootstrap the ecosystem.
In Matrix, every user runs one or more Matrix clients, which connect through to In Matrix, every user runs one or more Matrix clients, which connect through to
a Matrix homeserver. The homeserver stores all their personal chat history and a Matrix homeserver which stores all their personal chat history and user
user account information - much as a mail client connects through to an account information - much as a mail client connects through to an IMAP/SMTP
IMAP/SMTP server. Just like email, you can either run your own Matrix server. Just like email, you can either run your own Matrix homeserver and
homeserver and control and own your own communications and history or use one control and own your own communications and history or use one hosted by
hosted by someone else (e.g. matrix.org) - there is no single point of control someone else (e.g. matrix.org) - there is no single point of control or
or mandatory service provider in Matrix, unlike WhatsApp, Facebook, Hangouts, mandatory service provider in Matrix, unlike WhatsApp, Facebook, Hangouts, etc.
etc.
Synapse ships with two basic demo Matrix clients: webclient (a basic group chat Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
web client demo implemented in AngularJS) and cmdclient (a basic Python web client demo implemented in AngularJS) and cmdclient (a basic Python
@@ -95,7 +94,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
- Python 2.7 - Python 2.7
- At least 1GB of free RAM if you want to join large public rooms like #matrix:matrix.org - At least 512 MB RAM.
Synapse is written in python but some of the libraries is uses are written in Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the C. So before we can install synapse itself we need a working C compiler and the
@@ -105,7 +104,7 @@ Installing prerequisites on Ubuntu or Debian::
sudo apt-get install build-essential python2.7-dev libffi-dev \ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \ python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev libxslt1-dev libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux:: Installing prerequisites on ArchLinux::
@@ -134,12 +133,6 @@ Installing prerequisites on Raspbian::
sudo pip install --upgrade ndg-httpsclient sudo pip install --upgrade ndg-httpsclient
sudo pip install --upgrade virtualenv sudo pip install --upgrade virtualenv
Installing prerequisites on openSUSE::
sudo zypper in -t pattern devel_basis
sudo zypper in python-pip python-setuptools sqlite3 python-virtualenv \
python-devel libffi-devel libopenssl-devel libjpeg62-devel
To install the synapse homeserver run:: To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
@@ -205,21 +198,6 @@ run (e.g. ``~/.synapse``), and::
source ./bin/activate source ./bin/activate
synctl start synctl start
Security Note
=============
Matrix serves raw user generated data in some APIs - specifically the content
repository endpoints: http://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-media-r0-download-servername-mediaid
Whilst we have tried to mitigate against possible XSS attacks (e.g.
https://github.com/matrix-org/synapse/pull/1021) we recommend running
matrix homeservers on a dedicated domain name, to limit any malicious user generated
content served to web browsers a matrix API from being able to attack webapps hosted
on the same domain. This is particularly true of sharing a matrix webclient and
server on the same domain.
See https://github.com/vector-im/vector-web/issues/1977 and
https://developer.github.com/changes/2014-04-25-user-content-security for more details.
Using PostgreSQL Using PostgreSQL
================ ================
@@ -236,6 +214,9 @@ The advantages of Postgres include:
pointing at the same DB master, as well as enabling DB replication in pointing at the same DB master, as well as enabling DB replication in
synapse itself. synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
For information on how to install and use PostgreSQL, please see For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_. `docs/postgres.rst <docs/postgres.rst>`_.
@@ -463,7 +444,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs: IDs:
1) Use the machine's own hostname as available on public DNS in the form of 1) Use the machine's own hostname as available on public DNS in the form of
its A records. This is easier to set up initially, perhaps for its A or AAAA records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV. testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV 2) Set up a SRV record for your domain name. This requires you create a SRV
@@ -576,23 +557,6 @@ as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (https://matrix.org) at the current we are running a single identity server (https://matrix.org) at the current
time. time.
URL Previews
============
Synapse 0.15.0 introduces an experimental new API for previewing URLs at
/_matrix/media/r0/preview_url. This is disabled by default. To turn it on
you must enable the `url_preview_enabled: True` config parameter and explicitly
specify the IP ranges that Synapse is not allowed to spider for previewing in
the `url_preview_ip_range_blacklist` configuration parameter. This is critical
from a security perspective to stop arbitrary Matrix users spidering 'internal'
URLs on your network. At the very least we recommend that your loopback and
RFC1918 IP addresses are blacklisted.
This also requires the optional lxml and netaddr python dependencies to be
installed.
Password reset Password reset
============== ==============
@@ -636,7 +600,7 @@ Building internal API documentation::
Help!! Synapse eats all my RAM! Halp!! Synapse eats all my RAM!
=============================== ===============================
Synapse's architecture is quite RAM hungry currently - we deliberately Synapse's architecture is quite RAM hungry currently - we deliberately

View File

@@ -27,15 +27,7 @@ running:
# Pull the latest version of the master branch. # Pull the latest version of the master branch.
git pull git pull
# Update the versions of synapse's python dependencies. # Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.15.0
====================
If you want to use the new URL previewing API (/_matrix/media/r0/preview_url)
then you have to explicitly enable it in the config and update your dependencies
dependencies. See README.rst for details.
Upgrading to v0.11.0 Upgrading to v0.11.0

View File

@@ -9,7 +9,6 @@ Description=Synapse Matrix homeserver
Type=simple Type=simple
User=synapse User=synapse
Group=synapse Group=synapse
EnvironmentFile=-/etc/sysconfig/synapse
WorkingDirectory=/var/lib/synapse WorkingDirectory=/var/lib/synapse
ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml

View File

@@ -1,12 +0,0 @@
Admin APIs
==========
This directory includes documentation for the various synapse specific admin
APIs available.
Only users that are server admins can use these APIs. A user can be marked as a
server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register.

View File

@@ -1,15 +0,0 @@
Purge History API
=================
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.

View File

@@ -1,19 +0,0 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
{
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View File

@@ -32,4 +32,5 @@ The format of the AS configuration file is as follows:
See the spec_ for further details on how application services work. See the spec_ for further details on how application services work.
.. _spec: https://matrix.org/docs/spec/application_service/unstable.html .. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api

View File

@@ -43,10 +43,7 @@ Basically, PEP8
together, or want to deliberately extend or preserve vertical/horizontal together, or want to deliberately extend or preserve vertical/horizontal
space) space)
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_. Comments should follow the google code style. This is so that we can generate
This is so that we can generate documentation with documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
in the sphinx documentation.
Code should pass pep8 --max-line-length=100 without any warnings. Code should pass pep8 --max-line-length=100 without any warnings.

View File

@@ -1,10 +0,0 @@
What do I do about "Unexpected logging context" debug log-lines everywhere?
<Mjark> The logging context lives in thread local storage
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
Unanswered: how and when should we preserve log context?

View File

@@ -1,58 +0,0 @@
Replication Architecture
========================
Motivation
----------
We'd like to be able to split some of the work that synapse does into multiple
python processes. In theory multiple synapse processes could share a single
postgresql database and we'd scale up by running more synapse processes.
However much of synapse assumes that only one process is interacting with the
database, both for assigning unique identifiers when inserting into tables,
notifying components about new updates, and for invalidating its caches.
So running multiple copies of the current code isn't an option. One way to
run multiple processes would be to have a single writer process and multiple
reader processes connected to the same database. In order to do this we'd need
a way for the reader process to invalidate its in-memory caches when an update
happens on the writer. One way to do this is for the writer to present an
append-only log of updates which the readers can consume to invalidate their
caches and to push updates to listening clients or pushers.
Synapse already stores much of its data as an append-only log so that it can
correctly respond to /sync requests so the amount of code changes needed to
expose the append-only log to the readers should be fairly minimal.
Architecture
------------
The Replication API
~~~~~~~~~~~~~~~~~~~
Synapse will optionally expose a long poll HTTP API for extracting updates. The
API will have a similar shape to /sync in that clients provide tokens
indicating where in the log they have reached and a timeout. The synapse server
then either responds with updates immediately if it already has updates or it
waits until the timeout for more updates. If the timeout expires and nothing
happened then the server returns an empty response.
However unlike the /sync API this replication API is returning synapse specific
data rather than trying to implement a matrix specification. The replication
results are returned as arrays of rows where the rows are mostly lifted
directly from the database. This avoids unnecessary JSON parsing on the server
and hopefully avoids an impedance mismatch between the data returned and the
required updates to the datastore.
This does not replicate all the database tables as many of the database tables
are indexes that can be recovered from the contents of other tables.
The format and parameters for the api are documented in
``synapse/replication/resource.py``.
The Slaved DataStore
~~~~~~~~~~~~~~~~~~~~
There are read-only version of the synapse storage layer in
``synapse/replication/slave/storage`` that use the response of the replication
API to invalidate their caches.

View File

@@ -9,35 +9,31 @@ the Home Server to generate credentials that are valid for use on the TURN
server through the use of a secret shared between the Home Server and the server through the use of a secret shared between the Home Server and the
TURN server. TURN server.
This document describes how to install coturn This document described how to install coturn
(https://github.com/coturn/coturn) which also supports the TURN REST API, (https://code.google.com/p/coturn/) which also supports the TURN REST API,
and integrate it with synapse. and integrate it with synapse.
coturn Setup coturn Setup
============ ============
You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
1. Check out coturn:: 1. Check out coturn::
svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
git clone https://github.com/coturn/coturn.git coturn
cd coturn cd coturn
2. Configure it:: 2. Configure it::
./configure ./configure
You may need to install ``libevent2``: if so, you should do so You may need to install libevent2: if so, you should do so
in the way recommended by your operating system. in the way recommended by your operating system.
You can ignore warnings about lack of database support: a You can ignore warnings about lack of database support: a
database is unnecessary for this purpose. database is unnecessary for this purpose.
3. Build and install it:: 3. Build and install it::
make make
make install make install
4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant 4. Make a config file in /etc/turnserver.conf. You can customise
a config file from turnserver.conf.default. The relevant
lines, with example values, are:: lines, with example values, are::
lt-cred-mech lt-cred-mech
@@ -45,7 +41,7 @@ You may be able to setup coturn via your package manager, or set it up manually
static-auth-secret=[your secret key here] static-auth-secret=[your secret key here]
realm=turn.myserver.org realm=turn.myserver.org
See turnserver.conf for explanations of the options. See turnserver.conf.default for explanations of the options.
One way to generate the static-auth-secret is with pwgen:: One way to generate the static-auth-secret is with pwgen::
pwgen -s 64 1 pwgen -s 64 1
@@ -58,7 +54,6 @@ You may be able to setup coturn via your package manager, or set it up manually
import your private key and certificate. import your private key and certificate.
7. Start the turn server:: 7. Start the turn server::
bin/turnserver -o bin/turnserver -o

View File

@@ -1,74 +0,0 @@
URL Previews
============
Design notes on a URL previewing service for Matrix:
Options are:
1. Have an AS which listens for URLs, downloads them, and inserts an event that describes their metadata.
* Pros:
* Decouples the implementation entirely from Synapse.
* Uses existing Matrix events & content repo to store the metadata.
* Cons:
* Which AS should provide this service for a room, and why should you trust it?
* Doesn't work well with E2E; you'd have to cut the AS into every room
* the AS would end up subscribing to every room anyway.
2. Have a generic preview API (nothing to do with Matrix) that provides a previewing service:
* Pros:
* Simple and flexible; can be used by any clients at any point
* Cons:
* If each HS provides one of these independently, all the HSes in a room may needlessly DoS the target URI
* We need somewhere to store the URL metadata rather than just using Matrix itself
* We can't piggyback on matrix to distribute the metadata between HSes.
3. Make the synapse of the sending user responsible for spidering the URL and inserting an event asynchronously which describes the metadata.
* Pros:
* Works transparently for all clients
* Piggy-backs nicely on using Matrix for distributing the metadata.
* No confusion as to which AS
* Cons:
* Doesn't work with E2E
* We might want to decouple the implementation of the spider from the HS, given spider behaviour can be quite complicated and evolve much more rapidly than the HS. It's more like a bot than a core part of the server.
4. Make the sending client use the preview API and insert the event itself when successful.
* Pros:
* Works well with E2E
* No custom server functionality
* Lets the client customise the preview that they send (like on FB)
* Cons:
* Entirely specific to the sending client, whereas it'd be nice if /any/ URL was correctly previewed if clients support it.
5. Have the option of specifying a shared (centralised) previewing service used by a room, to avoid all the different HSes in the room DoSing the target.
Best solution is probably a combination of both 2 and 4.
* Sending clients do their best to create and send a preview at the point of sending the message, perhaps delaying the message until the preview is computed? (This also lets the user validate the preview before sending)
* Receiving clients have the option of going and creating their own preview if one doesn't arrive soon enough (or if the original sender didn't create one)
This is a bit magical though in that the preview could come from two entirely different sources - the sending HS or your local one. However, this can always be exposed to users: "Generate your own URL previews if none are available?"
This is tantamount also to senders calculating their own thumbnails for sending in advance of the main content - we are trusting the sender not to lie about the content in the thumbnail. Whereas currently thumbnails are calculated by the receiving homeserver to avoid this attack.
However, this kind of phishing attack does exist whether we let senders pick their thumbnails or not, in that a malicious sender can send normal text messages around the attachment claiming it to be legitimate. We could rely on (future) reputation/abuse management to punish users who phish (be it with bogus metadata or bogus descriptions). Bogus metadata is particularly bad though, especially if it's avoidable.
As a first cut, let's do #2 and have the receiver hit the API to calculate its own previews (as it does currently for image thumbnails). We can then extend/optimise this to option 4 as a special extra if needed.
API
---
GET /_matrix/media/r0/preview_url?url=http://wherever.com
200 OK
{
"og:type" : "article"
"og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672"
"og:title" : "Matrix on Twitter"
"og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png"
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
"og:site_name" : "Twitter"
}
* Downloads the URL
* If HTML, just stores it in RAM and parses it for OG meta tags
* Download any media OG meta tags to the media repo, and refer to them in the OG via mxc:// URIs.
* If a media filetype we know we can thumbnail: store it on disk, and hand it to the thumbnailer. Generate OG meta tags from the thumbnailer contents.
* Otherwise, don't bother downloading further.

View File

@@ -1,98 +0,0 @@
Scaling synapse via workers
---------------------------
Synapse has experimental support for splitting out functionality into
multiple separate python processes, helping greatly with scalability. These
processes are called 'workers', and are (eventually) intended to scale
horizontally independently.
All processes continue to share the same database instance, and as such, workers
only work with postgres based synapse deployments (sharing a single sqlite
across multiple processes is a recipe for disaster, plus you should be using
postgres anyway if you care about scalability).
The workers communicate with the master synapse process via a synapse-specific
HTTP protocol called 'replication' - analogous to MySQL or Postgres style
database replication; feeding a stream of relevant data to the workers so they
can be kept in sync with the main synapse process and database state.
To enable workers, you need to add a replication listener to the master synapse, e.g.::
listeners:
- port: 9092
bind_address: '127.0.0.1'
type: http
tls: false
x_forwarded: false
resources:
- names: [replication]
compress: false
Under **no circumstances** should this replication API listener be exposed to the
public internet; it currently implements no authentication whatsoever and is
unencrypted HTTP.
You then create a set of configs for the various worker processes. These should be
worker configuration files should be stored in a dedicated subdirectory, to allow
synctl to manipulate them.
The current available worker applications are:
* synapse.app.pusher - handles sending push notifications to sygnal and email
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
* synapse.app.appservice - handles output traffic to Application Services
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
* synapse.app.media_repository - handles the media repository.
* synapse.app.client_reader - handles client API endpoints like /publicRooms
Each worker configuration file inherits the configuration of the main homeserver
configuration file. You can then override configuration specific to that worker,
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
You should minimise the number of overrides though to maintain a usable config.
You must specify the type of worker application (worker_app) and the replication
endpoint that it's talking to on the main synapse process (worker_replication_url).
For instance::
worker_app: synapse.app.synchrotron
# The replication listener on the synapse to talk to.
worker_replication_url: http://127.0.0.1:9092/_synapse/replication
worker_listeners:
- type: http
port: 8083
resources:
- names:
- client
worker_daemonize: True
worker_pid_file: /home/matrix/synapse/synchrotron.pid
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
...is a full configuration for a synchrotron worker instance, which will expose a
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
by the main synapse.
Obviously you should configure your loadbalancer to route the /sync endpoint to
the synchrotron instance(s) in this instance.
Finally, to actually run your worker-based synapse, you must pass synctl the -a
commandline option to tell it to operate on all the worker configurations found
in the given directory, e.g.::
synctl -a $CONFIG/workers start
Currently one should always restart all workers when restarting or upgrading
synapse, unless you explicitly know it's safe not to. For instance, restarting
synapse without restarting all the synchrotrons may result in broken typing
notifications.
To manipulate a specific worker, you pass the -w option to synctl::
synctl -w $CONFIG/workers/synchrotron.yaml restart
All of the above is highly experimental and subject to change as Synapse evolves,
but documenting it here to help folks needing highly scalable Synapses similar
to the one running matrix.org!

View File

@@ -1,24 +0,0 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
./dendron/jenkins/build_dendron.sh
./sytest/jenkins/prep_sytest_for_postgres.sh
./sytest/jenkins/install_and_run.sh \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--federation-reader \
--client-reader \
--appservice \

View File

@@ -4,14 +4,58 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/prep_sytest_for_postgres.sh # Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
./sytest/jenkins/install_and_run.sh \ rm .coverage* || echo "No coverage files to remove"
--synapse-directory $WORKSPACE \
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
$TOX_BIN/pip install psycopg2
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -4,12 +4,52 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh # Output test results as junit xml
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
./sytest/jenkins/install_and_run.sh \ # Output flake8 violations to violations.flake8.log
--synapse-directory $WORKSPACE \ # Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8500}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -22,9 +22,4 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove" rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
tox -e py27 tox -e py27

86
jenkins.sh Executable file
View File

@@ -0,0 +1,86 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
: ${PORT_BASE:=8000}
echo >&2 "Running sytest with SQLite3";
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF
name: psycopg2
args:
database: synapse_jenkins_$port
EOF
fi
done
# Run if both postgresql databases exist
if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
else
echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
fi
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -1,44 +0,0 @@
#! /bin/bash
# This clones a project from github into a named subdirectory
# If the project has a branch with the same name as this branch
# then it will checkout that branch after cloning.
# Otherwise it will checkout "origin/develop."
# The first argument is the name of the directory to checkout
# the branch into.
# The second argument is the URL of the remote repository to checkout.
# Usually something like https://github.com/matrix-org/sytest.git
set -eux
NAME=$1
PROJECT=$2
BASE=".$NAME-base"
# Update our mirror.
if [ ! -d ".$NAME-base" ]; then
# Create a local mirror of the source repository.
# This saves us from having to download the entire repository
# when this script is next run.
git clone "$PROJECT" "$BASE" --mirror
else
# Fetch any updates from the source repository.
(cd "$BASE"; git fetch -p)
fi
# Remove the existing repository so that we have a clean copy
rm -rf "$NAME"
# Cloning with --shared means that we will share portions of the
# .git directory with our local mirror.
git clone "$BASE" "$NAME" --shared
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd "$NAME"
# check out the relevant branch
git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop"
)

View File

@@ -1,20 +0,0 @@
#! /bin/bash
cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox
mkdir -p $TOX_DIR
if ! [ $TOX_DIR -ef .tox ]; then
ln -s "$TOX_DIR" .tox
fi
# set up the virtualenv
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
$TOX_BIN/pip install setuptools
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View File

@@ -1,7 +0,0 @@
.header {
border-bottom: 4px solid #e4f7ed ! important;
}
.notif_link a, .footer a {
color: #76CFA6 ! important;
}

View File

@@ -1,156 +0,0 @@
body {
margin: 0px;
}
pre, code {
word-break: break-word;
white-space: pre-wrap;
}
#page {
font-family: 'Open Sans', Helvetica, Arial, Sans-Serif;
font-color: #454545;
font-size: 12pt;
width: 100%;
padding: 20px;
}
#inner {
width: 640px;
}
.header {
width: 100%;
height: 87px;
color: #454545;
border-bottom: 4px solid #e5e5e5;
}
.logo {
text-align: right;
margin-left: 20px;
}
.salutation {
padding-top: 10px;
font-weight: bold;
}
.summarytext {
}
.room {
width: 100%;
color: #454545;
border-bottom: 1px solid #e5e5e5;
}
.room_header td {
padding-top: 38px;
padding-bottom: 10px;
border-bottom: 1px solid #e5e5e5;
}
.room_name {
vertical-align: middle;
font-size: 18px;
font-weight: bold;
}
.room_header h2 {
margin-top: 0px;
margin-left: 75px;
font-size: 20px;
}
.room_avatar {
width: 56px;
line-height: 0px;
text-align: center;
vertical-align: middle;
}
.room_avatar img {
width: 48px;
height: 48px;
object-fit: cover;
border-radius: 24px;
}
.notif {
border-bottom: 1px solid #e5e5e5;
margin-top: 16px;
padding-bottom: 16px;
}
.historical_message .sender_avatar {
opacity: 0.3;
}
/* spell out opacity and historical_message class names for Outlook aka Word */
.historical_message .sender_name {
color: #e3e3e3;
}
.historical_message .message_time {
color: #e3e3e3;
}
.historical_message .message_body {
color: #c7c7c7;
}
.historical_message td,
.message td {
padding-top: 10px;
}
.sender_avatar {
width: 56px;
text-align: center;
vertical-align: top;
}
.sender_avatar img {
margin-top: -2px;
width: 32px;
height: 32px;
border-radius: 16px;
}
.sender_name {
display: inline;
font-size: 13px;
color: #a2a2a2;
}
.message_time {
text-align: right;
width: 100px;
font-size: 11px;
color: #a2a2a2;
}
.message_body {
}
.notif_link td {
padding-top: 10px;
padding-bottom: 10px;
font-weight: bold;
}
.notif_link a, .footer a {
color: #454545;
text-decoration: none;
}
.debug {
font-size: 10px;
color: #888;
}
.footer {
margin-top: 20px;
text-align: center;
}

View File

@@ -1,45 +0,0 @@
{% for message in notif.messages %}
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
<td class="sender_avatar">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
{% if message.sender_avatar_url %}
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
{% else %}
{% if message.sender_hash % 3 == 0 %}
<img class="sender_avatar" src="https://vector.im/beta/img/76cfa6.png" />
{% elif message.sender_hash % 3 == 1 %}
<img class="sender_avatar" src="https://vector.im/beta/img/50e2c2.png" />
{% else %}
<img class="sender_avatar" src="https://vector.im/beta/img/f4c371.png" />
{% endif %}
{% endif %}
{% endif %}
</td>
<td class="message_contents">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
{% endif %}
<div class="message_body">
{% if message.msgtype == "m.text" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.image" %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{% elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
{% endif %}
</div>
</td>
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
</tr>
{% endfor %}
<tr class="notif_link">
<td></td>
<td>
<a href="{{ notif.link }}">View {{ room.title }}</a>
</td>
<td></td>
</tr>

View File

@@ -1,16 +0,0 @@
{% for message in notif.messages %}
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{% if message.msgtype == "m.text" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.image" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.file" %}
{{ message.body_text_plain }}
{% endif %}
{% endfor %}
View {{ room.title }} at {{ notif.link }}

View File

@@ -1,53 +0,0 @@
<!doctype html>
<html lang="en">
<head>
<style type="text/css">
{% include 'mail.css' without context %}
{% include "mail-%s.css" % app_name ignore missing without context %}
</style>
</head>
<body>
<table id="page">
<tr>
<td> </td>
<td id="inner">
<table class="header">
<tr>
<td>
<div class="salutation">Hi {{ user_display_name }},</div>
<div class="summarytext">{{ summary_text }}</div>
</td>
<td class="logo">
{% if app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
</td>
</tr>
</table>
{% for room in rooms %}
{% include 'room.html' with context %}
{% endfor %}
<div class="footer">
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
<br/>
<br/>
<div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
{% else %}
and we don't have a last time we sent a mail for this room.
{% endif %}
</div>
</div>
</td>
<td> </td>
</tr>
</table>
</body>
</html>

View File

@@ -1,10 +0,0 @@
Hi {{ user_display_name }},
{{ summary_text }}
{% for room in rooms %}
{% include 'room.txt' with context %}
{% endfor %}
You can disable these notifications at {{ unsubscribe_link }}

View File

@@ -1,33 +0,0 @@
<table class="room">
<tr class="room_header">
<td class="room_avatar">
{% if room.avatar_url %}
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
{% else %}
{% if room.hash % 3 == 0 %}
<img alt="" src="https://vector.im/beta/img/76cfa6.png" />
{% elif room.hash % 3 == 1 %}
<img alt="" src="https://vector.im/beta/img/50e2c2.png" />
{% else %}
<img alt="" src="https://vector.im/beta/img/f4c371.png" />
{% endif %}
{% endif %}
</td>
<td class="room_name" colspan="2">
{{ room.title }}
</td>
</tr>
{% if room.invite %}
<tr>
<td></td>
<td>
<a href="{{ room.link }}">Join the conversation.</a>
</td>
<td></td>
</tr>
{% else %}
{% for notif in room.notifs %}
{% include 'notif.html' with context %}
{% endfor %}
{% endif %}
</table>

View File

@@ -1,9 +0,0 @@
{{ room.title }}
{% if room.invite %}
You've been invited, join at {{ room.link }}
{% else %}
{% for notif in room.notifs %}
{% include 'notif.txt' with context %}
{% endfor %}
{% endif %}

View File

@@ -116,19 +116,17 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = [] authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items(): for key, sig in signed_json["signatures"][origin_name].items():
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( authorization_headers.append(bytes(
origin_name, key, sig, "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
) origin_name, key, sig,
authorization_headers.append(bytes(header)) )
sys.stderr.write(header) ))
sys.stderr.write("\n")
result = requests.get( result = requests.get(
lookup(destination, path), lookup(destination, path),
headers={"Authorization": authorization_headers[0]}, headers={"Authorization": authorization_headers[0]},
verify=False, verify=False,
) )
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json() return result.json()
@@ -143,7 +141,6 @@ def main():
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)
print ""
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,16 +1,10 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse import argparse
import sys
import bcrypt import bcrypt
import getpass import getpass
import yaml
bcrypt_rounds=12 bcrypt_rounds=12
password_pepper = ""
def prompt_for_pass(): def prompt_for_pass():
password = getpass.getpass("Password: ") password = getpass.getpass("Password: ")
@@ -34,22 +28,12 @@ if __name__ == "__main__":
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
parser.add_argument(
"-c", "--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
)
args = parser.parse_args() args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
password_config = config.get("password_config", {})
password_pepper = password_config.get("pepper", password_pepper)
password = args.password password = args.password
if not password: if not password:
password = prompt_for_pass() password = prompt_for_pass()
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))

View File

@@ -25,26 +25,18 @@ import urllib2
import yaml import yaml
def request_registration(user, password, server_location, shared_secret, admin=False): def request_registration(user, password, server_location, shared_secret):
mac = hmac.new( mac = hmac.new(
key=shared_secret, key=shared_secret,
msg=user,
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
) ).hexdigest()
mac.update(user)
mac.update("\x00")
mac.update(password)
mac.update("\x00")
mac.update("admin" if admin else "notadmin")
mac = mac.hexdigest()
data = { data = {
"user": user, "user": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret", "type": "org.matrix.login.shared_secret",
"admin": admin,
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@@ -76,7 +68,7 @@ def request_registration(user, password, server_location, shared_secret, admin=F
sys.exit(1) sys.exit(1)
def register_new_user(user, password, server_location, shared_secret, admin): def register_new_user(user, password, server_location, shared_secret):
if not user: if not user:
try: try:
default_user = getpass.getuser() default_user = getpass.getuser()
@@ -107,14 +99,7 @@ def register_new_user(user, password, server_location, shared_secret, admin):
print "Passwords do not match" print "Passwords do not match"
sys.exit(1) sys.exit(1)
if not admin: request_registration(user, password, server_location, shared_secret)
admin = raw_input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__": if __name__ == "__main__":
@@ -134,11 +119,6 @@ if __name__ == "__main__":
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
parser.add_argument(
"-a", "--admin",
action="store_true",
help="Register new user as an admin. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument( group.add_argument(
@@ -171,4 +151,4 @@ if __name__ == "__main__":
else: else:
secret = args.shared_secret secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret, args.admin) register_new_user(args.user, args.password, args.server_url, secret)

View File

@@ -34,12 +34,11 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"], "events": ["processed", "outlier"],
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"], "presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"],
} }
@@ -72,14 +71,6 @@ APPEND_ONLY_TABLES = [
"event_to_state_groups", "event_to_state_groups",
"rejections", "rejections",
"event_search", "event_search",
"presence_stream",
"push_rules_stream",
"current_state_resets",
"ex_outlier_stream",
"cache_invalidation_stream",
"public_room_list_stream",
"state_group_edges",
"stream_ordering_to_exterm",
] ]
@@ -101,12 +92,8 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__[ _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
"_simple_select_one_onecol_txn"
]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@@ -171,40 +158,31 @@ class Porter(object):
def setup_table(self, table): def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
row = yield self.postgres_store._simple_select_one( next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcols=("forward_rowid", "backward_rowid"), retcol="rowid",
allow_none=True, allow_none=True,
) )
total_to_port = None total_to_port = None
if row is None: if next_chunk is None:
if table == "sent_transactions": if table == "sent_transactions":
forward_chunk, already_ported, total_to_port = ( next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions() yield self._setup_sent_transactions()
) )
backward_chunk = 0
else: else:
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "rowid": 1}
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 next_chunk = 1
backward_chunk = 0
already_ported = 0 already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, next_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
@@ -218,124 +196,32 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": table, "rowid": 0}
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
forward_chunk = 1 next_chunk = 1
backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, forward_chunk, backward_chunk table, next_chunk
) )
defer.returnValue( defer.returnValue((table, already_ported, total_to_port, next_chunk))
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, forward_chunk, def handle_table(self, table, postgres_size, table_size, next_chunk):
backward_chunk):
if not table_size: if not table_size:
return return
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search": select = (
yield self.handle_search_table(
postgres_size, table_size, forward_chunk, backward_chunk
)
return
forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,) % (table,)
) )
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
)
do_forward = [True]
do_backward = [True]
while True: while True:
def r(txn): def r(txn):
forward_rows = [] txn.execute(select, (next_chunk, self.batch_size,))
backward_rows = []
if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
if forward_rows or backward_rows:
headers = [column[0] for column in txn.description]
else:
headers = None
return headers, forward_rows, backward_rows
headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
if frows or brows:
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
)
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
self.progress.update(table, postgres_size)
else:
return
@defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, forward_chunk,
backward_chunk):
select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es"
" INNER JOIN events AS e USING (event_id, room_id)"
" WHERE es.rowid >= ?"
" ORDER BY es.rowid LIMIT ?"
)
while True:
def r(txn):
txn.execute(select, (forward_chunk, self.batch_size,))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@@ -344,51 +230,59 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r) headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows: if rows:
forward_chunk = rows[-1][0] + 1 next_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a if table == "event_search":
# different structure in the two different databases. # We have to treat event_search differently since it has a
def insert(txn): # different structure in the two different databases.
sql = ( def insert(txn):
"INSERT INTO event_search (event_id, room_id, key," sql = (
" sender, vector, origin_server_ts, stream_ordering)" "INSERT INTO event_search (event_id, room_id, key, sender, vector)"
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)" " VALUES (?,?,?,?,to_tsvector('english', ?))"
)
rows_dict = [
dict(zip(headers, row))
for row in rows
]
txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
row["origin_server_ts"],
row["stream_ordering"],
) )
for row in rows_dict
])
self.postgres_store._simple_update_one_txn( rows_dict = [
txn, dict(zip(headers, row))
table="port_from_sqlite3", for row in rows
keyvalues={"table_name": "event_search"}, ]
updatevalues={
"forward_rowid": forward_chunk, txn.executemany(sql, [
"backward_rowid": backward_chunk, (
}, row["event_id"],
) row["room_id"],
row["key"],
row["sender"],
row["value"],
)
for row in rows_dict
])
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
else:
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
postgres_size += len(rows) postgres_size += len(rows)
self.progress.update("event_search", postgres_size) self.progress.update(table, postgres_size)
else: else:
return return
@@ -462,32 +356,10 @@ class Porter(object):
txn.execute( txn.execute(
"CREATE TABLE port_from_sqlite3 (" "CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
" forward_rowid bigint NOT NULL," " rowid bigint NOT NULL"
" backward_rowid bigint NOT NULL"
")" ")"
) )
# The old port script created a table with just a "rowid" column.
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
)
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
@@ -547,7 +419,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _setup_sent_transactions(self): def _setup_sent_transactions(self):
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time() * 1000) - 86400000 yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination # And save the max transaction id from each destination
select = ( select = (
@@ -603,11 +475,7 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={ values={"table_name": "sent_transactions", "rowid": next_chunk}
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
@@ -628,18 +496,13 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count)) defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): def _get_remaining_count_to_port(self, table, next_chunk):
frows = yield self.sqlite_store.execute_sql( rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
forward_chunk, next_chunk,
) )
brows = yield self.sqlite_store.execute_sql( defer.returnValue(rows[0][0])
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
)
defer.returnValue(frows[0][0] + brows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
@@ -650,10 +513,10 @@ class Porter(object):
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults( remaining, done = yield defer.gatherResults(
[ [
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk), self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table), self._get_already_ported_count(table),
], ],
consumeErrors=True, consumeErrors=True,
@@ -784,7 +647,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len - len(table), i+2, left_margin + max_len - len(table),
table, table,
curses.A_BOLD | color, curses.A_BOLD | color,
) )
@@ -792,18 +655,18 @@ class CursesProgress(Progress):
size = 20 size = 20
progress = "[%s%s]" % ( progress = "[%s%s]" % (
"#" * int(perc * size / 100), "#" * int(perc*size/100),
" " * (size - int(perc * size / 100)), " " * (size - int(perc*size/100)),
) )
self.stdscr.addstr( self.stdscr.addstr(
i + 2, left_margin + max_len + middle_space, i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(
rows - 1, 0, rows-1, 0,
"Press any key to exit...", "Press any key to exit...",
) )

View File

@@ -16,5 +16,7 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503
[pep8]
max-line-length = 90

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.18.1" __version__ = "0.14.0"

View File

@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging """This module contains classes for authenticating the user."""
import pymacaroons
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.logcontext import preserve_context_over_fn
from unpaddedbase64 import decode_base64
import logging
import pymacaroons
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -41,20 +41,12 @@ AuthEventTypes = (
class Auth(object): class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ", "gen = ",
"guest = ", "guest = ",
@@ -63,18 +55,7 @@ class Auth(object):
"user_id = ", "user_id = ",
]) ])
@defer.inlineCallbacks def check(self, event, auth_events):
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@@ -85,35 +66,11 @@ class Auth(object):
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
with Measure(self.clock, "auth.check"): self.check_size_limits(event)
self.check_size_limits(event)
try:
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now) # are trusting that this is allowed (at least for now)
@@ -121,12 +78,6 @@ class Auth(object):
return True return True
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME # FIXME
return True return True
@@ -138,8 +89,8 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,) "Room %r does not exist" % (event.room_id,)
) )
creating_domain = get_domain_from_id(event.room_id) creating_domain = RoomID.from_string(event.room_id).domain
originating_domain = get_domain_from_id(event.sender) originating_domain = UserID.from_string(event.sender).domain
if creating_domain != originating_domain: if creating_domain != originating_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@@ -149,22 +100,6 @@ class Auth(object):
# FIXME: Temp hack # FIXME: Temp hack
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True return True
logger.debug( logger.debug(
@@ -183,24 +118,6 @@ class Auth(object):
return allowed return allowed
self.check_event_sender_in_room(event, auth_events) self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events) self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels: if event.type == EventTypes.PowerLevels:
@@ -210,6 +127,13 @@ class Auth(object):
self.check_redaction(event, auth_events) self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event) logger.debug("Allowing! %s", event)
except AuthError as e:
logger.info(
"Event auth check failed on event %s with msg: %s",
event, e.msg
)
logger.info("Denying! %s", event)
raise
def check_size_limits(self, event): def check_size_limits(self, event):
def too_big(field): def too_big(field):
@@ -295,17 +219,21 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"): curr_state = yield self.state.get_current_state(room_id)
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
entry = yield self.state.resolve_state_groups( for event in curr_state.values():
room_id, latest_event_ids if event.type == EventTypes.Member:
) try:
if UserID.from_string(event.state_key).domain != host:
continue
except:
logger.warn("state_key not user_id: %s", event.state_key)
continue
ret = yield self.store.is_host_joined( if event.content["membership"] == Membership.JOIN:
room_id, host, entry.state_group, entry.state defer.returnValue(True)
)
defer.returnValue(ret) defer.returnValue(False)
def check_event_sender_in_room(self, event, auth_events): def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
@@ -343,8 +271,8 @@ class Auth(object):
target_user_id = event.state_key target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id) creating_domain = RoomID.from_string(event.room_id).domain
target_domain = get_domain_from_id(target_user_id) target_domain = UserID.from_string(target_user_id).domain
if creating_domain != target_domain: if creating_domain != target_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@@ -400,10 +328,6 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content: if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events): if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.") raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
@@ -508,9 +432,6 @@ class Auth(object):
if not invite_event: if not invite_event:
return False return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id: if event.user_id != invite_event.user_id:
return False return False
@@ -591,36 +512,33 @@ class Auth(object):
return default return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False, rights="access"): def get_user_by_req(self, request, allow_guest=False):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object tuple of:
UserID (str)
Access token ID (str)
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
user_id = yield self._get_appservice_user_id(request) user_id = yield self._get_appservice_user_id(request.args)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue(synapse.types.create_requester(user_id)) defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
access_token = get_access_token_from_request( access_token = request.args["access_token"][0]
request, self.TOKEN_NOT_FOUND_HTTP_STATUS user_info = yield self.get_user_by_access_token(access_token)
)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
"User-Agent", "User-Agent",
@@ -632,8 +550,7 @@ class Auth(object):
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent, user_agent=user_agent
device_id=device_id,
) )
if is_guest and not allow_guest: if is_guest and not allow_guest:
@@ -643,8 +560,7 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(synapse.types.create_requester( defer.returnValue(Requester(user, token_id, is_guest))
user, token_id, is_guest, device_id))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -652,19 +568,17 @@ class Auth(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_appservice_user_id(self, request): def _get_appservice_user_id(self, request_args):
app_service = yield self.store.get_app_service_by_token( app_service = yield self.store.get_app_service_by_token(
get_access_token_from_request( request_args["access_token"][0]
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
) )
if app_service is None: if app_service is None:
defer.returnValue(None) defer.returnValue(None)
if "user_id" not in request.args: if "user_id" not in request_args:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
user_id = request.args["user_id"][0] user_id = request_args["user_id"][0]
if app_service.sender == user_id: if app_service.sender == user_id:
defer.returnValue(app_service.sender) defer.returnValue(app_service.sender)
@@ -681,7 +595,7 @@ class Auth(object):
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"): def get_user_by_access_token(self, token):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@@ -692,61 +606,46 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
ret = yield self.get_user_from_macaroon(token, rights) ret = yield self.get_user_from_macaroon(token)
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token) ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str, rights="access"): def get_user_from_macaroon(self, macaroon_str):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", False)
user_id = self.get_user_id_from_macaroon(macaroon) user_prefix = "user_id = "
user = UserID.from_string(user_id) user = None
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
guest = False guest = False
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id == "guest = true": if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
elif caveat.caveat_id == "guest = true":
guest = True guest = True
if user is None:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
if guest: if guest:
ret = { ret = {
"user": user, "user": user,
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
"device_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"token_id": None,
"device_id": None,
} }
else: else:
# This codepath exists for several reasons: # This codepath exists so that we can actually return a
# * so that we can actually return a token ID, which is used # token ID, because we use token IDs in place of device
# in some parts of the schema (where we probably ought to # identifiers throughout the codebase.
# use device IDs instead) # TODO(daniel): Remove this fallback when device IDs are
# * the only way we currently have to invalidate an # properly implemented.
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str) ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user: if ret["user"] != user:
logger.error( logger.error(
@@ -766,46 +665,21 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def get_user_id_from_macaroon(self, macaroon): def validate_macaroon(self, macaroon, type_string, verify_expiry):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
AuthError if there is no user_id caveat in the macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access", "refresh", type_string(str): The kind of token this is (e.g. "access", "refresh")
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet. token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
v.satisfy_exact("type = " + type_string) v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_exact("guest = true") v.satisfy_exact("guest = true")
if verify_expiry: if verify_expiry:
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
@@ -844,23 +718,17 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False, "is_guest": False,
"device_id": ret.get("device_id"),
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
try: try:
token = get_access_token_from_request( token = request.args["access_token"][0]
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
logger.warn("Unrecognised appservice access token: %s" % (token,)) logger.warn("Unrecognised appservice access token: %s" % (token,))
@@ -881,7 +749,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids) auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
auth_ids auth_ids
@@ -889,32 +757,30 @@ class Auth(object):
builder.auth_events = auth_events_entries builder.auth_events = auth_events_entries
@defer.inlineCallbacks def compute_auth_events(self, event, current_state):
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
defer.returnValue([]) return []
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event_id = current_state_ids.get(key) power_level_event = current_state.get(key)
if power_level_event_id: if power_level_event:
auth_ids.append(power_level_event_id) auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event_id = current_state_ids.get(key) join_rule_event = current_state.get(key)
key = (EventTypes.Member, event.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event_id = current_state_ids.get(key) member_event = current_state.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event_id = current_state_ids.get(key) create_event = current_state.get(key)
if create_event_id: if create_event:
auth_ids.append(create_event_id) auth_ids.append(create_event.event_id)
if join_rule_event_id: if join_rule_event:
join_rule_event = yield self.store.get_event(join_rule_event_id)
join_rule = join_rule_event.content.get("join_rule") join_rule = join_rule_event.content.get("join_rule")
is_public = join_rule == JoinRules.PUBLIC if join_rule else False is_public = join_rule == JoinRules.PUBLIC if join_rule else False
else: else:
@@ -923,21 +789,15 @@ class Auth(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
e_type = event.content["membership"] e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event_id: if join_rule_event:
auth_ids.append(join_rule_event_id) auth_ids.append(join_rule_event.event_id)
if e_type == Membership.JOIN: if e_type == Membership.JOIN:
if member_event_id and not is_public: if member_event and not is_public:
auth_ids.append(member_event_id) auth_ids.append(member_event.event_id)
else: else:
if member_event_id: if member_event:
auth_ids.append(member_event_id) auth_ids.append(member_event.event_id)
if for_verification:
key = (EventTypes.Member, event.state_key, )
existing_event_id = current_state_ids.get(key)
if existing_event_id:
auth_ids.append(existing_event_id)
if e_type == Membership.INVITE: if e_type == Membership.INVITE:
if "third_party_invite" in event.content: if "third_party_invite" in event.content:
@@ -945,15 +805,14 @@ class Auth(object):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
third_party_invite_id = current_state_ids.get(key) third_party_invite = current_state.get(key)
if third_party_invite_id: if third_party_invite:
auth_ids.append(third_party_invite_id) auth_ids.append(third_party_invite.event_id)
elif member_event_id: elif member_event:
member_event = yield self.store.get_event(member_event_id)
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
defer.returnValue(auth_ids) return auth_ids
def _get_send_level(self, etype, state_key, auth_events): def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
@@ -1035,8 +894,8 @@ class Auth(object):
if user_level >= redact_level: if user_level >= redact_level:
return False return False
redacter_domain = get_domain_from_id(event.event_id) redacter_domain = EventID.from_string(event.event_id).domain
redactee_domain = get_domain_from_id(event.redacts) redactee_domain = EventID.from_string(event.redacts).domain
if redacter_domain == redactee_domain: if redacter_domain == redactee_domain:
return True return True
@@ -1169,40 +1028,3 @@ class Auth(object):
"This server requires you to be a moderator in the room to" "This server requires you to be a moderator in the room to"
" edit its room list entry" " edit its room list entry"
) )
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
return bool(query_params)
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:
request: The http request.
token_not_found_http_status(int): The HTTP status code to set in the
AuthError if the token isn't found. This is used in some of the
legacy APIs to change the status code to 403 from the default of
401 since some of the old clients depended on auth errors returning
403.
Returns:
str: The access_token
Raises:
AuthError: If there isn't an access_token in the request.
"""
query_params = request.args.get("access_token")
# Try to get the access_token from the query params.
if not query_params:
raise AuthError(
token_not_found_http_status,
"Missing access token.",
errcode=Codes.MISSING_TOKEN
)
return query_params[0]

View File

@@ -85,8 +85,3 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"

View File

@@ -42,10 +42,8 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE" THREEPID_IN_USE = "THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@@ -15,9 +15,7 @@
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID from synapse.types import UserID, RoomID
from twisted.internet import defer import json
import ujson as json
class Filtering(object): class Filtering(object):
@@ -26,10 +24,10 @@ class Filtering(object):
super(Filtering, self).__init__() super(Filtering, self).__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
result = yield self.store.get_user_filter(user_localpart, filter_id) result = self.store.get_user_filter(user_localpart, filter_id)
defer.returnValue(FilterCollection(result)) result.addCallback(FilterCollection)
return result
def add_user_filter(self, user_localpart, user_filter): def add_user_filter(self, user_localpart, user_filter):
self.check_valid_filter(user_filter) self.check_valid_filter(user_filter)
@@ -191,17 +189,6 @@ class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
@@ -220,10 +207,9 @@ class Filter(object):
event.get("room_id", None), event.get("room_id", None),
sender, sender,
event.get("type", None), event.get("type", None),
"url" in event.get("content", {})
) )
def check_fields(self, room_id, sender, event_type, contains_url): def check_fields(self, room_id, sender, event_type):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
Returns: Returns:
@@ -237,20 +223,15 @@ class Filter(object):
for name, match_func in literal_keys.items(): for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,) not_name = "not_%s" % (name,)
disallowed_values = getattr(self, not_name) disallowed_values = self.filter_json.get(not_name, [])
if any(map(match_func, disallowed_values)): if any(map(match_func, disallowed_values)):
return False return False
allowed_values = getattr(self, name) allowed_values = self.filter_json.get(name, None)
if allowed_values is not None: if allowed_values is not None:
if not any(map(match_func, allowed_values)): if not any(map(match_func, allowed_values)):
return False return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True return True
def filter_rooms(self, room_ids): def filter_rooms(self, room_ids):

View File

@@ -25,3 +25,4 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@@ -16,11 +16,13 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse import python_dependencies # noqa: E402 from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
try: try:
python_dependencies.check_requirements() check_requirements()
except python_dependencies.MissingRequirementError as e: except MissingRequirementError as e:
message = "\n".join([ message = "\n".join([
"Missing Requirement: %s" % (e.message,), "Missing Requirement: %s" % (e.message,),
"To install run:", "To install run:",

View File

@@ -1,210 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
appservice_handler = self.get_application_service_handler()
@defer.inlineCallbacks
def replicate(results):
stream = results.get("events")
if stream:
max_stream_id = stream["position"]
yield appservice_handler.notify_interested_services(max_stream_id)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
replicate(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse appservice", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.appservice"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``notify_appservices: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.notify_appservices = True
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-appservice",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -1,216 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.client.v1.room import PublicRoomListRestServlet
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.client_reader")
class ClientReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
pass
class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
PublicRoomListRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse client reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse client reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.client_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-client-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -1,207 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
TransactionStore,
BaseSlavedStore,
):
pass
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -16,10 +16,14 @@
import synapse import synapse
import gc import contextlib
import logging import logging
import os import os
import re
import resource
import subprocess
import sys import sys
import time
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.python_dependencies import ( from synapse.python_dependencies import (
@@ -33,11 +37,18 @@ from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_d
from synapse.server import HomeServer from synapse.server import HomeServer
from twisted.conch.manhole import ColoredManhole
from twisted.conch.insults import insults
from twisted.conch import manhole_ssh
from twisted.cred import checkers, portal
from twisted.internet import reactor, task, defer from twisted.internet import reactor, task, defer
from twisted.application import service from twisted.application import service
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import GzipEncoderFactory from twisted.web.server import Site, GzipEncoderFactory, Request
from synapse.http.server import RootRedirect from synapse.http.server import RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@@ -51,18 +62,10 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics, get_metrics_for
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.http.site import SynapseSite
from synapse import events from synapse import events
from daemonize import Daemonize from daemonize import Daemonize
@@ -70,6 +73,9 @@ from daemonize import Daemonize
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
def gz_wrap(r): def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@@ -148,7 +154,7 @@ class SynapseHomeServer(HomeServer):
MEDIA_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource( CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path self, self.config.uploads_path, self.auth, self.content_addr
), ),
}) })
@@ -167,12 +173,7 @@ class SynapseHomeServer(HomeServer):
if name == "replication": if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationResource(self) resources[REPLICATION_PREFIX] = ReplicationResource(self)
if WEB_CLIENT_PREFIX in resources: root_resource = create_resource_tree(resources)
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource)
if tls: if tls:
reactor.listenSSL( reactor.listenSSL(
port, port,
@@ -205,13 +206,24 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http": if listener["type"] == "http":
self._listener_http(config, listener) self._listener_http(config, listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
matrix="rabbithole"
)
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
ColoredManhole,
{
"__name__": "__console__",
"hs": self,
}
)
f = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
reactor.listenTCP( reactor.listenTCP(
listener["port"], listener["port"],
manhole( f,
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1') interface=listener.get("bind_address", '127.0.0.1')
) )
else: else:
@@ -257,6 +269,86 @@ def quit_with_error(error_string):
sys.exit(1) sys.exit(1)
def get_version_string():
try:
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(__file__))
try:
git_branch = subprocess.check_output(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try:
git_tag = subprocess.check_output(
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
).strip()
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
except subprocess.CalledProcessError:
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
is_dirty = subprocess.check_output(
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
).strip().endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
s for s in
(git_branch, git_tag, git_commit, git_dirty,)
if s
)
return (
"Synapse/%s (%s)" % (
synapse.__version__, git_version,
)
).encode("ascii")
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def change_resource_limit(soft_file_no):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if not soft_file_no:
soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no)
resource.setrlimit(
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e:
logger.warn("Failed to set file or core limit: %s", e)
def setup(config_options): def setup(config_options):
""" """
Args: Args:
@@ -267,9 +359,10 @@ def setup(config_options):
HomeServer HomeServer
""" """
try: try:
config = HomeServerConfig.load_or_generate_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
config_options, config_options,
generate_section="Homeserver"
) )
except ConfigError as e: except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n") sys.stderr.write("\n" + e.message + "\n")
@@ -285,7 +378,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
version_string = "Synapse/" + get_version_string(synapse) version_string = get_version_string()
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)
@@ -302,6 +395,7 @@ def setup(config_options):
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
config=config, config=config,
content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
database_engine=database_engine, database_engine=database_engine,
) )
@@ -336,8 +430,6 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs)
reactor.callWhenRunning(start) reactor.callWhenRunning(start)
return hs return hs
@@ -353,13 +445,215 @@ class SynapseService(service.Service):
def startService(self): def startService(self):
hs = setup(self.config) hs = setup(self.config)
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
def stopService(self): def stopService(self):
return self._port.stopListening() return self._port.stopListening()
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
r'\1<redacted>\3',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration
except:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
int(db_txn_duration * 1000),
int(db_txn_count),
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
pass
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = _resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = _resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = _resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return root_resource
def _resource_id(resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resourceb
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def run(hs): def run(hs):
PROFILE_SYNAPSE = False PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE: if PROFILE_SYNAPSE:
@@ -385,8 +679,6 @@ def run(hs):
start_time = hs.get_clock().time() start_time = hs.get_clock().time()
stats = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def phone_stats_home(): def phone_stats_home():
logger.info("Gathering stats for reporting") logger.info("Gathering stats for reporting")
@@ -395,10 +687,7 @@ def run(hs):
if uptime < 0: if uptime < 0:
uptime = 0 uptime = 0
# If the stats directory is empty then this is the first time we've stats = {}
# reported stats.
first_time = not stats
stats["homeserver"] = hs.config.server_name stats["homeserver"] = hs.config.server_name
stats["timestamp"] = now stats["timestamp"] = now
stats["uptime_seconds"] = uptime stats["uptime_seconds"] = uptime
@@ -411,25 +700,6 @@ def run(hs):
daily_messages = yield hs.get_datastore().count_daily_messages() daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None: if daily_messages is not None:
stats["daily_messages"] = daily_messages stats["daily_messages"] = daily_messages
else:
stats.pop("daily_messages", None)
if first_time:
# Add callbacks to report the synapse stats as metrics whenever
# prometheus requests them, typically every 30s.
# As some of the stats are expensive to calculate we only update
# them when synapse phones home to matrix.org every 24 hours.
metrics = get_metrics_for("synapse.usage")
metrics.add_callback("timestamp", lambda: stats["timestamp"])
metrics.add_callback("uptime_seconds", lambda: stats["uptime_seconds"])
metrics.add_callback("total_users", lambda: stats["total_users"])
metrics.add_callback("total_room_count", lambda: stats["total_room_count"])
metrics.add_callback(
"daily_active_users", lambda: stats["daily_active_users"]
)
metrics.add_callback(
"daily_messages", lambda: stats.get("daily_messages", 0)
)
logger.info("Reporting stats to matrix.org: %s" % (stats,)) logger.info("Reporting stats to matrix.org: %s" % (stats,))
try: try:
@@ -450,8 +720,6 @@ def run(hs):
# sys.settrace(logcontext_tracer) # sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
reactor.run() reactor.run()
if hs.config.daemonize: if hs.config.daemonize:

View File

@@ -1,213 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.media_repository import MediaRepositoryStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import (
CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
)
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.media_repository")
class MediaRepositorySlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
BaseSlavedStore,
MediaRepositoryStore,
ClientIpStore,
):
pass
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "media":
media_repo = MediaRepositoryResource(self)
resources.update({
MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path
),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse media repository", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.media_repository"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_state_handler().start_caching()
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-media-repository",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -1,299 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.config._base import ConfigError
from synapse.config.logger import setup_logging
from synapse.config.homeserver import HomeServerConfig
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.storage.engines import create_engine
from synapse.storage import DataStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.pusher")
class PusherSlaveStore(
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
SlavedAccountDataStore
):
update_pusher_last_stream_ordering_and_success = (
DataStore.update_pusher_last_stream_ordering_and_success.__func__
)
update_pusher_failing_since = (
DataStore.update_pusher_failing_since.__func__
)
update_pusher_last_stream_ordering = (
DataStore.update_pusher_last_stream_ordering.__func__
)
get_throttle_params_by_room = (
DataStore.get_throttle_params_by_room.__func__
)
set_throttle_params = (
DataStore.set_throttle_params.__func__
)
get_time_of_last_push_action_before = (
DataStore.get_time_of_last_push_action_before.__func__
)
get_profile_displayname = (
DataStore.get_profile_displayname.__func__
)
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def remove_pusher(self, app_id, push_key, user_id):
http_client = self.get_simple_http_client()
replication_url = self.config.worker_replication_url
url = replication_url + "/remove_pushers"
return http_client.post_json_get_json(url, {
"remove": [{
"app_id": app_id,
"push_key": push_key,
"user_id": user_id,
}]
})
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse pusher now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
pushers_for_user = pusher_pool.pushers.get(user_id, {})
pusher = pushers_for_user.pop(key, None)
if pusher is None:
return
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()
def start_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
@defer.inlineCallbacks
def poke_pushers(results):
pushers_rows = set(
map(tuple, results.get("pushers", {}).get("rows", []))
)
deleted_pushers_rows = set(
map(tuple, results.get("deleted_pushers", {}).get("rows", []))
)
for row in sorted(pushers_rows | deleted_pushers_rows):
if row in deleted_pushers_rows:
user_id, app_id, pushkey = row[1:4]
stop_pusher(user_id, app_id, pushkey)
elif row in pushers_rows:
user_id = row[1]
app_id = row[5]
pushkey = row[8]
yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events")
if stream:
min_stream_id = stream["rows"][0][0]
max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_notifications)(
min_stream_id, max_stream_id
)
stream = results.get("receipts")
if stream:
rows = stream["rows"]
affected_room_ids = set(row[1] for row in rows)
min_stream_id = rows[0][0]
max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_receipts)(
min_stream_id, max_stream_id, affected_room_ids
)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
poke_pushers(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(30)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse pusher", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.pusher"
setup_logging(config.worker_log_config, config.worker_log_file)
if config.start_pushers:
sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``start_pushers: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
database_engine = create_engine(config.database_config)
ps = PusherServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ps.setup()
ps.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ps.replicate()
ps.get_pusherpool().start()
ps.get_datastore().start_profiling()
ps.get_state_handler().start_caching()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-pusher",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
ps = start(sys.argv[1:])

View File

@@ -1,494 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.api.constants import EventTypes, PresenceState
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.events import FrozenEvent
from synapse.handlers.presence import PresenceHandler
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import PresenceStore, UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import contextlib
import gc
import ujson as json
logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronSlavedStore(
SlavedPushRuleStore,
SlavedEventStore,
SlavedReceiptsStore,
SlavedAccountDataStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
SlavedDeviceInboxStore,
RoomStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
# XXX: This is a bit broken because we don't persist the accepted list in a
# way that can be replicated. This means that we don't have a way to
# invalidate the cache correctly.
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
state.user_id: state
for state in active_presence
}
self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id)
self._sending_sync = False
self._need_to_send_sync = False
self.clock.looping_call(
self._send_syncing_users_regularly,
UPDATE_SYNCING_USERS_MS,
)
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
get_state = PresenceHandler.get_state.__func__
_get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
def user_syncing(self, user_id, affect_presence):
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
prev_states = yield self.current_state_for_users([user_id])
if prev_states[user_id].state == PresenceState.OFFLINE:
# TODO: Don't block the sync request on this HTTP hit.
yield self._send_syncing_users_now()
def _end():
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
if affect_presence and user_id in self.user_to_num_current_syncs:
self.user_to_num_current_syncs[user_id] -= 1
@contextlib.contextmanager
def _user_syncing():
try:
yield
finally:
_end()
defer.returnValue(_user_syncing())
@defer.inlineCallbacks
def _on_shutdown(self):
# When the synchrotron is shutdown tell the master to clear the in
# progress syncs for this process
self.user_to_num_current_syncs.clear()
yield self._send_syncing_users_now()
def _send_syncing_users_regularly(self):
# Only send an update if we aren't in the middle of sending one.
if not self._sending_sync:
preserve_fn(self._send_syncing_users_now)()
@defer.inlineCallbacks
def _send_syncing_users_now(self):
if self._sending_sync:
# We don't want to race with sending another update.
# Instead we wait for that update to finish and send another
# update afterwards.
self._need_to_send_sync = True
return
# Flag that we are sending an update.
self._sending_sync = True
yield self.http_client.post_json_get_json(self.syncing_users_url, {
"process_id": self.process_id,
"syncing_users": [
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count > 0
],
})
# Unset the flag as we are no longer sending an update.
self._sending_sync = False
if self._need_to_send_sync:
# If something happened while we were sending the update then
# we might need to send another update.
# TODO: Check if the update that was sent matches the current state
# as we only need to send an update if they are different.
self._need_to_send_sync = False
yield self._send_syncing_users_now()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield self._get_interested_parties(
states, calculate_remote_hosts=False
)
room_ids_to_states, users_to_states, _ = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=users_to_states.keys()
)
@defer.inlineCallbacks
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
states = []
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
state = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
self.user_to_current_state[user_id] = state
states.append(state)
if states and "position" in stream:
stream_id = int(stream["position"])
yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object):
def __init__(self, hs):
self._latest_room_serial = 0
self._room_serials = {}
self._room_typing = {}
def stream_positions(self):
# We must update this typing token from the response of the previous
# sync. In particular, the stream id may "reset" back to zero/a low
# value which we *must* use for the next replication request.
return {"typing": self._latest_room_serial}
def process_replication(self, result):
stream = result.get("typing")
if stream:
self._latest_room_serial = int(stream["position"])
for row in stream["rows"]:
position, room_id, typing_json = row
typing = json.loads(typing_json)
self._room_serials[room_id] = position
self._room_typing[room_id] = typing
class SynchrotronApplicationService(object):
def notify_interested_services(self, event):
pass
class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
RoomInitialSyncRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
"/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse synchrotron now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
stream = result.get(stream_name)
if stream:
position_index = stream["field_names"].index("position")
if room:
room_index = stream["field_names"].index(room)
if user:
user_index = stream["field_names"].index(user)
users = ()
rooms = ()
for row in stream["rows"]:
position = row[position_index]
if user:
users = (row[user_index],)
if room:
rooms = (row[room_index],)
notifier.on_new_event(
stream_key, position, users=users, rooms=rooms
)
def notify(result):
stream = result.get("events")
if stream:
max_position = stream["position"]
for row in stream["rows"]:
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
notifier.on_new_room_event(
event, position, max_position, extra_users
)
notify_from_stream(
result, "push_rules", "push_rules_key", user="user_id"
)
notify_from_stream(
result, "user_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "room_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "tag_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "receipts", "receipt_key", room="room_id"
)
notify_from_stream(
result, "typing", "typing_key", room="room_id"
)
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
typing_handler.process_replication(result)
yield presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def build_presence_handler(self):
return SynchrotronPresence(self)
def build_typing_handler(self):
return SynchrotronTyping(self)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse synchrotron", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.synchrotron"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
ss = SynchrotronServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)
ss.setup()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
ss.get_state_handler().start_caching()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-synchrotron",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@@ -14,14 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import argparse import sys
import collections
import glob
import os import os
import os.path import os.path
import signal
import subprocess import subprocess
import sys import signal
import yaml import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
@@ -31,181 +28,56 @@ RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
def write(message, colour=NORMAL, stream=sys.stdout):
if colour == NORMAL:
stream.write(message + "\n")
else:
stream.write(colour + message + NORMAL + "\n")
def start(configfile): def start(configfile):
write("Starting ...") print ("Starting ...")
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", configfile]) args.extend(["--daemonize", "-c", configfile])
try: try:
subprocess.check_call(args) subprocess.check_call(args)
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN) print (GREEN + "started" + NORMAL)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
write( print (
"error starting (exit code: %d); see above for logs" % e.returncode, RED +
colour=RED, "error starting (exit code: %d); see above for logs" % e.returncode +
NORMAL
) )
def start_worker(app, configfile, worker_configfile): def stop(pidfile):
args = [
"python", "-B",
"-m", app,
"-c", configfile,
"-c", worker_configfile
]
try:
subprocess.check_call(args)
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
except subprocess.CalledProcessError as e:
write(
"error starting %s(%r) (exit code: %d); see above for logs" % (
app, worker_configfile, e.returncode,
),
colour=RED,
)
def stop(pidfile, app):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN) print (GREEN + "stopped" + NORMAL)
Worker = collections.namedtuple("Worker", [
"app", "configfile", "pidfile", "cache_factor"
])
def main(): def main():
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
parser = argparse.ArgumentParser()
parser.add_argument(
"action",
choices=["start", "stop", "restart"],
help="whether to start, stop or restart the synapse",
)
parser.add_argument(
"configfile",
nargs="?",
default="homeserver.yaml",
help="the homeserver config file, defaults to homserver.yaml",
)
parser.add_argument(
"-w", "--worker",
metavar="WORKERCONFIG",
help="start or stop a single worker",
)
parser.add_argument(
"-a", "--all-processes",
metavar="WORKERCONFIGDIR",
help="start or stop all the workers in the given directory"
" and the main synapse process",
)
options = parser.parse_args()
if options.worker and options.all_processes:
write(
'Cannot use "--worker" with "--all-processes"',
stream=sys.stderr
)
sys.exit(1)
configfile = options.configfile
if not os.path.exists(configfile): if not os.path.exists(configfile):
write( sys.stderr.write(
"No config file found\n" "No config file found\n"
"To generate a config file, run '%s -c %s --generate-config" "To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % ( " --server-name=<server name>'\n" % (
" ".join(SYNAPSE), options.configfile " ".join(SYNAPSE), configfile
), )
stream=sys.stderr,
) )
sys.exit(1) sys.exit(1)
with open(configfile) as stream: config = yaml.load(open(configfile))
config = yaml.load(stream)
pidfile = config["pid_file"] pidfile = config["pid_file"]
cache_factor = config.get("synctl_cache_factor")
start_stop_synapse = True
if cache_factor: action = sys.argv[1] if sys.argv[1:] else "usage"
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) if action == "start":
start(configfile)
worker_configfiles = [] elif action == "stop":
if options.worker: stop(pidfile)
start_stop_synapse = False elif action == "restart":
worker_configfile = options.worker stop(pidfile)
if not os.path.exists(worker_configfile): start(configfile)
write( else:
"No worker config found at %r" % (worker_configfile,), sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],))
stream=sys.stderr, sys.exit(1)
)
sys.exit(1)
worker_configfiles.append(worker_configfile)
if options.all_processes:
worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir):
write(
"No worker config directory found at %r" % (worker_configdir,),
stream=sys.stderr,
)
sys.exit(1)
worker_configfiles.extend(sorted(glob.glob(
os.path.join(worker_configdir, "*.yaml")
)))
workers = []
for worker_configfile in worker_configfiles:
with open(worker_configfile) as stream:
worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"]
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize # TODO print something more user friendly
worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
))
action = options.action
if action == "stop" or action == "restart":
for worker in workers:
stop(worker.pidfile, worker.app)
if start_stop_synapse:
stop(pidfile, "synapse.app.homeserver")
# TODO: Wait for synapse to actually shutdown before starting it again
if action == "start" or action == "restart":
if start_stop_synapse:
start(configfile)
for worker in workers:
if worker.cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
start_worker(worker.app, configfile, worker.configfile)
if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
else:
os.environ.pop("SYNAPSE_CACHE_FACTOR", None)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from twisted.internet import defer
import logging import logging
import re import re
@@ -81,7 +79,7 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None, protocols=None): sender=None, id=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
@@ -89,12 +87,6 @@ class ApplicationService(object):
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
# .protocols is a publicly visible field
if protocols:
self.protocols = set(protocols)
else:
self.protocols = set()
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
# { # {
@@ -146,66 +138,65 @@ class ApplicationService(object):
return regex_obj["exclusive"] return regex_obj["exclusive"]
return False return False
@defer.inlineCallbacks def _matches_user(self, event, member_list):
def _matches_user(self, event, store): if (hasattr(event, "sender") and
if not event: self.is_interested_in_user(event.sender)):
defer.returnValue(False) return True
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key # also check m.room.member state key
if (event.type == EventTypes.Member and if (hasattr(event, "type") and event.type == EventTypes.Member
self.is_interested_in_user(event.state_key)): and hasattr(event, "state_key")
defer.returnValue(True) and self.is_interested_in_user(event.state_key)):
return True
if not store:
defer.returnValue(False)
member_list = yield store.get_users_in_room(event.room_id)
# check joined member events # check joined member events
for user_id in member_list: for user_id in member_list:
if self.is_interested_in_user(user_id): if self.is_interested_in_user(user_id):
defer.returnValue(True) return True
defer.returnValue(False) return False
def _matches_room_id(self, event): def _matches_room_id(self, event):
if hasattr(event, "room_id"): if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id) return self.is_interested_in_room(event.room_id)
return False return False
@defer.inlineCallbacks def _matches_aliases(self, event, alias_list):
def _matches_aliases(self, event, store):
if not store or not event:
defer.returnValue(False)
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list: for alias in alias_list:
if self.is_interested_in_alias(alias): if self.is_interested_in_alias(alias):
defer.returnValue(True) return True
defer.returnValue(False) return False
@defer.inlineCallbacks def is_interested(self, event, restrict_to=None, aliases_for_event=None,
def is_interested(self, event, store=None): member_list=None):
"""Check if this service is interested in this event. """Check if this service is interested in this event.
Args: Args:
event(Event): The event to check. event(Event): The event to check.
store(DataStore) restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
Returns: Returns:
bool: True if this service would like to know about this event. bool: True if this service would like to know about this event.
""" """
# Do cheap checks first if aliases_for_event is None:
if self._matches_room_id(event): aliases_for_event = []
defer.returnValue(True) if member_list is None:
member_list = []
if (yield self._matches_aliases(event, store)): if restrict_to and restrict_to not in ApplicationService.NS_LIST:
defer.returnValue(True) # this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if (yield self._matches_user(event, store)): if not restrict_to:
defer.returnValue(True) return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
defer.returnValue(False) or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id): def is_interested_in_user(self, user_id):
return ( return (
@@ -225,9 +216,6 @@ class ApplicationService(object):
or user_id == self.sender or user_id == self.sender
) )
def is_interested_in_protocol(self, protocol):
return protocol in self.protocols
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View File

@@ -14,11 +14,9 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.util.caches.response_cache import ResponseCache
import logging import logging
import urllib import urllib
@@ -26,42 +24,6 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_metadata(info):
if "instances" not in info:
return False
if not isinstance(info["instances"], list):
return False
return True
def _is_valid_3pe_result(r, field):
if not isinstance(r, dict):
return False
for k in (field, "protocol"):
if k not in r:
return False
if not isinstance(r[k], str):
return False
if "fields" not in r:
return False
fields = r["fields"]
if not isinstance(fields, dict):
return False
for k in fields.keys():
if not isinstance(fields[k], str):
return False
return True
class ApplicationServiceApi(SimpleHttpClient): class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and """This class manages HS -> AS communications, including querying and
pushing. pushing.
@@ -71,12 +33,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs) super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None response = None
try: try:
@@ -96,8 +54,6 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_alias(self, service, alias): def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias)) uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None response = None
try: try:
@@ -115,84 +71,8 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_alias to %s threw exception %s", uri, ex) logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION:
required_field = "alias"
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
if service.url is None:
defer.returnValue([])
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r",
uri, response
)
defer.returnValue([])
ret = []
for r in response:
if _is_valid_3pe_result(r, field=required_field):
ret.append(r)
else:
logger.warning(
"query_3pe to %s returned an invalid result %r",
uri, r
)
defer.returnValue(ret)
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
if service.url is None:
defer.returnValue({})
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
)
try:
info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
logger.warning("query_3pe_protocol to %s did not return a"
" valid result", uri)
defer.returnValue(None)
defer.returnValue(info)
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
defer.returnValue(None)
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
if service.url is None:
defer.returnValue(True)
events = self._serialize(events) events = self._serialize(events)
if txn_id is None: if txn_id is None:

View File

@@ -48,35 +48,32 @@ UP & quit +---------- YES SUCCESS
This is all tied together by the AppServiceScheduler which DIs the required This is all tied together by the AppServiceScheduler which DIs the required
components. components.
""" """
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState from synapse.appservice import ApplicationServiceState
from synapse.util.logcontext import preserve_fn from twisted.internet import defer
from synapse.util.metrics import Measure
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ApplicationServiceScheduler(object): class AppServiceScheduler(object):
""" Public facing API for this module. Does the required DI to tie the """ Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this components together. This also serves as the "event_pool", which in this
case is a simple array. case is a simple array.
""" """
def __init__(self, hs): def __init__(self, clock, store, as_api):
self.clock = hs.get_clock() self.clock = clock
self.store = hs.get_datastore() self.store = store
self.as_api = hs.get_application_service_api() self.as_api = as_api
def create_recoverer(service, callback): def create_recoverer(service, callback):
return _Recoverer(self.clock, self.store, self.as_api, service, callback) return _Recoverer(clock, store, as_api, service, callback)
self.txn_ctrl = _TransactionController( self.txn_ctrl = _TransactionController(
self.clock, self.store, self.as_api, create_recoverer clock, store, as_api, create_recoverer
) )
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock) self.queuer = _ServiceQueuer(self.txn_ctrl)
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
@@ -97,36 +94,38 @@ class _ServiceQueuer(object):
this schedules any other events in the queue to run. this schedules any other events in the queue to run.
""" """
def __init__(self, txn_ctrl, clock): def __init__(self, txn_ctrl):
self.queued_events = {} # dict of {service_id: [events]} self.queued_events = {} # dict of {service_id: [events]}
self.requests_in_flight = set() self.pending_requests = {} # dict of {service_id: Deferred}
self.txn_ctrl = txn_ctrl self.txn_ctrl = txn_ctrl
self.clock = clock
def enqueue(self, service, event): def enqueue(self, service, event):
# if this service isn't being sent something # if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event) if not self.pending_requests.get(service.id):
preserve_fn(self._send_request)(service) self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
@defer.inlineCallbacks def _send_request(self, service, events):
def _send_request(self, service): # send request and add callbacks
if service.id in self.requests_in_flight: d = self.txn_ctrl.send(service, events)
return d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
self.requests_in_flight.add(service.id) def _on_request_finish(self, service):
try: self.pending_requests[service.id] = None
while True: # if there are queued events, then send them.
events = self.queued_events.pop(service.id, []) if (service.id in self.queued_events
if not events: and len(self.queued_events[service.id]) > 0):
return self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
with Measure(self.clock, "servicequeuer.send"): def _on_request_fail(self, err):
try: logger.error("AS request failed: %s", err)
yield self.txn_ctrl.send(service, events)
except:
logger.exception("AS request failed")
finally:
self.requests_in_flight.discard(service.id)
class _TransactionController(object): class _TransactionController(object):
@@ -150,12 +149,14 @@ class _TransactionController(object):
if service_is_up: if service_is_up:
sent = yield txn.send(self.as_api) sent = yield txn.send(self.as_api)
if sent: if sent:
yield txn.complete(self.store) txn.complete(self.store)
else: else:
preserve_fn(self._start_recoverer)(service) self._start_recoverer(service)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
preserve_fn(self._start_recoverer)(service) self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_recovered(self, recoverer): def on_recovered(self, recoverer):

View File

@@ -157,40 +157,9 @@ class Config(object):
return default_config, config return default_config, config
@classmethod @classmethod
def load_config(cls, description, argv): def load_config(cls, description, argv, generate_section=None):
config_parser = argparse.ArgumentParser(
description=description,
)
config_parser.add_argument(
"-c", "--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files."
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when"
" their location is given explicitly in the config."
" Defaults to the directory containing the last config file",
)
config_args = config_parser.parse_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
obj = cls() obj = cls()
obj.read_config_files(
config_files,
keys_directory=config_args.keys_directory,
generate_keys=False,
)
return obj
@classmethod
def load_or_generate_config(cls, description, argv):
config_parser = argparse.ArgumentParser(add_help=False) config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument( config_parser.add_argument(
"-c", "--config-path", "-c", "--config-path",
@@ -207,7 +176,7 @@ class Config(object):
config_parser.add_argument( config_parser.add_argument(
"--report-stats", "--report-stats",
action="store", action="store",
help="Whether the generated config reports anonymized usage statistics", help="Stuff",
choices=["yes", "no"] choices=["yes", "no"]
) )
config_parser.add_argument( config_parser.add_argument(
@@ -228,11 +197,36 @@ class Config(object):
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
generate_keys = config_args.generate_keys generate_keys = config_args.generate_keys
obj = cls() config_files = []
if config_args.config_path:
for config_path in config_args.config_path:
if os.path.isdir(config_path):
# We accept specifying directories as config paths, we search
# inside that directory for all files matching *.yaml, and then
# we apply them in *sorted* order.
files = []
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
continue
if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
continue
files.append(entry_path)
config_files.extend(sorted(files))
else:
config_files.append(config_path)
if config_args.generate_config: if config_args.generate_config:
if config_args.report_stats is None: if config_args.report_stats is None:
@@ -305,43 +299,28 @@ class Config(object):
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
obj.read_config_files( if config_args.keys_directory:
config_files, config_dir_path = config_args.keys_directory
keys_directory=config_args.keys_directory, else:
generate_keys=generate_keys, config_dir_path = os.path.dirname(config_args.config_path[-1])
) config_dir_path = os.path.abspath(config_dir_path)
if generate_keys:
return None
obj.invoke_all("read_arguments", args)
return obj
def read_config_files(self, config_files, keys_directory=None,
generate_keys=False):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
config_dir_path = os.path.abspath(keys_directory)
specified_config = {} specified_config = {}
for config_file in config_files: for config_file in config_files:
yaml_config = self.read_config_file(config_file) yaml_config = cls.read_config_file(config_file)
specified_config.update(yaml_config) specified_config.update(yaml_config)
if "server_name" not in specified_config: if "server_name" not in specified_config:
raise ConfigError(MISSING_SERVER_NAME) raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = self.generate_config( _, config = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
is_generating_file=False, is_generating_file=False,
) )
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config: if "report_stats" not in config:
raise ConfigError( raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
@@ -349,51 +328,11 @@ class Config(object):
) )
if generate_keys: if generate_keys:
self.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
return return
self.invoke_all("read_config", config) obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)
def find_config_files(search_paths): return obj
"""Finds config files using a list of search paths. If a path is a file
then that file path is added to the list. If a search path is a directory
then all the "*.yaml" files in that directory are added to the list in
sorted order.
Args:
search_paths(list(str)): A list of paths to search.
Returns:
list(str): A list of file paths.
"""
config_files = []
if search_paths:
for config_path in search_paths:
if os.path.isdir(config_path):
# We accept specifying directories as config paths, we search
# inside that directory for all files matching *.yaml, and then
# we apply them in *sorted* order.
files = []
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
continue
if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
continue
files.append(entry_path)
config_files.extend(sorted(files))
else:
config_files.append(config_path)
return config_files

View File

@@ -12,147 +12,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ConfigError from ._base import Config
from synapse.appservice import ApplicationService
from synapse.types import UserID
import urllib
import yaml
import logging
logger = logging.getLogger(__name__)
class AppServiceConfig(Config): class AppServiceConfig(Config):
def read_config(self, config): def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", []) self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
def default_config(cls, **kwargs): def default_config(cls, **kwargs):
return """\ return """\
# A list of application service config file to use # A list of application service config file to use
app_service_config_files: [] app_service_config_files: []
""" """
def load_appservices(hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
logger.warning(
"Expected %s to be a list of AS config files.", config_files
)
return []
# Dicts of value -> filename
seen_as_tokens = {}
seen_ids = {}
appservices = []
for config_file in config_files:
try:
with open(config_file, 'r') as f:
appservice = _load_appservice(
hostname, yaml.load(f), config_file
)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
"%s (files: %s, %s)" % (
appservice.id, config_file, seen_ids[appservice.id],
)
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
"%s (files: %s, %s)" % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
)
)
seen_as_tokens[appservice.token] = config_file
logger.info("Loaded application service: %s", appservice)
appservices.append(appservice)
except Exception as e:
logger.error("Failed to load appservice from '%s'", config_file)
logger.exception(e)
raise
return appservices
def _load_appservice(hostname, as_info, config_filename):
required_string_fields = [
"id", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
if not isinstance(as_info.get(field), basestring):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
if (not isinstance(as_info.get("url"), basestring) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
if urllib.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
user = UserID(localpart, hostname)
user_id = user.to_string()
# namespace checks
if not isinstance(as_info.get("namespaces"), dict):
raise KeyError("Requires 'namespaces' object.")
for ns in ApplicationService.NS_LIST:
# specific namespaces are optional
if ns in as_info["namespaces"]:
# expect a list of dicts with exclusive and regex keys
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
)
# protocols check
protocols = as_info.get("protocols")
if protocols:
# Because strings are lists in python
if isinstance(protocols, str) or not isinstance(protocols, list):
raise KeyError("Optional 'protocols' must be a list if present.")
for p in protocols:
if not isinstance(p, str):
raise KeyError("Bad value for 'protocols' item")
if as_info["url"] is None:
logger.info(
"(%s) Explicitly empty 'url' provided. This application service"
" will not receive events or queries.",
config_filename,
)
return ApplicationService(
token=as_info["as_token"],
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
sender=user_id,
id=as_info["id"],
protocols=protocols,
)

View File

@@ -27,7 +27,6 @@ class CaptchaConfig(Config):
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
## Captcha ## ## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
# This Home Server's ReCAPTCHA public key. # This Home Server's ReCAPTCHA public key.
recaptcha_public_key: "YOUR_PUBLIC_KEY" recaptcha_public_key: "YOUR_PUBLIC_KEY"

View File

@@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file can't be called email.py because if it is, we cannot:
import email.utils
from ._base import Config
class EmailConfig(Config):
def read_config(self, config):
self.email_enable_notifs = False
email_config = config.get("email", {})
self.email_enable_notifs = email_config.get("enable_notifs", False)
if self.email_enable_notifs:
# make sure we can import the required deps
import jinja2
import bleach
# prevent unused warnings
jinja2
bleach
required = [
"smtp_host",
"smtp_port",
"notif_from",
"template_dir",
"notif_template_html",
"notif_template_text",
]
missing = []
for k in required:
if k not in email_config:
missing.append(k)
if (len(missing) > 0):
raise RuntimeError(
"email.enable_notifs is True but required keys are missing: %s" %
(", ".join(["email." + k for k in missing]),)
)
if config.get("public_baseurl") is None:
raise RuntimeError(
"email.enable_notifs is True but no public_baseurl is set"
)
self.email_smtp_host = email_config["smtp_host"]
self.email_smtp_port = email_config["smtp_port"]
self.email_notif_from = email_config["notif_from"]
self.email_template_dir = email_config["template_dir"]
self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"]
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
self.email_app_name = "Matrix"
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
if parsed[1] == '':
raise RuntimeError("Invalid notif_from address")
else:
self.email_enable_notifs = False
# Not much point setting defaults for the rest: it would be an
# error for them to be used.
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events
#email:
# enable_notifs: false
# smtp_host: "localhost"
# smtp_port: 25
# notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
# app_name: Matrix
# template_dir: res/templates
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
# notif_for_new_users: True
"""

View File

@@ -31,16 +31,13 @@ from .cas import CasConfig
from .password import PasswordConfig from .password import PasswordConfig
from .jwt import JWTConfig from .jwt import JWTConfig
from .ldap import LDAPConfig from .ldap import LDAPConfig
from .emailconfig import EmailConfig
from .workers import WorkerConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig, JWTConfig, LDAPConfig, PasswordConfig,):
WorkerConfig,):
pass pass

View File

@@ -13,16 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ConfigError from ._base import Config
MISSING_JWT = (
"""Missing jwt library. This is required for jwt login.
Install by running:
pip install pyjwt
"""
)
class JWTConfig(Config): class JWTConfig(Config):
@@ -32,12 +23,6 @@ class JWTConfig(Config):
self.jwt_enabled = jwt_config.get("enabled", False) self.jwt_enabled = jwt_config.get("enabled", False)
self.jwt_secret = jwt_config["secret"] self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"] self.jwt_algorithm = jwt_config["algorithm"]
try:
import jwt
jwt # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
else: else:
self.jwt_enabled = False self.jwt_enabled = False
self.jwt_secret = None self.jwt_secret = None
@@ -45,8 +30,6 @@ class JWTConfig(Config):
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
# jwt_config: # jwt_config:
# enabled: true # enabled: true
# secret: "a secret" # secret: "a secret"

View File

@@ -57,8 +57,6 @@ class KeyConfig(Config):
seed = self.signing_key[0].seed seed = self.signing_key[0].seed
self.macaroon_secret_key = hashlib.sha256(seed) self.macaroon_secret_key = hashlib.sha256(seed)
self.expire_access_token = config.get("expire_access_token", False)
def default_config(self, config_dir_path, server_name, is_generating_file=False, def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs): **kwargs):
base_key_name = os.path.join(config_dir_path, server_name) base_key_name = os.path.join(config_dir_path, server_name)
@@ -71,9 +69,6 @@ class KeyConfig(Config):
return """\ return """\
macaroon_secret_key: "%(macaroon_secret_key)s" macaroon_secret_key: "%(macaroon_secret_key)s"
# Used to enable access token expiration.
expire_access_token: False
## Signing Keys ## ## Signing Keys ##
# Path to the signing key to sign messages with # Path to the signing key to sign messages with

View File

@@ -13,88 +13,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ConfigError from ._base import Config
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config): class LDAPConfig(Config):
def read_config(self, config): def read_config(self, config):
ldap_config = config.get("ldap_config", {}) ldap_config = config.get("ldap_config", None)
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False) self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
if self.ldap_enabled: self.ldap_port = ldap_config["port"]
# verify dependencies are available self.ldap_tls = ldap_config.get("tls", False)
try: self.ldap_search_base = ldap_config["search_base"]
import ldap3 self.ldap_search_property = ldap_config["search_property"]
ldap3 # to stop unused lint self.ldap_email_property = ldap_config["email_property"]
except ImportError: self.ldap_full_name_property = ldap_config["full_name_property"]
raise ConfigError(MISSING_LDAP3) else:
self.ldap_enabled = False
self.ldap_mode = LDAPMode.SIMPLE self.ldap_server = None
self.ldap_port = None
# verify config sanity self.ldap_tls = False
self.require_keys(ldap_config, [ self.ldap_search_base = None
"uri", self.ldap_search_property = None
"base", self.ldap_email_property = None
"attributes", self.ldap_full_name_property = None
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
# ldap_config: # ldap_config:
# enabled: true # enabled: true
# uri: "ldap://ldap.example.com:389" # server: "ldap://localhost"
# start_tls: true # port: 389
# base: "ou=users,dc=example,dc=com" # tls: false
# attributes: # search_base: "ou=Users,dc=example,dc=com"
# uid: "cn" # search_property: "cn"
# mail: "email" # email_property: "email"
# name: "givenName" # full_name_property: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
""" """

View File

@@ -126,58 +126,54 @@ class LoggingConfig(Config):
) )
def setup_logging(self): def setup_logging(self):
setup_logging(self.log_config, self.log_file, self.verbosity) log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
)
if self.log_config is None:
level = logging.INFO
level_for_storage = logging.INFO
if self.verbosity:
level = logging.DEBUG
if self.verbosity > 1:
level_for_storage = logging.DEBUG
def setup_logging(log_config=None, log_file=None, verbosity=None): # FIXME: we need a logging.WARN for a -q quiet option
log_format = ( logger = logging.getLogger('')
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" logger.setLevel(level)
" - %(message)s"
)
if log_config is None:
level = logging.INFO logging.getLogger('synapse.storage').setLevel(level_for_storage)
level_for_storage = logging.INFO
if verbosity:
level = logging.DEBUG
if verbosity > 1:
level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option formatter = logging.Formatter(log_format)
logger = logging.getLogger('') if self.log_file:
logger.setLevel(level) # TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler(
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
)
logging.getLogger('synapse.storage').setLevel(level_for_storage) def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
formatter = logging.Formatter(log_format) # TODO(paul): obviously this is a terrible mechanism for
if log_file: # stealing SIGHUP, because it means no other part of synapse
# TODO: Customisable file size / backup count # can use it instead. If we want to catch SIGHUP anywhere
handler = logging.handlers.RotatingFileHandler( # else as well, I'd suggest we find a nicer way to broadcast
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3 # it around.
) if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
def sighup(signum, stack): handler.addFilter(LoggingContextFilter(request=""))
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for logger.addHandler(handler)
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else: else:
handler = logging.StreamHandler() with open(self.log_config, 'r') as f:
handler.setFormatter(formatter) logging.config.dictConfig(yaml.load(f))
handler.addFilter(LoggingContextFilter(request="")) observer = PythonLoggingObserver()
observer.start()
logger.addHandler(handler)
else:
with open(log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
observer = PythonLoggingObserver()
observer.start()

View File

@@ -23,14 +23,10 @@ class PasswordConfig(Config):
def read_config(self, config): def read_config(self, config):
password_config = config.get("password_config", {}) password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True) self.password_enabled = password_config.get("enabled", True)
self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable password for login. # Enable password for login.
password_config: password_config:
enabled: true enabled: true
# Uncomment and change to a secret random string for extra security.
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#pepper: ""
""" """

View File

@@ -32,7 +32,6 @@ class RegistrationConfig(Config):
) )
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.user_creation_max_duration = int(config["user_creation_max_duration"])
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"] self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
@@ -55,11 +54,6 @@ class RegistrationConfig(Config):
# secret, even if registration is otherwise disabled. # secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s" registration_shared_secret: "%(registration_shared_secret)s"
# Sets the expiry for the short term user creation in
# milliseconds. For instance the bellow duration is two weeks
# in milliseconds.
user_creation_max_duration: 1209600000
# Set the number of bcrypt rounds used to generate password hash. # Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.

View File

@@ -13,25 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ConfigError from ._base import Config
from collections import namedtuple from collections import namedtuple
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
)
MISSING_LXML = (
"""Missing lxml library. This is required for URL preview API.
Install by running:
pip install lxml
Requires libxslt1-dev system package.
"""
)
ThumbnailRequirement = namedtuple( ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"] "ThumbnailRequirement", ["width", "height", "method", "media_type"]
) )
@@ -39,7 +23,7 @@ ThumbnailRequirement = namedtuple(
def parse_thumbnail_requirements(thumbnail_sizes): def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys """ Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumbnailing and creates a map from image media types to the thumbnail size, thumnailing
method, and thumbnail media type to precalculate method, and thumbnail media type to precalculate
Args: Args:
@@ -69,44 +53,12 @@ class ContentRepositoryConfig(Config):
def read_config(self, config): def read_config(self, config):
self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_spider_size = self.parse_size(config["max_spider_size"])
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"] self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements( self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"] config["thumbnail_sizes"]
) )
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
import lxml
lxml # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_LXML)
try:
from netaddr import IPSet
except ImportError:
raise ConfigError(MISSING_NETADDR)
if "url_preview_ip_range_blacklist" in config:
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
else:
raise ConfigError(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
self.url_preview_url_blacklist = config.get(
"url_preview_url_blacklist", ()
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
@@ -128,7 +80,7 @@ class ContentRepositoryConfig(Config):
# the resolution requested by the client. If true then whenever # the resolution requested by the client. If true then whenever
# a new resolution is requested by the client the server will # a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail # generate a new thumbnail. If false the server will pick a thumbnail
# from a precalculated list. # from a precalcualted list.
dynamic_thumbnails: false dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded. # List of thumbnail to precalculate when an image is uploaded.
@@ -148,71 +100,4 @@ class ContentRepositoryConfig(Config):
- width: 800 - width: 800
height: 600 height: 600
method: scale method: scale
# Is the preview URL API enabled? If enabled, you *must* specify
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
# denied from accessing.
url_preview_enabled: False
# List of IP address CIDR ranges that the URL preview spider is denied
# from accessing. There are no defaults: you must explicitly
# specify a list for URL previewing to work. You should specify any
# internal services in your network that you do not want synapse to try
# to connect to, otherwise anyone in any Matrix room could cause your
# synapse to issue arbitrary GET requests to your internal services,
# causing serious security issues.
#
# url_preview_ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
# This is useful for specifying exceptions to wide-ranging blacklisted
# target IP ranges - e.g. for enabling URL previews for a specific private
# website only visible in your network.
#
# url_preview_ip_range_whitelist:
# - '192.168.1.1'
# Optional list of URL matches that the URL preview spider is
# denied from accessing. You should use url_preview_ip_range_blacklist
# in preference to this, otherwise someone could define a public DNS
# entry that points to a private IP address and circumvent the blacklist.
# This is more useful if you know there is an entire shape of URL that
# you know that will never want synapse to try to spider.
#
# Each list entry is a dictionary of url component attributes as returned
# by urlparse.urlsplit as applied to the absolute form of the URL. See
# https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit
# The values of the dictionary are treated as an filename match pattern
# applied to that component of URLs, unless they start with a ^ in which
# case they are treated as a regular expression match. If all the
# specified component matches for a given list item succeed, the URL is
# blacklisted.
#
# url_preview_url_blacklist:
# # blacklist any URL with a username in its URI
# - username: '*'
#
# # blacklist all *.google.com URLs
# - netloc: 'google.com'
# - netloc: '*.google.com'
#
# # blacklist all plain HTTP URLs
# - scheme: 'http'
#
# # blacklist http(s)://www.acme.com/foo
# - netloc: 'www.acme.com'
# path: '/foo'
#
# # blacklist any URL with a literal IPv4 address
# - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
# The largest allowed URL preview spidering size in bytes
max_spider_size: "10M"
""" % locals() """ % locals()

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config, ConfigError from ._base import Config
class ServerConfig(Config): class ServerConfig(Config):
@@ -27,18 +27,10 @@ class ServerConfig(Config):
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile") self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.public_baseurl = config.get("public_baseurl")
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
self.start_pushers = config.get("start_pushers", True)
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port") bind_port = config.get("bind_port")
if bind_port: if bind_port:
self.listeners = [] self.listeners = []
@@ -106,6 +98,26 @@ class ServerConfig(Config):
] ]
}) })
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name
if ':' not in host:
host = "%s:%d" % (host, unsecure_port)
else:
host = host.split(':')[0]
host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,)
self.content_addr = content_addr
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])
@@ -130,17 +142,11 @@ class ServerConfig(Config):
# Whether to serve a web client from the HTTP/HTTPS root resource. # Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True web_client: True
# The public-facing base URL for the client API (not including _matrix/...)
# public_baseurl: https://example.com:8448/
# Set the soft limit on the number of file descriptors synapse can use # Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the # Zero is used to indicate synapse should set the soft limit to the
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:
@@ -222,20 +228,3 @@ class ServerConfig(Config):
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"
" service on the given port.") " service on the given port.")
def read_gc_thresholds(thresholds):
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.
"""
if thresholds is None:
return None
try:
assert len(thresholds) == 3
return (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
except:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)

View File

@@ -1,31 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class WorkerConfig(Config):
"""The workers are processes run separately to the main synapse process.
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
def read_config(self, config):
self.worker_app = config.get("worker_app")
self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url")

View File

@@ -77,12 +77,10 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self): def __init__(self):
self.remote_key = defer.Deferred() self.remote_key = defer.Deferred()
self.host = None self.host = None
self._peer = None
def connectionMade(self): def connectionMade(self):
self._peer = self.transport.getPeer() self.host = self.transport.getHost()
logger.debug("Connected to %s", self._peer) logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host: if self.host:
self.sendHeader(b"Host", self.host) self.sendHeader(b"Host", self.host)
@@ -126,10 +124,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug( logger.debug("Timeout waiting for response from %s", self.host)
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response")) self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
@@ -138,5 +133,4 @@ class SynapseKeyClientFactory(Factory):
def protocol(self): def protocol(self):
protocol = SynapseKeyClientProtocol() protocol = SynapseKeyClientProtocol()
protocol.path = self.path protocol.path = self.path
protocol.host = self.host
return protocol return protocol

View File

@@ -22,7 +22,6 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn preserve_fn
) )
from synapse.util.metrics import Measure
from twisted.internet import defer from twisted.internet import defer
@@ -45,25 +44,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VerifyKeyRequest = namedtuple("VerifyRequest", ( KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class KeyLookupError(ValueError):
pass
class Keyring(object): class Keyring(object):
@@ -93,32 +74,39 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each list of deferreds indicating success or failure to verify each
json object's signature for the given server_name. json object's signature for the given server_name.
""" """
verify_requests = [] group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json: for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
deferred = defer.fail(SynapseError( deferreds[group_id] = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
else: else:
deferred = defer.Deferred() deferreds[group_id] = defer.Deferred()
verify_request = VerifyKeyRequest( group = KeyGroup(server_name, group_id, key_ids)
server_name, key_ids, json_object, deferred
)
verify_requests.append(verify_request) group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_key_deferred(verify_request): def handle_key_deferred(group, deferred):
server_name = verify_request.server_name server_name = group.server_name
try: try:
_, key_id, verify_key = yield verify_request.deferred _, _, key_id, verify_key = yield deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@@ -140,7 +128,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = verify_request.json_object json_object = group_id_to_json[group.group_id]
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
@@ -169,34 +157,36 @@ class Keyring(object):
# Actually start fetching keys. # Actually start fetching keys.
wait_on_deferred.addBoth( wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(verify_requests) lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
) )
# When we've finished fetching all the keys for a given server_name, # When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed. # any lookups waiting will proceed.
server_to_request_ids = {} server_to_gids = {}
def remove_deferreds(res, server_name, verify_request): def remove_deferreds(res, server_name, group_id):
request_id = id(verify_request) server_to_gids[server_name].discard(group_id)
server_to_request_ids[server_name].discard(request_id) if not server_to_gids[server_name]:
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) d = server_to_deferred.pop(server_name, None)
if d: if d:
d.callback(None) d.callback(None)
return res return res
for verify_request in verify_requests: for g_id, deferred in deferreds.items():
server_name = verify_request.server_name server_name = group_id_to_group[g_id].server_name
request_id = id(verify_request) server_to_gids.setdefault(server_name, set()).add(g_id)
server_to_request_ids.setdefault(server_name, set()).add(request_id) deferred.addBoth(remove_deferreds, server_name, g_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
preserve_context_over_fn(handle_key_deferred, verify_request) preserve_context_over_fn(
for verify_request in verify_requests handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -230,7 +220,7 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for """Takes a dict of KeyGroups and tries to find at least one key for
each group. each group.
""" """
@@ -244,79 +234,76 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_iterations(): def do_iterations():
with Measure(self.clock, "get_server_verify_keys"): merged_results = {}
merged_results = {}
missing_keys = {} missing_keys = {}
for verify_request in verify_requests: for group in group_id_to_group.values():
missing_keys.setdefault(verify_request.server_name, set()).update( missing_keys.setdefault(group.server_name, set()).update(
verify_request.key_ids group.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
) )
for server_name, groups in missing_groups.items()
}
for fn in key_fetch_fns: for group in missing_groups.values():
results = yield fn(missing_keys.items()) group_id_to_deferred[group.group_id].errback(SynapseError(
merged_results.update(results) 401,
"No key for %s with id %s" % (
# We now need to figure out which verify requests we have keys group.server_name, group.key_ids,
# for and which we don't ),
missing_keys = {} Codes.UNAUTHORIZED,
requests_missing_keys = [] ))
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext():
verify_request.deferred.callback((
server_name,
key_id,
result_keys[key_id],
))
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_keys:
break
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err): def on_err(err):
for verify_request in verify_requests: for deferred in group_id_to_deferred.values():
if not verify_request.deferred.called: if not deferred.called:
verify_request.deferred.errback(err) deferred.errback(err)
do_iterations().addErrback(on_err) do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield preserve_context_over_deferred(defer.gatherResults( res = yield defer.gatherResults(
[ [
preserve_fn(self.store.get_server_verify_keys)( self.store.get_server_verify_keys(
server_name, key_ids server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name) ).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(dict(res)) defer.returnValue(dict(res))
@@ -337,13 +324,13 @@ class Keyring(object):
) )
defer.returnValue({}) defer.returnValue({})
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield defer.gatherResults(
[ [
preserve_fn(get_key)(p_name, p_keys) get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
union_of_keys = {} union_of_keys = {}
for result in results: for result in results:
@@ -369,7 +356,7 @@ class Keyring(object):
) )
except Exception as e: except Exception as e:
logger.info( logger.info(
"Unable to get key %r for %r directly: %s %s", "Unable to getting key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
@@ -383,13 +370,13 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield defer.gatherResults(
[ [
preserve_fn(get_key)(server_name, key_ids) get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
merged = {} merged = {}
for result in results: for result in results:
@@ -431,7 +418,7 @@ class Keyring(object):
for response in responses: for response in responses:
if (u"signatures" not in response if (u"signatures" not in response
or perspective_name not in response[u"signatures"]): or perspective_name not in response[u"signatures"]):
raise KeyLookupError( raise ValueError(
"Key response not signed by perspective server" "Key response not signed by perspective server"
" %r" % (perspective_name,) " %r" % (perspective_name,)
) )
@@ -454,21 +441,21 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]), list(response[u"signatures"][perspective_name]),
list(perspective_keys) list(perspective_keys)
) )
raise KeyLookupError( raise ValueError(
"Response not signed with a known key for perspective" "Response not signed with a known key for perspective"
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
processed_response = yield self.process_v2_response( processed_response = yield self.process_v2_response(
perspective_name, response, only_from_server=False perspective_name, response
) )
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys) keys.setdefault(server_name, {}).update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store_keys)( self.store_keys(
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
verify_keys=response_keys, verify_keys=response_keys,
@@ -476,7 +463,7 @@ class Keyring(object):
for server_name, response_keys in keys.items() for server_name, response_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@@ -497,10 +484,10 @@ class Keyring(object):
if (u"signatures" not in response if (u"signatures" not in response
or server_name not in response[u"signatures"]): or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server") raise ValueError("Key response not signed by remote server")
if "tls_fingerprints" not in response: if "tls_fingerprints" not in response:
raise KeyLookupError("Key response missing TLS fingerprints") raise ValueError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate( certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate crypto.FILETYPE_ASN1, tls_certificate
@@ -514,7 +501,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"]) response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints: if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise KeyLookupError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
@@ -524,7 +511,7 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store_keys)( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
@@ -534,13 +521,13 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items() for key_server_name, verify_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_ids=[], only_from_server=True): requested_ids=[]):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@@ -564,16 +551,9 @@ class Keyring(object):
results = {} results = {}
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise KeyLookupError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@@ -600,7 +580,7 @@ class Keyring(object):
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store.store_server_keys_json)( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
@@ -613,7 +593,7 @@ class Keyring(object):
for key_id in updated_key_ids for key_id in updated_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
results[server_name] = response_keys results[server_name] = response_keys
@@ -641,15 +621,15 @@ class Keyring(object):
if ("signatures" not in response if ("signatures" not in response
or server_name not in response["signatures"]): or server_name not in response["signatures"]):
raise KeyLookupError("Key response not signed by remote server") raise ValueError("Key response not signed by remote server")
if "tls_certificate" not in response: if "tls_certificate" not in response:
raise KeyLookupError("Key response missing TLS certificate") raise ValueError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"] tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64: if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise KeyLookupError("TLS certificate doesn't match") raise ValueError("TLS certificate doesn't match")
# Cache the result in the datastore. # Cache the result in the datastore.
@@ -665,7 +645,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]: for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]: if key_id not in response["verify_keys"]:
raise KeyLookupError( raise ValueError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@@ -702,7 +682,7 @@ class Keyring(object):
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield preserve_context_over_deferred(defer.gatherResults( yield defer.gatherResults(
[ [
preserve_fn(self.store.store_server_verify_key)( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
@@ -710,4 +690,4 @@ class Keyring(object):
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)

View File

@@ -99,7 +99,7 @@ class EventBase(object):
return d return d
def get(self, key, default=None): def get(self, key, default):
return self._event_dict.get(key, default) return self._event_dict.get(key, default)
def get_internal_metadata_dict(self): def get_internal_metadata_dict(self):

View File

@@ -15,30 +15,9 @@
class EventContext(object): class EventContext(object):
__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
"push_actions",
"prev_group",
"delta_ids",
"prev_state_events",
]
def __init__(self): def __init__(self, current_state=None):
# The current state including the current event self.current_state = current_state
self.current_state_ids = None
# The current state excluding the current event
self.prev_state_ids = None
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False
self.push_actions = [] self.push_actions = []
# A previously persisted state group and a delta between that
# and this state.
self.prev_group = None
self.delta_ids = None
self.prev_state_events = None

View File

@@ -88,8 +88,6 @@ def prune_event(event):
if "age_ts" in event.unsigned: if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)( return type(event)(
allowed_fields, allowed_fields,

View File

@@ -23,7 +23,6 @@ from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
import logging import logging
@@ -32,9 +31,6 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs):
pass
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False): include_none=False):
@@ -103,10 +99,10 @@ class FederationBase(object):
warn, pdu warn, pdu
) )
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( valid_pdus = yield defer.gatherResults(
deferreds, deferreds,
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
if include_none: if include_none:
defer.returnValue(valid_pdus) defer.returnValue(valid_pdus)
@@ -130,7 +126,7 @@ class FederationBase(object):
for pdu in pdus for pdu in pdus
] ]
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json()) (p.origin, p.get_pdu_json())
for p in redacted_pdus for p in redacted_pdus
]) ])

View File

@@ -26,9 +26,7 @@ from synapse.api.errors import (
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -52,34 +50,7 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
self.state = hs.get_state_handler()
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
old_dict = self.pdu_destination_tried
self.pdu_destination_tried = {}
for event_id, destination_dict in old_dict.items():
destination_dict = {
dest: time
for dest, time in destination_dict.items()
if time + PDU_RETRY_TIME_MS > now
}
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self): def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
@@ -121,12 +92,8 @@ class FederationClient(FederationBase):
pdu.event_id pdu.event_id
) )
def send_presence(self, destination, states):
if destination != self.server_name:
self._transaction_queue.enqueue_presence(destination, states)
@log_function @log_function
def send_edu(self, destination, edu_type, content, key=None): def send_edu(self, destination, edu_type, content):
edu = Edu( edu = Edu(
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
@@ -136,13 +103,9 @@ class FederationClient(FederationBase):
sent_edus_counter.inc() sent_edus_counter.inc()
self._transaction_queue.enqueue_edu(edu, key=key) # TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
@log_function return defer.succeed(None)
def send_device_messages(self, destination):
"""Sends the device messages in the local database to the remote
destination"""
self._transaction_queue.enqueue_device_messages(destination)
@log_function @log_function
def send_failure(self, failure, destination): def send_failure(self, failure, destination):
@@ -173,7 +136,7 @@ class FederationClient(FederationBase):
) )
@log_function @log_function
def query_client_keys(self, destination, content, timeout): def query_client_keys(self, destination, content):
"""Query device keys for a device hosted on a remote server. """Query device keys for a device hosted on a remote server.
Args: Args:
@@ -185,12 +148,10 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_device_keys") sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys( return self.transport_layer.query_client_keys(destination, content)
destination, content, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content, timeout): def claim_client_keys(self, destination, content):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
@@ -202,9 +163,7 @@ class FederationClient(FederationBase):
response response
""" """
sent_queries_counter.inc("client_one_time_keys") sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys( return self.transport_layer.claim_client_keys(destination, content)
destination, content, timeout
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -239,10 +198,10 @@ class FederationClient(FederationBase):
] ]
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( pdus[:] = yield defer.gatherResults(
self._check_sigs_and_hashes(pdus), self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(pdus) defer.returnValue(pdus)
@@ -274,19 +233,12 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache: if self._get_pdu_cache:
ev = self._get_pdu_cache.get(event_id) e = self._get_pdu_cache.get(event_id)
if ev: if e:
defer.returnValue(ev) defer.returnValue(e)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) pdu = None
signed_pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
if last_attempt + PDU_RETRY_TIME_MS > now:
continue
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
@@ -310,33 +262,39 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
pdu_attempts[destination] = now except SynapseError:
except SynapseError as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
) )
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
) )
continue continue
if self._get_pdu_cache is not None and signed_pdu: if self._get_pdu_cache is not None and pdu:
self._get_pdu_cache[event_id] = signed_pdu self._get_pdu_cache[event_id] = pdu
defer.returnValue(signed_pdu) defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -353,42 +311,6 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
try:
# First we try and ask for just the IDs, as thats far quicker if
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id,
)
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events(
[destination], room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
logger.warn("Failed to get %r", failed_to_fetch)
event_map = {
ev.event_id: ev for ev in fetched_events
}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue((pdus, auth_chain))
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
else:
raise e
result = yield self.transport_layer.get_room_state( result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id, destination, room_id, event_id=event_id,
) )
@@ -402,95 +324,18 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
seen_events = yield self.store.get_events([
ev.event_id for ev in itertools.chain(pdus, auth_chain)
])
signed_pdus = yield self._check_sigs_and_hash_and_fetch( signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, destination, pdus, outlier=True
[p for p in pdus if p.event_id not in seen_events],
outlier=True
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
) )
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, destination, auth_chain, outlier=True
[p for p in auth_chain if p.event_id not in seen_events],
outlier=True
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth)) defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
def get_events(self, destinations, room_id, event_ids, return_local=True):
"""Fetch events from some remote destinations, checking if we already
have them.
Args:
destinations (list)
room_id (str)
event_ids (list)
return_local (bool): Whether to include events we already have in
the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
signed_events = []
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
def random_server_list():
srvs = list(destinations)
random.shuffle(srvs)
return srvs
batch_size = 20
missing_events = list(missing_events)
for i in xrange(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
defer.returnValue((signed_events, failed_to_fetch))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):
@@ -566,19 +411,14 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )
break break
except CodeMessageException as e: except CodeMessageException:
if not 500 <= e.code < 600: raise
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_%s via %s: %s", "Failed to make_%s via %s: %s",
membership, destination, e.message membership, destination, e.message
) )
raise
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@@ -650,14 +490,8 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException as e: except CodeMessageException:
if not 500 <= e.code < 600: raise
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",
@@ -716,15 +550,6 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
def get_public_rooms(self, destination, limit=None, since_token=None,
search_filter=None):
if destination == self.server_name:
return
return self.transport_layer.get_public_rooms(
destination, limit, since_token, search_filter
)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
""" """
@@ -814,8 +639,7 @@ class FederationClient(FederationBase):
if len(signed_events) >= limit: if len(signed_events) >= limit:
defer.returnValue(signed_events) defer.returnValue(signed_events)
users = yield self.state.get_current_user_in_room(room_id) servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
servers = set(servers) servers = set(servers)
servers.discard(self.server_name) servers.discard(self.server_name)
@@ -860,16 +684,14 @@ class FederationClient(FederationBase):
return srvs return srvs
deferreds = [ deferreds = [
preserve_fn(self.get_pdu)( self.get_pdu(
destinations=random_server_list(), destinations=random_server_list(),
event_id=e_id, event_id=e_id,
) )
for e_id, depth in ordered_missing[:limit - len(signed_events)] for e_id, depth in ordered_missing[:limit - len(signed_events)]
] ]
res = yield preserve_context_over_deferred( res = yield defer.DeferredList(deferreds, consumeErrors=True)
defer.DeferredList(deferreds, consumeErrors=True)
)
for (result, val), (e_id, _) in zip(res, ordered_missing): for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val: if result and val:
signed_events.append(val) signed_events.append(val)

View File

@@ -19,13 +19,11 @@ from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.async import Linearizer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
@@ -46,18 +44,6 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
class FederationServer(FederationBase): class FederationServer(FederationBase):
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are receipt of new PDUs from other home servers. The required methods are
@@ -97,14 +83,11 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
with (yield self._server_linearizer.queue((origin, room_id))): pdus = yield self.handler.on_backfill_request(
pdus = yield self.handler.on_backfill_request( origin, room_id, versions, limit
origin, room_id, versions, limit )
)
res = self._transaction_from_pdus(pdus).get_dict() defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -188,64 +171,22 @@ class FederationServer(FederationBase):
except SynapseError as e: except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e) logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception as e: except Exception as e:
logger.exception("Failed to handle edu %r", edu_type) logger.exception("Failed to handle edu %r", edu_type, e)
else: else:
logger.warn("Received EDU of type %s with no handler", edu_type) logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, origin, room_id, event_id): def on_context_state_request(self, origin, room_id, event_id):
if not event_id: if event_id:
raise NotImplementedError("Specify an event") pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
in_room = yield self.auth.check_host_in_room(room_id, origin) for event in auth_chain:
if not in_room:
raise AuthError(403, "Host not in room.")
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
)
else:
resp = yield result
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_state_ids_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_ids = yield self.handler.get_state_ids_for_pdu(
room_id, event_id,
)
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
defer.returnValue((200, {
"pdu_ids": state_ids,
"auth_chain_ids": auth_chain_ids,
}))
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
event, event,
@@ -253,11 +194,13 @@ class FederationServer(FederationBase):
self.hs.config.signing_key[0] self.hs.config.signing_key[0]
) )
) )
else:
raise NotImplementedError("Specify an event")
defer.returnValue({ defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus], "pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}) }))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -331,16 +274,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
with (yield self._server_linearizer.queue((origin, room_id))): time_now = self._clock.time_msec()
time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id)
auth_pdus = yield self.handler.on_event_auth(event_id) defer.returnValue((200, {
res = { "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], }))
}
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth_request(self, origin, content, room_id, event_id): def on_query_auth_request(self, origin, content, event_id):
""" """
Content is a dict with keys:: Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain. auth_chain (list): A list of events that give the auth chain.
@@ -359,41 +300,58 @@ class FederationServer(FederationBase):
Returns: Returns:
Deferred: Results in `dict` with the same format as `content` Deferred: Results in `dict` with the same format as `content`
""" """
with (yield self._server_linearizer.queue((origin, room_id))): auth_chain = [
auth_chain = [ self.event_from_pdu_json(e)
self.event_from_pdu_json(e) for e in content["auth_chain"]
for e in content["auth_chain"] ]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True origin, auth_chain, outlier=True
) )
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(
origin, origin,
event_id, event_id,
signed_auth, signed_auth,
content.get("rejects", []), content.get("rejects", []),
content.get("missing", []), content.get("missing", []),
) )
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
send_content = { send_content = {
"auth_chain": [ "auth_chain": [
e.get_pdu_json(time_now) e.get_pdu_json(time_now)
for e in ret["auth_chain"] for e in ret["auth_chain"]
], ],
"rejects": ret.get("rejects", []), "rejects": ret.get("rejects", []),
"missing": ret.get("missing", []), "missing": ret.get("missing", []),
} }
defer.returnValue( defer.returnValue(
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content) query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -419,34 +377,16 @@ class FederationServer(FederationBase):
@log_function @log_function
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth): latest_events, limit, min_depth):
with (yield self._server_linearizer.queue((origin, room_id))): missing_events = yield self.handler.on_get_missing_events(
logger.info( origin, room_id, earliest_events, latest_events, limit, min_depth
"on_get_missing_events: earliest_events: %r, latest_events: %r," )
" limit: %d, min_depth: %d",
earliest_events, latest_events, limit, min_depth
)
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit, min_depth
)
if len(missing_events) < 5: time_now = self._clock.time_msec()
logger.info(
"Returning %d events: %r", len(missing_events), missing_events
)
else:
logger.info("Returning %d events", len(missing_events))
time_now = self._clock.time_msec()
defer.returnValue({ defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events], "events": [ev.get_pdu_json(time_now) for ev in missing_events],
}) })
@log_function
def on_openid_userinfo(self, token):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
@log_function @log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True): def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.
@@ -536,59 +476,42 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth: elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen: if get_missing and prevs - seen:
# If we're missing stuff, ensure we only fetch stuff one latest = yield self.store.get_latest_event_ids_in_room(
# at a time. pdu.room_id
with (yield self._room_pdu_linearizer.queue(pdu.room_id)): )
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
if prevs - seen: # We add the prev events that we have seen to the latest
latest = yield self.store.get_latest_event_ids_in_room( # list to ensure the remote server doesn't give them to us
pdu.room_id latest = set(latest)
) latest |= seen
# We add the prev events that we have seen to the latest missing_events = yield self.get_missing_events(
# list to ensure the remote server doesn't give them to us origin,
latest = set(latest) pdu.room_id,
latest |= seen earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
)
logger.info( # We want to sort these by depth so we process them and
"Missing %d events for room %r: %r...", # tell clients about them in order.
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5] missing_events.sort(key=lambda x: x.depth)
)
missing_events = yield self.get_missing_events( for e in missing_events:
origin, yield self._handle_new_pdu(
pdu.room_id, origin,
earliest_events_ids=list(latest), e,
latest_events=[pdu], get_missing=False
limit=10, )
min_depth=min_depth,
)
# We want to sort these by depth so we process them and have_seen = yield self.store.have_events(
# tell clients about them in order. [ev for ev, _ in pdu.prev_events]
missing_events.sort(key=lambda x: x.depth) )
for e in missing_events:
yield self._handle_new_pdu(
origin,
e,
get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
prevs = {e_id for e_id, _ in pdu.prev_events} prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys()) seen = set(have_seen.keys())
if prevs - seen: if prevs - seen:
logger.info(
"Still missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
fetch_state = True fetch_state = True
if fetch_state: if fetch_state:
@@ -603,7 +526,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
except: except:
logger.exception("Failed to get state for event: %s", pdu.event_id) logger.warn("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu( yield self.handler.on_receive_pdu(
origin, origin,

View File

@@ -72,7 +72,5 @@ class ReplicationLayer(FederationClient, FederationServer):
self.hs = hs self.hs = hs
super(ReplicationLayer, self).__init__(hs)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name

View File

@@ -17,16 +17,14 @@
from twisted.internet import defer from twisted.internet import defer
from .persistence import TransactionActions from .persistence import TransactionActions
from .units import Transaction, Edu from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import ( from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state
import synapse.metrics import synapse.metrics
import logging import logging
@@ -52,7 +50,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.clock = hs.get_clock() self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track # Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are # of which destinations have transactions in flight and when they are
@@ -70,30 +68,20 @@ class TransactionQueue(object):
# destination -> list of tuple(edu, deferred) # destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {} self.pending_edus_by_dest = edus = {}
# Presence needs to be separate as we send single aggragate EDUs
self.pending_presence_by_dest = presence = {}
self.pending_edus_keyed_by_dest = edus_keyed = {}
metrics.register_callback( metrics.register_callback(
"pending_pdus", "pending_pdus",
lambda: sum(map(len, pdus.values())), lambda: sum(map(len, pdus.values())),
) )
metrics.register_callback( metrics.register_callback(
"pending_edus", "pending_edus",
lambda: ( lambda: sum(map(len, edus.values())),
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
+ sum(map(len, edus_keyed.values()))
),
) )
# destination -> list of tuple(failure, deferred) # destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self._clock.time_msec())
def can_send_to(self, destination): def can_send_to(self, destination):
"""Can we send messages to the given server? """Can we send messages to the given server?
@@ -130,68 +118,86 @@ class TransactionQueue(object):
if not destinations: if not destinations:
return return
deferreds = []
for destination in destinations: for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append( self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order) (pdu, deferred, order)
) )
preserve_context_over_fn( def chain(failure):
self._attempt_new_transaction, destination if not deferred.called:
) deferred.errback(failure)
def enqueue_presence(self, destination, states): def log_failure(f):
self.pending_presence_by_dest.setdefault(destination, {}).update({ logger.warn("Failed to send pdu to %s: %s", destination, f.value)
state.user_id: state for state in states
})
preserve_context_over_fn( deferred.addErrback(log_failure)
self._attempt_new_transaction, destination
)
def enqueue_edu(self, edu, key=None): with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
if key: deferred = defer.Deferred()
self.pending_edus_keyed_by_dest.setdefault( self.pending_edus_by_dest.setdefault(destination, []).append(
destination, {} (edu, deferred)
)[(edu.edu_type, key)] = edu
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination): def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
deferred = defer.Deferred()
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
self.pending_failures_by_dest.setdefault( self.pending_failures_by_dest.setdefault(
destination, [] destination, []
).append(failure) ).append(
(failure, deferred)
preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def enqueue_device_messages(self, destination): def chain(f):
if destination == self.server_name or destination == "localhost": if not deferred.called:
return deferred.errback(f)
if not self.can_send_to(destination): def log_failure(f):
return logger.warn("Failed to send failure to %s: %s", destination, f.value)
preserve_context_over_fn( deferred.addErrback(log_failure)
self._attempt_new_transaction, destination
) with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
@@ -205,128 +211,55 @@ class TransactionQueue(object):
) )
return return
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
yield run_on_reactor()
while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
)
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter,
)
if not success:
break
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
@defer.inlineCallbacks
def _get_new_device_messages(self, destination):
last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
to_device_stream_id = self.store.get_to_device_stream_token()
contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
destination, last_device_stream_id, to_device_stream_id
)
edus = [
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.direct_to_device",
content=content,
)
for content in contents
]
defer.returnValue((edus, stream_id))
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id,
should_delete_from_device_stream, limiter):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
failures = [x.get_dict() for x in pending_failures]
success = True
try:
logger.debug("TX [%s] _attempt_new_transaction", destination) logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
txn_id = str(self._next_txn_id) txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
logger.debug( logger.debug(
"TX [%s] {%s} Attempting new transaction" "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)", " (pdus: %d, edus: %d, failures: %d)",
destination, txn_id, destination, txn_id,
len(pdus), len(pending_pdus),
len(edus), len(pending_edus),
len(failures) len(pending_failures)
) )
logger.debug("TX [%s] Persisting transaction...", destination) logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new( transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()), origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id, transaction_id=txn_id,
origin=self.server_name, origin=self.server_name,
destination=destination, destination=destination,
@@ -345,9 +278,9 @@ class TransactionQueue(object):
" (PDUs: %d, EDUs: %d, failures: %d)", " (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id, destination, txn_id,
transaction.transaction_id, transaction.transaction_id,
len(pdus), len(pending_pdus),
len(edus), len(pending_edus),
len(failures), len(pending_failures),
) )
with limiter: with limiter:
@@ -357,7 +290,7 @@ class TransactionQueue(object):
# keys work # keys work
def json_data_cb(): def json_data_cb():
data = transaction.get_dict() data = transaction.get_dict()
now = int(self.clock.time_msec()) now = int(self._clock.time_msec())
if "pdus" in data: if "pdus" in data:
for p in data["pdus"]: for p in data["pdus"]:
if "age_ts" in p: if "age_ts" in p:
@@ -397,19 +330,28 @@ class TransactionQueue(object):
logger.debug("TX [%s] Marked as delivered", destination) logger.debug("TX [%s] Marked as delivered", destination)
if code != 200: logger.debug("TX [%s] Yielding to callbacks...", destination)
for p in pdus:
logger.info( for deferred in deferreds:
"Failed to send event %s to %s", p.event_id, destination if code == 200:
) deferred.callback(None)
success = False else:
else: deferred.errback(RuntimeError("Got status %d" % code))
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream: # Ensures we don't continue until all callbacks on that
yield self.store.delete_device_msgs_for_remote( # deferred have fired
destination, device_stream_id try:
) yield deferred
self.last_device_stream_id_by_dest[destination] = device_stream_id except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
@@ -418,11 +360,6 @@ class TransactionQueue(object):
destination, destination,
e, e,
) )
success = False
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e: except Exception as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.
@@ -432,9 +369,13 @@ class TransactionQueue(object):
e, e,
) )
success = False for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
for p in pdus: finally:
logger.info("Failed to send event %s to %s", p.event_id, destination) # We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
defer.returnValue(success) # Check to see if there is anything else to send.
self._attempt_new_transaction(destination)

View File

@@ -54,28 +54,6 @@ class TransportLayerClient(object):
destination, path=path, args={"event_id": event_id}, destination, path=path, args={"event_id": event_id},
) )
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
destination (str): The host name of the remote home server we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = PREFIX + "/state_ids/%s/" % room_id
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@log_function @log_function
def get_event(self, destination, event_id, timeout=None): def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server. """ Requests the pdu with give id and origin from the given server.
@@ -201,8 +179,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
retry_on_dns_fail=False, retry_on_dns_fail=True,
timeout=20000,
) )
defer.returnValue(content) defer.returnValue(content)
@@ -246,28 +223,6 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server, limit, since_token,
search_filter=None):
path = PREFIX + "/publicRooms"
args = {}
if limit:
args["limit"] = [str(limit)]
if since_token:
args["since"] = [since_token]
# TODO(erikj): Actually send the search_filter across federation.
response = yield self.client.get_json(
destination=remote_server,
path=path,
args=args,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, destination, room_id, event_dict): def exchange_third_party_invite(self, destination, room_id, event_dict):
@@ -308,7 +263,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def query_client_keys(self, destination, query_content, timeout): def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
@@ -337,13 +292,12 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content, timeout): def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server. """Claim one-time keys for a list of devices hosted on a remote server.
Request: Request:
@@ -374,7 +328,6 @@ class TransportLayerClient(object):
destination=destination, destination=destination,
path=path, path=path,
data=query_content, data=query_content,
timeout=timeout,
) )
defer.returnValue(content) defer.returnValue(content)

View File

@@ -18,16 +18,13 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import parse_json_object_from_request
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
)
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
import functools import functools
import logging import logging
import simplejson as json
import re import re
import synapse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -40,7 +37,7 @@ class TransportLayerServer(JsonResource):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
super(TransportLayerServer, self).__init__(hs, canonical_json=False) super(TransportLayerServer, self).__init__(hs)
self.authenticator = Authenticator(hs) self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter( self.ratelimiter = FederationRateLimiter(
@@ -63,16 +60,6 @@ class TransportLayerServer(JsonResource):
) )
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@@ -80,7 +67,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request, content): def authenticate_request(self, request):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@@ -88,11 +75,18 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
if content is not None: content = None
json_request["content"] = content
origin = None origin = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
def parse_auth_header(header_str): def parse_auth_header(header_str):
try: try:
params = auth.split(" ")[1].split(",") params = auth.split(" ")[1].split(",")
@@ -109,14 +103,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return (origin, key, sig) return (origin, key, sig)
except: except:
raise AuthenticationError( raise SynapseError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
raise NoAuthenticationError( raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@@ -127,7 +121,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]: if not json_request["signatures"]:
raise NoAuthenticationError( raise SynapseError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@@ -136,59 +130,38 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
defer.returnValue(origin) defer.returnValue((origin, content))
class BaseFederationServlet(object): class BaseFederationServlet(object):
REQUIRE_AUTH = True def __init__(self, handler, authenticator, ratelimiter, server_name):
def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
def _wrap(self, func): def _wrap(self, code):
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(func) @functools.wraps(code)
def new_func(request, *args, **kwargs): def new_code(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try: try:
origin = yield authenticator.authenticate_request(request, content) (origin, content) = yield authenticator.authenticate_request(request)
except NoAuthenticationError: with ratelimiter.ratelimit(origin) as d:
origin = None yield d
if self.REQUIRE_AUTH: response = yield code(
logger.exception("authenticate_request failed") origin, content, request.args, *args, **kwargs
raise )
except: except:
logger.exception("authenticate_request failed") logger.exception("authenticate_request failed")
raise raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response) defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish # Extra logic that functools.wraps() doesn't finish
new_func.__self__ = func.__self__ new_code.__self__ = code.__self__
return new_func return new_code
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$") pattern = re.compile("^" + PREFIX + self.PATH + "$")
@@ -296,17 +269,6 @@ class FederationStateServlet(BaseFederationServlet):
) )
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/"
def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request(
origin,
room_id,
query.get("event_id", [None])[0],
)
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/"
@@ -361,7 +323,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/event_auth(?P<context>[^/]*)/(?P<event_id>[^/]*)"
def on_GET(self, origin, content, query, context, event_id): def on_GET(self, origin, content, query, context, event_id):
return self.handler.on_event_auth(origin, context, event_id) return self.handler.on_event_auth(origin, context, event_id)
@@ -403,8 +365,10 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, origin, content, query):
return self.handler.on_query_client_keys(origin, content) response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
@@ -422,7 +386,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id): def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request( new_content = yield self.handler.on_query_auth_request(
origin, content, context, event_id origin, content, event_id
) )
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
@@ -454,10 +418,9 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, request):
content = parse_json_object_from_request(request)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
@@ -479,103 +442,10 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception raise last_exception
defer.returnValue((200, {})) defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
class OpenIdUserInfo(BaseFederationServlet): # BaseFederationServlet.
""" def _wrap(self, code):
Exchange a bearer token for information about a user. return code
The response format should be compatible with:
http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"sub": "@userpart:example.org",
}
"""
PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
token = query.get("access_token", [None])[0]
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
}))
return
user_id = yield self.handler.on_openid_userinfo(token)
if user_id is None:
defer.returnValue((401, {
"errcode": "M_UNKNOWN_TOKEN",
"error": "Access Token unknown or expired"
}))
defer.returnValue((200, {"sub": user_id}))
class PublicRoomList(BaseFederationServlet):
"""
Fetch the public room list for this server.
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
intended for consumption by other home servers.
GET /publicRooms HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"chunk": [
{
"aliases": [
"#test:localhost"
],
"guest_can_join": false,
"name": "test room",
"num_joined_members": 3,
"room_id": "!whkydVegtvatLfXmPN:localhost",
"world_readable": false
}
],
"end": "END",
"start": "START"
}
"""
PATH = "/publicRooms"
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
data = yield self.room_list_handler.get_local_public_room_list(
limit, since_token
)
defer.returnValue((200, data))
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
SERVLET_CLASSES = ( SERVLET_CLASSES = (
@@ -583,7 +453,6 @@ SERVLET_CLASSES = (
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
FederationStateIdsServlet,
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
@@ -599,9 +468,6 @@ SERVLET_CLASSES = (
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo,
PublicRoomList,
FederationVersionServlet,
) )
@@ -612,5 +478,4 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
authenticator=authenticator, authenticator=authenticator,
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,
room_list_handler=hs.get_room_list_handler(),
).register(resource) ).register(resource)

View File

@@ -13,16 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomContextHandler, RoomCreationHandler, RoomListHandler, RoomContextHandler,
) )
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .presence import PresenceHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler from .search import SearchHandler
@@ -30,21 +38,10 @@ from .search import SearchHandler
class Handlers(object): class Handlers(object):
""" Deprecated. A collection of handlers. """ A collection of all the event handlers.
At some point most of the classes whose name ended "Handler" were There's no need to lazily create these; we'll just make them all eagerly
accessed through this class. at construction time.
However this makes it painful to unit test the handlers and to run cut
down versions of synapse that only use specific handlers because using a
single handler required creating all of the handlers. So some of the
handlers have been lifted out of the Handlers object and are now accessed
directly through the homeserver object itself.
Any new handlers should follow the new pattern of being accessed through
the homeserver object and should not be added to the Handlers object.
The remaining handlers should be moved out of the handlers object.
""" """
def __init__(self, hs): def __init__(self, hs):
@@ -52,11 +49,26 @@ class Handlers(object):
self.message_handler = MessageHandler(hs) self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs) self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs) self.room_member_handler = RoomMemberHandler(hs)
self.event_stream_handler = EventStreamHandler(hs)
self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs) self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler(
clock=hs.get_clock(),
store=hs.get_datastore(),
as_api=asapi
)
)
self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs) self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs) self.room_context_handler = RoomContextHandler(hs)

View File

@@ -13,33 +13,49 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.api.errors import LimitExceededError from synapse.types import UserID, RoomAlias, Requester
from synapse.types import UserID from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VISIBILITY_PRIORITY = (
"world_readable",
"shared",
"invited",
"joined",
)
MEMBERSHIP_PRIORITY = (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
class BaseHandler(object): class BaseHandler(object):
""" """
Common base class for the event handlers. Common base class for the event handlers.
Attributes: Attributes:
store (synapse.storage.DataStore): store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler): state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@@ -49,10 +65,161 @@ class BaseHandler(object):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.hs = hs self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def filter_events_for_clients(self, user_tuples, events, event_id_to_state):
""" Returns dict of user_id -> list of events that user is allowed to
see.
Args:
user_tuples (str, bool): (user id, is_peeking) for each user to be
checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the
given events
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
room_id,
)
for room_id in frozenset(e.room_id for e in events)
], consumeErrors=True)
# Set of membership event_ids that have been forgotten
event_id_forgotten = frozenset(
row["event_id"] for rows in forgotten for row in rows
)
def allowed(event, user_id, is_peeking):
"""
Args:
event (synapse.events.EventBase): event to check
user_id (str)
is_peeking (bool)
"""
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility not in VISIBILITY_PRIORITY:
visibility = "shared"
# if it was world_readable, it's easy: everyone can read it
if visibility == "world_readable":
return True
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
# of the old vs new.
if event.type == EventTypes.RoomHistoryVisibility:
prev_content = event.unsigned.get("prev_content", {})
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
prev_visibility = "shared"
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
if old_priority < new_priority:
visibility = prev_visibility
# likewise, if the event is the user's own membership event, use
# the 'most joined' membership
membership = None
if event.type == EventTypes.Member and event.state_key == user_id:
membership = event.content.get("membership", None)
if membership not in MEMBERSHIP_PRIORITY:
membership = "leave"
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:
membership = prev_membership
# otherwise, get the user's membership at the time of the event.
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id not in event_id_forgotten:
membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN:
return True
if visibility == "joined":
# we weren't a member at the time of the event, so we can't
# see this event.
return False
elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
return membership == Membership.INVITE
else:
# visibility is shared: user can also see the event if they have
# become a member since the event
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
# we don't know when they left.
return not is_peeking
defer.returnValue({
user_id: [
event
for event in events
if allowed(event, user_id, is_peeking)
]
for user_id, is_peeking in user_tuples
})
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_peeking=False):
"""
Check which events a user is allowed to see
Args:
user_id(str): user id to be checked
events([synapse.events.EventBase]): list of events to be checked
is_peeking(bool): should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
Returns:
[synapse.events.EventBase]
"""
types = (
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
event_id_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events),
types=types
)
res = yield self.filter_events_for_clients(
[(user_id, is_peeking)], events, event_id_to_state
)
defer.returnValue(res.get(user_id, []))
def ratelimit(self, requester): def ratelimit(self, requester):
time_now = self.clock.time() time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.send_message( allowed, time_allowed = self.ratelimiter.send_message(
@@ -66,20 +233,213 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def maybe_kick_guest_users(self, event, context=None): def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
add_hashes_and_signatures(
builder, self.server_name, self.signing_key
)
event = builder.build()
logger.debug(
"Created event %s with current state: %s",
event.event_id, context.current_state,
)
defer.returnValue(
(event, context,)
)
def is_host_in_room(self, current_state):
room_members = [
(state_key, event.membership)
for ((event_type, state_key), event) in current_state.items()
if event_type == EventTypes.Member
]
if len(room_members) == 0:
# Have we just created the room, and is this about to be the very
# first member event?
create_event = current_state.get(("m.room.create", ""))
if create_event:
return True
for (state_key, membership) in room_members:
if (
UserID.from_string(state_key).domain == self.hs.hostname
and membership == Membership.JOIN
):
return True
return False
@defer.inlineCallbacks
def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values())
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for k, e in context.current_state.items()
if e.type in self.hs.config.room_invite_state_types
or is_inviter_member_event(e)
]
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.current_state:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context, self
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN:
destinations.add(
UserID.from_string(s.state_key).domain
)
except SynapseError:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
# If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None)
federation_handler.handle_new_event(
event, destinations=destinations,
)
@defer.inlineCallbacks
def maybe_kick_guest_users(self, event, current_state):
# Technically this function invalidates current_state by changing it. # Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller. # Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden") guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join": if guest_access != "can_join":
if context:
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
current_state = current_state.values()
else:
current_state = yield self.store.get_current_state(event.room_id)
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) yield self.kick_guest_users(current_state)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -112,8 +472,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = synapse.types.create_requester( requester = Requester(target_user, "", True)
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership( yield handler.update_membership(
requester, requester,

View File

@@ -16,8 +16,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure from synapse.appservice import ApplicationService
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.types import UserID
import logging import logging
@@ -35,81 +35,47 @@ def log_failure(failure):
) )
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object): class ApplicationServicesHandler(object):
def __init__(self, hs): def __init__(self, hs, appservice_api, appservice_scheduler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id self.hs = hs
self.appservice_api = hs.get_application_service_api() self.appservice_api = appservice_api
self.scheduler = hs.get_application_service_scheduler() self.scheduler = appservice_scheduler
self.started_scheduler = False self.started_scheduler = False
self.clock = hs.get_clock()
self.notify_appservices = hs.config.notify_appservices
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, current_id): def notify_interested_services(self, event):
"""Notifies (pushes) all application services interested in this event. """Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any Pushing is done asynchronously, so this method won't block for any
prolonged length of time. prolonged length of time.
Args: Args:
current_id(int): The current maximum ID. event(Event): The event to push out to interested services.
""" """
services = yield self.store.get_app_services() # Gather interested services
if not services or not self.notify_appservices: services = yield self._get_services_for_event(event)
return if len(services) == 0:
return # no services need notifying
self.current_max = max(self.current_max, current_id) # Do we know this user exists? If not, poke the user query API for
if self.is_processing: # all services which match that user regex. This needs to block as these
return # user queries need to be made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
with Measure(self.clock, "notify_interested_services"): if not self.started_scheduler:
self.is_processing = True self.scheduler.start().addErrback(log_failure)
try: self.started_scheduler = True
upper_bound = self.current_max
limit = 100
while True:
upper_bound, events = yield self.store.get_new_events_for_appservice(
upper_bound, limit
)
if not events: # Fork off pushes to these services
break for service in services:
self.scheduler.submit_event_for_as(service, event)
for event in events:
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
continue # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
for service in services:
preserve_fn(self.scheduler.submit_event_for_as)(
service, event
)
yield self.store.set_appservice_last_pos(upper_bound)
if len(events) < limit:
break
finally:
self.is_processing = False
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user_exists(self, user_id): def query_user_exists(self, user_id):
@@ -142,12 +108,11 @@ class ApplicationServicesHandler(object):
association can be found. association can be found.
""" """
room_alias_str = room_alias.to_string() room_alias_str = room_alias.to_string()
services = yield self.store.get_app_services() alias_query_services = yield self._get_services_for_event(
alias_query_services = [ event=None,
s for s in services if ( restrict_to=ApplicationService.NS_ALIASES,
s.is_interested_in_alias(room_alias_str) alias_list=[room_alias_str]
) )
]
for alias_service in alias_query_services: for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias( is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str alias_service, room_alias_str
@@ -160,74 +125,34 @@ class ApplicationServicesHandler(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields): def _get_services_for_event(self, event, restrict_to="", alias_list=None):
services = yield self._get_services_for_3pn(protocol)
results = yield preserve_context_over_deferred(defer.DeferredList([
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
for service in services
], consumeErrors=True))
ret = []
for (success, result) in results:
if success:
ret.extend(result)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None):
services = yield self.store.get_app_services()
protocols = {}
# Collect up all the individual protocol responses out of the ASes
for s in services:
for p in s.protocols:
if only_protocol is not None and p != only_protocol:
continue
if p not in protocols:
protocols[p] = []
info = yield self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
def _merge_instances(infos):
if not infos:
return {}
# Merge the 'instances' lists of multiple results, but just take
# the other fields from the first as they ought to be identical
# copy the result so as not to corrupt the cached one
combined = dict(infos[0])
combined["instances"] = list(combined["instances"])
for info in infos[1:]:
combined["instances"].extend(info["instances"])
return combined
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
defer.returnValue(protocols)
@defer.inlineCallbacks
def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.
Args: Args:
event(Event): The event to check. Can be None if alias_list is not. event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns: Returns:
list<ApplicationService>: A list of services interested in this list<ApplicationService>: A list of services interested in this
event based on the service regex. event based on the service regex.
""" """
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services() services = yield self.store.get_app_services()
interested_list = [ interested_list = [
s for s in services if ( s for s in services if (
yield s.is_interested(event, self.store) s.is_interested(event, restrict_to, alias_list, member_list)
) )
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@@ -242,17 +167,10 @@ class ApplicationServicesHandler(object):
] ]
defer.returnValue(interested_list) defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_3pn(self, protocol):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if s.is_interested_in_protocol(protocol)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id): user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
# we don't know if they are unknown or not since it isn't one of our # we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes. # users. We can't poke ASes.
defer.returnValue(False) defer.returnValue(False)

View File

@@ -18,9 +18,8 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError from synapse.api.errors import AuthError, LoginError, Codes
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@@ -29,13 +28,6 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@@ -46,10 +38,6 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = { self.checkers = {
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
@@ -62,23 +50,19 @@ class AuthHandler(BaseHandler):
self.INVALID_TOKEN_HTTP_STATUS = 401 self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled self.ldap_enabled = hs.config.ldap_enabled
if self.ldap_enabled: self.ldap_server = hs.config.ldap_server
if not ldap3: self.ldap_port = hs.config.ldap_port
raise RuntimeError( self.ldap_tls = hs.config.ldap_tls
'Missing ldap3 library. This is required for LDAP Authentication.' self.ldap_search_base = hs.config.ldap_search_base
) self.ldap_search_property = hs.config.ldap_search_property
self.ldap_mode = hs.config.ldap_mode self.ldap_email_property = hs.config.ldap_email_property
self.ldap_uri = hs.config.ldap_uri self.ldap_full_name_property = hs.config.ldap_full_name_property
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base if self.ldap_enabled is True:
self.ldap_attributes = hs.config.ldap_attributes import ldap
if self.ldap_mode == LDAPMode.SEARCH: logger.info("Import ldap version: %s", ldap.__version__)
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@@ -236,6 +220,7 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -245,7 +230,11 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
return self._check_password(user_id, password) if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@@ -281,17 +270,8 @@ class AuthHandler(BaseHandler):
data = pde.response data = pde.response
resp_body = simplejson.loads(data) resp_body = simplejson.loads(data)
if 'success' in resp_body: if 'success' in resp_body and resp_body['success']:
# Note that we do NOT check the hostname here: we explicitly defer.returnValue(True)
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -358,84 +338,67 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
def validate_password_login(self, user_id, password): @defer.inlineCallbacks
def login_with_password(self, user_id, password):
""" """
Authenticates the user with their username and password. Authenticates the user with their username and password.
Used only by the v1 login API. Used only by the v1 login API.
Args: Args:
user_id (str): complete @user:id user_id (str): User ID
password (str): Password password (str): Password
Returns:
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args:
user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
# the device *should* have been registered before we got here; however, if not (yield self._check_password(user_id, password)):
# it's possible we raced against a DELETE operation. The thing we logger.warn("Failed password login for user %s", user_id)
# really don't want is active access_tokens without a record of the raise LoginError(403, "", errcode=Codes.FORBIDDEN)
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token)) logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_exists(self, user_id): def get_login_tuple_for_user_id(self, user_id):
""" """
Checks to see if a user with the given id exists. Will check case Gets login tuple for the user with the given user ID.
insensitively, but return None if there are multiple inexact matches. The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
Args: Args:
(str) user_id: complete @user:id user_id (str): User ID
Returns: Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or A tuple of:
multiple matches The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
""" """
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
try: try:
res = yield self._find_user_id_and_pwd_hash(user_id) yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(res[0]) defer.returnValue(True)
except LoginError: except LoginError:
defer.returnValue(None) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
@@ -465,336 +428,77 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
"""Authenticate a user against the LDAP and local databases. defer.returnValue(
not (
user_id is checked case insensitively against the local database, but (yield self._check_ldap_password(user_id, password))
will throw if there are multiple inexact matches. or
(yield self._check_local_password(user_id, password))
Args: ))
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
defer.returnValue(user_id)
result = yield self._check_local_password(user_id, password)
defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are
multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
def _ldap_simple_bind(self, server, localpart, password):
""" Attempt a simple bind with the credentials
given by the user against the LDAP server.
Returns True, LDAP3Connection
if the bind was successful
Returns False, None
if an error occured
"""
try: try:
# bind with the the local users ldap credentials user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
bind_dn = "{prop}={value},{base}".format( defer.returnValue(not self.validate_hash(password, password_hash))
prop=self.ldap_attributes['uid'], except LoginError:
value=localpart, defer.returnValue(False)
base=self.ldap_base
)
conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
"Established LDAP connection in simple bind mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in simple bind mode through StartTLS: %s",
conn
)
if conn.bind():
# GOOD: bind okay
logger.debug("LDAP Bind successful in simple bind mode.")
return True, conn
# BAD: bind failed
logger.info(
"Binding against LDAP failed for '%s' failed: %s",
localpart, conn.result['description']
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
def _ldap_authenticated_search(self, server, localpart, password):
""" Attempt to login with the preconfigured bind_dn
and then continue searching and filtering within
the base_dn
Returns (True, LDAP3Connection)
if a single matching DN within the base was found
that matched the filter expression, and with which
a successful bind was achieved
The LDAP3Connection returned is the instance that was used to
verify the password not the one using the configured bind_dn.
Returns (False, None)
if an error occured
"""
try:
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established LDAP connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded LDAP connection in search mode through StartTLS: %s",
conn
)
if not conn.bind():
logger.warn(
"Binding against LDAP with `bind_dn` failed: %s",
conn.result['description']
)
conn.unbind()
return False, None
# construct search_filter like (uid=localpart)
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
# combine with the AND expression
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug(
"LDAP search filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query
)
if len(conn.response) == 1:
# GOOD: found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('LDAP search found dn: %s', user_dn)
# unbind and simple bind with user_dn to verify the password
# Note: do not use rebind(), for some reason it did not verify
# the password for me!
conn.unbind()
return self._ldap_simple_bind(server, localpart, password)
else:
# BAD: found 0 or > 1 results, abort!
if len(conn.response) == 0:
logger.info(
"LDAP search returned no results for '%s'",
localpart
)
else:
logger.info(
"LDAP search returned too many (%s) results for '%s'",
len(conn.response), localpart
)
conn.unbind()
return False, None
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during LDAP authentication: %s", e)
return False, None
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server if self.ldap_enabled is not True:
and register an account if none exists. logger.debug("LDAP not configured")
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False) defer.returnValue(False)
localpart = UserID.from_string(user_id).localpart import ldap
logger.info("Authenticating %s with LDAP" % user_id)
try: try:
server = ldap3.Server(self.ldap_uri) ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
logger.debug( logger.debug("Connecting LDAP server at %s" % ldap_url)
"Attempting LDAP connection with %s", l = ldap.initialize(ldap_url)
self.ldap_uri if self.ldap_tls:
) logger.debug("Initiating TLS")
self._connection.start_tls_s()
if self.ldap_mode == LDAPMode.SIMPLE: local_name = UserID.from_string(user_id).localpart
result, conn = self._ldap_simple_bind(
server=server, localpart=localpart, password=password dn = "%s=%s, %s" % (
) self.ldap_search_property,
logger.debug( local_name,
'LDAP authentication method simple bind returned: %s (conn: %s)', self.ldap_search_base)
result, logger.debug("DN for LDAP authentication: %s" % dn)
conn
) l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
if not result:
defer.returnValue(False) if not (yield self.does_user_exist(user_id)):
elif self.ldap_mode == LDAPMode.SEARCH: handler = self.hs.get_handlers().registration_handler
result, conn = self._ldap_authenticated_search( user_id, access_token = (
server=server, localpart=localpart, password=password yield handler.register(localpart=local_name)
)
logger.debug(
'LDAP auth method authenticated search returned: %s (conn: %s)',
result,
conn
)
if not result:
defer.returnValue(False)
else:
raise RuntimeError(
'Invalid LDAP mode specified: {mode}'.format(
mode=self.ldap_mode
)
) )
try: defer.returnValue(True)
logger.info( except ldap.LDAPError, e:
"User authenticated against LDAP server: %s", logger.warn("LDAP error: %s", e)
conn
)
except NameError:
logger.warn("Authentication method yielded no LDAP connection, aborting!")
defer.returnValue(False)
# check if user with user_id exists
if (yield self.check_user_exists(user_id)):
# exists, authentication complete
conn.unbind()
defer.returnValue(True)
else:
# does not exist, fetch metadata for account creation from
# existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug(
"ldap registration filter: %s",
query
)
conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
defer.returnValue(True)
else:
if len(conn.response) == 0:
logger.warn("LDAP registration failed, no result.")
else:
logger.warn(
"LDAP registration failed, too many results (%s)",
len(conn.response)
)
defer.returnValue(False)
defer.returnValue(False)
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None): def issue_access_token(self, user_id):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token, yield self.store.add_access_token_to_user(user_id, access_token)
device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_refresh_token(self, user_id, device_id=None): def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id) refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token, yield self.store.add_refresh_token_to_user(user_id, refresh_token)
device_id)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None, def generate_access_token(self, user_id, extra_caveats=None):
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats: for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
@@ -810,28 +514,22 @@ class AuthHandler(BaseHandler):
)) ))
return m.serialize() return m.serialize()
def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): def generate_short_term_login_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login") macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms expiry = now + (2 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon) auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True, user_id) auth_api.validate_macaroon(macaroon, "login", True)
return user_id return self.get_user_from_macaroon(macaroon)
except Exception: except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@@ -842,23 +540,28 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_id = requester.access_token_id if requester else None except_access_token_ids = [requester.access_token_id] if requester else []
try: yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_id user_id, except_access_token_ids
) )
yield self.hs.get_pusherpool().remove_pushers_by_user( yield self.hs.get_pusherpool().remove_pushers_by_user(
user_id, except_access_token_id user_id, except_access_token_ids
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -893,8 +596,7 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Hashed password (str). Hashed password (str).
""" """
return bcrypt.hashpw(password + self.hs.config.password_pepper, return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@@ -906,8 +608,4 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
if stored_hash: return bcrypt.hashpw(password, stored_hash) == stored_hash
return bcrypt.hashpw(password + self.hs.config.password_pepper,
stored_hash.encode('utf-8')) == stored_hash
else:
return False

View File

@@ -1,181 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api import errors
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string(10).upper()
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys())
)
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),)
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})

View File

@@ -1,117 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.direct_to_device", self.on_direct_to_device_edu
)
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
origin, sender_user_id
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
@defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
if self.is_mine_id(user_id):
messages_by_device = {
device_id: {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
for device_id, message_content in by_device.items()
}
if messages_by_device:
local_messages[user_id] = messages_by_device
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
message_id = random_string(16)
remote_edu_contents = {}
for destination, messages in remote_messages.items():
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
}
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
self.notifier.on_new_event(
"to_device_key", stream_id, users=local_messages.keys()
)
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
self.federation.send_device_messages(destination)

View File

@@ -19,7 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias, UserID, get_domain_from_id from synapse.types import RoomAlias, UserID
import logging import logging
import string import string
@@ -33,7 +33,6 @@ class DirectoryHandler(BaseHandler):
super(DirectoryHandler, self).__init__(hs) super(DirectoryHandler, self).__init__(hs)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
@@ -55,8 +54,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions. # TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association. # TODO(erikj): Check if there is a current association.
if not servers: if not servers:
users = yield self.state.get_current_user_in_room(room_id) servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
if not servers: if not servers:
raise SynapseError(400, "Failed to get server list") raise SynapseError(400, "Failed to get server list")
@@ -194,8 +192,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND Codes.NOT_FOUND
) )
users = yield self.state.get_current_user_in_room(room_id) extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
extra_servers = set(get_domain_from_id(u) for u in users)
servers = set(extra_servers) | set(servers) servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first. # If this server is in the list of servers, return it first.
@@ -284,7 +281,7 @@ class DirectoryHandler(BaseHandler):
) )
if not result: if not result:
# Query AS to see if it exists # Query AS to see if it exists
as_handler = self.appservice_handler as_handler = self.hs.get_handlers().appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result) defer.returnValue(result)

View File

@@ -1,279 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ujson as json
import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks
def query_devices(self, query_body, timeout):
""" Handle a device key query from a client
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
local_query = {}
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries
failures = {}
results = {}
if local_query:
local_result = yield self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
@defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.query_client_keys(
destination,
{"device_keys": destination_query},
timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
]))
defer.returnValue({
"device_keys": results, "failures": failures,
})
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
@defer.inlineCallbacks
def claim_client_keys(destination):
device_keys = remote_queries[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
)
with limiter:
remote_result = yield self.federation.claim_client_keys(
destination,
{"one_time_keys": device_keys},
timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
except CodeMessageException as e:
failures[destination] = {
"status": e.code, "message": e.message
}
except NotRetryingDestination as e:
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = keys.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})

View File

@@ -47,7 +47,6 @@ class EventStreamHandler(BaseHandler):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -59,7 +58,7 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down. If `only_keys` is not None, events from keys will be sent down.
""" """
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler() presence_handler = self.hs.get_handlers().presence_handler
context = yield presence_handler.user_syncing( context = yield presence_handler.user_syncing(
auth_user_id, affect_presence=affect_presence, auth_user_id, affect_presence=affect_presence,
@@ -91,7 +90,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence. # Send down presence.
if event.state_key == auth_user_id: if event.state_key == auth_user_id:
# Send down presence for everyone in the room. # Send down presence for everyone in the room.
users = yield self.state.get_current_user_in_room(event.room_id) users = yield self.store.get_users_in_room(event.room_id)
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
users, users,
as_event=True, as_event=True,

View File

@@ -26,17 +26,14 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import ( from synapse.util.logcontext import PreserveLoggingContext
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
)
from synapse.util.metrics import measure_func
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures, compute_event_signature, add_hashes_and_signatures,
) )
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
@@ -69,6 +66,10 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.distributor.observe("user_joined_room", self.user_joined_room)
self.waiting_for_join_list = {}
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@@ -101,9 +102,6 @@ class FederationHandler(BaseHandler):
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None): def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to """ Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler. do auth checks and put it through the StateHandler.
auth_chain and state are None if we already have the necessary state
and prev_events in the db
""" """
event = pdu event = pdu
@@ -121,25 +119,16 @@ class FederationHandler(BaseHandler):
# FIXME (erikj): Awful hack to make the case where we are not currently # FIXME (erikj): Awful hack to make the case where we are not currently
# in the room work # in the room work
# If state and auth_chain are None, then we don't need to do this check is_in_room = yield self.auth.check_host_in_room(
# as we already know we have enough state in the DB to handle this event.room_id,
# event. self.server_name
if state and auth_chain and not event.internal_metadata.is_outlier(): )
is_in_room = yield self.auth.check_host_in_room( if not is_in_room and not event.internal_metadata.is_outlier():
event.room_id, logger.debug("Got event for room we're not in.")
self.server_name
)
else:
is_in_room = True
if not is_in_room:
logger.info(
"Got event for room we're not in: %r %r",
event.room_id, event.event_id
)
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event auth_chain, state, event
) )
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
@@ -230,28 +219,17 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally prev_state = context.current_state.get((event.type, event.state_key))
# joined the room. Don't bother if the user is just if not prev_state or prev_state.membership != Membership.JOIN:
# changing their profile info. # Only fire user_joined_room if the user has acutally
newly_joined = True # joined the room. Don't bother if the user is just
prev_state_id = context.prev_state_ids.get( # changing their profile info.
(event.type, event.state_key)
)
if prev_state_id:
prev_state = yield self.store.get_event(
prev_state_id, allow_none=True,
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
if newly_joined:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id) yield user_joined_room(self.distributor, user, event.room_id)
@measure_func("_filter_events_for_server")
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
event_to_state_ids = yield self.store.get_state_ids_for_events( event_to_state = yield self.store.get_state_for_events(
frozenset(e.event_id for e in events), frozenset(e.event_id for e in events),
types=( types=(
(EventTypes.RoomHistoryVisibility, ""), (EventTypes.RoomHistoryVisibility, ""),
@@ -259,30 +237,6 @@ class FederationHandler(BaseHandler):
) )
) )
# We only want to pull out member events that correspond to the
# server's domain.
def check_match(id):
try:
return server_name == get_domain_from_id(id)
except:
return False
event_map = yield self.store.get_events([
e_id for key_to_eid in event_to_state_ids.values()
for key, e_id in key_to_eid
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.items()
}
def redact_disallowed(event, state): def redact_disallowed(event, state):
if not state: if not state:
return event return event
@@ -299,7 +253,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
domain = get_domain_from_id(ev.state_key) domain = UserID.from_string(ev.state_key).domain
except: except:
continue continue
@@ -324,16 +278,15 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities): def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id` """ Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return
be successfull and still return no events if the other side has no new
events to offer.
""" """
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill( events = yield self.replication_layer.backfill(
dest, dest,
room_id, room_id,
@@ -341,16 +294,6 @@ class FederationHandler(BaseHandler):
extremities=extremities, extremities=extremities,
) )
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
events = [e for e in events if e.event_id not in seen_events]
if not events:
defer.returnValue([])
event_map = {e.event_id: e for e in events} event_map = {e.event_id: e for e in events}
event_ids = set(e.event_id for e in events) event_ids = set(e.event_id for e in events)
@@ -382,73 +325,40 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state events_to_state[e_id] = state
required_auth = set(
a_id
for event in events + state_events.values() + auth_events.values()
for a_id, _ in event.auth_events
)
auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
})
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch
)
results = yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.replication_layer.get_pdu)(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True
)).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
for event in results if event
for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_events( seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = [] ev_infos = []
for a in auth_events.values(): for a in auth_events.values():
if a.event_id in seen_events: if a.event_id in seen_events:
continue continue
a.internal_metadata.outlier = True
ev_infos.append({ ev_infos.append({
"event": a, "event": a,
"auth_events": { "auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in a.auth_events for a_id, _ in a.auth_events
if a_id in auth_events
} }
}) })
@@ -460,27 +370,23 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events for a_id, _ in event_map[e_id].auth_events
if a_id in auth_events
} }
}) })
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
for event in events: for event in events:
if event in events_to_state: if event in events_to_state:
continue continue
# We store these one at a time since each event depends on the ev_infos.append({
# previous to work out the state. "event": event,
# TODO: We can probably do something more clever here. })
yield self._handle_new_event(
dest, event, backfilled=True, yield self._handle_new_events(
) dest, ev_infos,
backfilled=True,
)
defer.returnValue(events) defer.returnValue(events)
@@ -504,10 +410,6 @@ class FederationHandler(BaseHandler):
) )
max_depth = sorted_extremeties_tuple[0][1] max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
# request URI to be too long.
extremities = dict(sorted_extremeties_tuple[:5])
if current_depth > max_depth: if current_depth > max_depth:
logger.debug( logger.debug(
"Not backfilling as we don't need to. %d < %d", "Not backfilling as we don't need to. %d < %d",
@@ -533,7 +435,7 @@ class FederationHandler(BaseHandler):
joined_domains = {} joined_domains = {}
for u, d in joined_users: for u, d in joined_users:
try: try:
dom = get_domain_from_id(u) dom = UserID.from_string(u).domain
old_d = joined_domains.get(dom) old_d = joined_domains.get(dom)
if old_d: if old_d:
joined_domains[dom] = min(d, old_d) joined_domains[dom] = min(d, old_d)
@@ -556,15 +458,11 @@ class FederationHandler(BaseHandler):
# TODO: Should we try multiple of these at a time? # TODO: Should we try multiple of these at a time?
for dom in domains: for dom in domains:
try: try:
yield self.backfill( events = yield self.backfill(
dom, room_id, dom, room_id,
limit=100, limit=100,
extremities=[e for e in extremities.keys()] extremities=[e for e in extremities.keys()]
) )
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
defer.returnValue(True)
except SynapseError as e: except SynapseError as e:
logger.info( logger.info(
"Failed to backfill from %s because %s", "Failed to backfill from %s because %s",
@@ -590,6 +488,8 @@ class FederationHandler(BaseHandler):
) )
continue continue
if events:
defer.returnValue(True)
defer.returnValue(False) defer.returnValue(False)
success = yield try_backfill(likely_domains) success = yield try_backfill(likely_domains)
@@ -604,24 +504,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
states = yield preserve_context_over_deferred(defer.gatherResults([ states = yield defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e]) self.state_handler.resolve_state_groups(room_id, [e])
for e in event_ids for e in event_ids
])) ])
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.items()
if e_id in state_map
} for key, state_dict in states.items()
}
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id]) likely_domains = get_domains_from_state(states[e_id])
@@ -731,7 +619,7 @@ class FederationHandler(BaseHandler):
pass pass
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
origin, auth_chain, state, event auth_chain, state, event
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
@@ -773,18 +661,11 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
try: event, context = yield self._create_new_client_event(
message_handler = self.hs.get_handlers().message_handler builder=builder,
event, context = yield message_handler._create_new_client_event( )
builder=builder,
)
except AuthError as e:
logger.warn("Failed to create join %r because %s", event, e)
raise e
# The remote hasn't signed it yet, obviously. We'll do the full checks self.auth.check(event, auth_events=context.current_state)
# when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
defer.returnValue(event) defer.returnValue(event)
@@ -832,12 +713,19 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context) destinations = set()
destinations = set( for k, s in context.current_state.items():
get_domain_from_id(user_id) for user_id in users_in_room try:
if not self.hs.is_mine_id(user_id) if k[0] == EventTypes.Member:
) if s.content["membership"] == Membership.JOIN:
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations.discard(origin) destinations.discard(origin)
@@ -849,15 +737,13 @@ class FederationHandler(BaseHandler):
self.replication_layer.send_pdu(new_pdu, destinations) self.replication_layer.send_pdu(new_pdu, destinations)
state_ids = context.prev_state_ids.values() state_ids = [e.event_id for e in context.current_state.values()]
auth_chain = yield self.store.get_auth_chain(set( auth_chain = yield self.store.get_auth_chain(set(
[event.event_id] + state_ids [event.event_id] + state_ids
)) ))
state = yield self.store.get_events(context.prev_state_ids.values())
defer.returnValue({ defer.returnValue({
"state": state.values(), "state": context.current_state.values(),
"auth_chain": auth_chain, "auth_chain": auth_chain,
}) })
@@ -1005,18 +891,11 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
message_handler = self.hs.get_handlers().message_handler event, context = yield self._create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
try: self.auth.check(event, auth_events=context.current_state)
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(event, context, do_sig_check=False)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
defer.returnValue(event) defer.returnValue(event)
@@ -1057,12 +936,20 @@ class FederationHandler(BaseHandler):
new_pdu = event new_pdu = event
users_in_room = yield self.store.get_joined_users_from_context(event, context) destinations = set()
for k, s in context.current_state.items():
try:
if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE:
destinations.add(
UserID.from_string(s.state_key).domain
)
except:
logger.warn(
"Failed to get destination from event %s", s.event_id
)
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
)
destinations.discard(origin) destinations.discard(origin)
logger.debug( logger.debug(
@@ -1076,11 +963,14 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
"""Returns the state at the event. i.e. not including said event.
"""
yield run_on_reactor() yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, [event_id] room_id, [event_id]
) )
@@ -1104,49 +994,18 @@ class FederationHandler(BaseHandler):
res = results.values() res = results.values()
for event in res: for event in res:
# We sign these again because there was a bug where we event.signatures.update(
# incorrectly signed things the first time round compute_event_signature(
if self.hs.is_mine_id(event.event_id): event,
event.signatures.update( self.hs.hostname,
compute_event_signature( self.hs.config.signing_key[0]
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
) )
)
defer.returnValue(res) defer.returnValue(res)
else: else:
defer.returnValue([]) defer.returnValue([])
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
yield run_on_reactor()
state_groups = yield self.store.get_state_groups_ids(
room_id, [event_id]
)
if state_groups:
_, state = state_groups.items().pop()
results = state
event = yield self.store.get_event(event_id)
if event and event.is_state():
# Get previous state
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id
else:
del results[(event.type, event.state_key)]
defer.returnValue(results.values())
else:
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit): def on_backfill_request(self, origin, room_id, pdu_list, limit):
@@ -1179,17 +1038,16 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
if self.hs.is_mine_id(event.event_id): # FIXME: This is a temporary work around where we occasionally
# FIXME: This is a temporary work around where we occasionally # return events slightly differently than when they were
# return events slightly differently than when they were # originally signed
# originally signed event.signatures.update(
event.signatures.update( compute_event_signature(
compute_event_signature( event,
event, self.hs.hostname,
self.hs.hostname, self.hs.config.signing_key[0]
self.hs.config.signing_key[0]
)
) )
)
if do_auth: if do_auth:
in_room = yield self.auth.check_host_in_room( in_room = yield self.auth.check_host_in_room(
@@ -1199,12 +1057,6 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server(
origin, event.room_id, [event]
)
event = events[0]
defer.returnValue(event) defer.returnValue(event)
else: else:
defer.returnValue(None) defer.returnValue(None)
@@ -1213,10 +1065,18 @@ class FederationHandler(BaseHandler):
def get_min_depth_for_context(self, context): def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context) return self.store.get_min_depth(context)
@log_function
def user_joined_room(self, user, room_id):
waiters = self.waiting_for_join_list.get(
(user.to_string(), room_id),
[]
)
while waiters:
waiters.pop().callback(None)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None, def _handle_new_event(self, origin, event, state=None, auth_events=None):
backfilled=False):
context = yield self._prep_event( context = yield self._prep_event(
origin, event, origin, event,
state=state, state=state,
@@ -1226,34 +1086,21 @@ class FederationHandler(BaseHandler):
if not event.internal_metadata.is_outlier(): if not event.internal_metadata.is_outlier():
action_generator = ActionGenerator(self.hs) action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event( yield action_generator.handle_push_actions_for_event(
event, context event, context, self
) )
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, event,
context=context, context=context,
backfilled=backfilled,
) )
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False): def _handle_new_events(self, origin, event_infos, backfilled=False):
"""Creates the appropriate contexts and persists events. The events contexts = yield defer.gatherResults(
should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
contexts = yield preserve_context_over_deferred(defer.gatherResults(
[ [
preserve_fn(self._prep_event)( self._prep_event(
origin, origin,
ev_info["event"], ev_info["event"],
state=ev_info.get("state"), state=ev_info.get("state"),
@@ -1261,7 +1108,7 @@ class FederationHandler(BaseHandler):
) )
for ev_info in event_infos for ev_info in event_infos
] ]
)) )
yield self.store.persist_events( yield self.store.persist_events(
[ [
@@ -1272,19 +1119,11 @@ class FederationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_auth_tree(self, origin, auth_events, state, event): def _persist_auth_tree(self, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event seperately. Persists the event seperately.
Will attempt to fetch missing auth events.
Args:
origin (str): Where the events came from
auth_events (list)
state (list)
event (Event)
Returns: Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event 2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event` call for `event`
@@ -1297,7 +1136,7 @@ class FederationHandler(BaseHandler):
event_map = { event_map = {
e.event_id: e e.event_id: e
for e in itertools.chain(auth_events, state, [event]) for e in auth_events
} }
create_event = None create_event = None
@@ -1306,29 +1145,10 @@ class FederationHandler(BaseHandler):
create_event = e create_event = e
break break
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events:
if e_id not in event_map:
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu(
[origin],
e_id,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
else:
logger.info("Failed to find auth event %r", e_id)
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
auth_for_e = { auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events for e_id, _ in e.auth_events
if e_id in event_map
} }
if create_event: if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event auth_for_e[(EventTypes.Create, "")] = create_event
@@ -1377,13 +1197,7 @@ class FederationHandler(BaseHandler):
) )
if not auth_events: if not auth_events:
auth_events_ids = yield self.auth.compute_auth_events( auth_events = context.current_state
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
@@ -1409,7 +1223,8 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if event.type == EventTypes.GuestAccess: if event.type == EventTypes.GuestAccess:
yield self.maybe_kick_guest_users(event) full_context = yield self.store.get_current_state(room_id=event.room_id)
yield self.maybe_kick_guest_users(event, full_context)
defer.returnValue(context) defer.returnValue(context)
@@ -1477,11 +1292,6 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event.is_state():
event_key = (event.type, event.state_key)
else:
event_key = None
if event_auth_events - current_state: if event_auth_events - current_state:
have_events = yield self.store.have_events( have_events = yield self.store.have_events(
event_auth_events - current_state event_auth_events - current_state
@@ -1555,9 +1365,9 @@ class FederationHandler(BaseHandler):
# Do auth conflict res. # Do auth conflict res.
logger.info("Different auth: %s", different_auth) logger.info("Different auth: %s", different_auth)
different_events = yield preserve_context_over_deferred(defer.gatherResults( different_events = yield defer.gatherResults(
[ [
preserve_fn(self.store.get_event)( self.store.get_event(
d, d,
allow_none=True, allow_none=True,
allow_rejected=False, allow_rejected=False,
@@ -1566,13 +1376,13 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d] if d in have_events and not have_events[d]
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
if different_events: if different_events:
local_view = dict(auth_events) local_view = dict(auth_events)
remote_view = dict(auth_events) remote_view = dict(auth_events)
remote_view.update({ remote_view.update({
(d.type, d.state_key): d for d in different_events if d (d.type, d.state_key): d for d in different_events
}) })
new_state, prev_state = self.state_handler.resolve_events( new_state, prev_state = self.state_handler.resolve_events(
@@ -1585,16 +1395,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values()) current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state different_auth = event_auth_events - current_state
context.current_state_ids = dict(context.current_state_ids) context.current_state.update(auth_events)
context.current_state_ids.update({ context.state_group = None
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
if different_auth and not event.internal_metadata.is_outlier(): if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth) logger.info("Different auth after resolution: %s", different_auth)
@@ -1615,8 +1417,8 @@ class FederationHandler(BaseHandler):
if do_resolution: if do_resolution:
# 1. Get what we think is the auth chain. # 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events( auth_ids = self.auth.compute_auth_events(
event, context.prev_state_ids event, context.current_state
) )
local_auth_chain = yield self.store.get_auth_chain(auth_ids) local_auth_chain = yield self.store.get_auth_chain(auth_ids)
@@ -1672,22 +1474,13 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs. # 4. Look at rejects and their proofs.
# TODO. # TODO.
context.current_state_ids = dict(context.current_state_ids) context.current_state.update(auth_events)
context.current_state_ids.update({ context.state_group = None
k: a.event_id for k, a in auth_events.items()
if k != event_key
})
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.items()
})
context.state_group = self.store.get_next_state_group()
try: try:
self.auth.check(event, auth_events=auth_events) self.auth.check(event, auth_events=auth_events)
except AuthError as e: except AuthError:
logger.warn("Failed auth resolution for %r because %s", event, e) raise
raise e
@defer.inlineCallbacks @defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth): def construct_auth_difference(self, local_auth, remote_auth):
@@ -1857,22 +1650,14 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self._create_new_client_event(builder=builder)
event, context = yield message_handler._create_new_client_event(
builder=builder
)
event, context = yield self.add_display_name_to_third_party_invite( event, context = yield self.add_display_name_to_third_party_invite(
event_dict, event, context event_dict, event, context
) )
try: self.auth.check(event, context.current_state)
yield self.auth.check_from_context(event, context) yield self._check_signature(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context) yield member_handler.send_membership_event(None, event, context)
else: else:
@@ -1888,8 +1673,7 @@ class FederationHandler(BaseHandler):
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler event, context = yield self._create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@@ -1897,12 +1681,8 @@ class FederationHandler(BaseHandler):
event_dict, event, context event_dict, event, context
) )
try: self.auth.check(event, auth_events=context.current_state)
self.auth.check_from_context(event, context) yield self._check_signature(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
returned_invite = yield self.send_invite(origin, event) returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct. # TODO: Make sure the signatures actually are correct.
@@ -1916,38 +1696,29 @@ class FederationHandler(BaseHandler):
EventTypes.ThirdPartyInvite, EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"] event.content["third_party_invite"]["signed"]["token"]
) )
original_invite = None original_invite = context.current_state.get(key)
original_invite_id = context.prev_state_ids.get(key) if not original_invite:
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
if original_invite:
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info( logger.info(
"Could not find invite event for third_party_invite: %r", "Could not find invite event for third_party_invite - "
event_dict "discarding: %s" % (event_dict,)
) )
# We don't discard here as this is not the appropriate place to do return
# auth checks. If we need the invite and don't have it then the
# auth check code will explode appropriately.
display_name = original_invite.content["display_name"]
event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self._create_new_client_event(builder=builder)
event, context = yield message_handler._create_new_client_event(builder=builder)
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_signature(self, event, context): def _check_signature(self, event, auth_events):
""" """
Checks that the signature in the event is consistent with its invite. Checks that the signature in the event is consistent with its invite.
Args: Args:
event (Event): The m.room.member event to check event (Event): The m.room.member event to check
context (EventContext): auth_events (dict<(event type, state_key), event>):
Raises: Raises:
AuthError: if signature didn't match any keys, or key has been AuthError: if signature didn't match any keys, or key has been
@@ -1958,14 +1729,10 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"] signed = event.content["third_party_invite"]["signed"]
token = signed["token"] token = signed["token"]
invite_event_id = context.prev_state_ids.get( invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,) (EventTypes.ThirdPartyInvite, token,)
) )
invite_event = None
if invite_event_id:
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
if not invite_event: if not invite_event:
raise AuthError(403, "Could not find invite") raise AuthError(403, "Could not find invite")

View File

@@ -21,7 +21,7 @@ from synapse.api.errors import (
) )
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError
import json import json
import logging import logging
@@ -41,20 +41,6 @@ class IdentityHandler(BaseHandler):
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
) )
def _should_trust_id_server(self, id_server):
if id_server not in self.trusted_id_servers:
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn(
"Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
return False
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
@@ -73,12 +59,19 @@ class IdentityHandler(BaseHandler):
else: else:
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
if not self._should_trust_id_server(id_server): if id_server not in self.trusted_id_servers:
logger.warn( if self.trust_any_id_server_just_for_testing_do_not_use:
'%s is not a trusted ID server: rejecting 3pid ' + logger.warn(
'credentials', id_server "Trusting untrustworthy ID server %r even though it isn't"
) " in the trusted id list for testing because"
defer.returnValue(None) " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None)
data = {} data = {}
try: try:
@@ -136,12 +129,6 @@ class IdentityHandler(BaseHandler):
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor() yield run_on_reactor()
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
)
params = { params = {
'email': email, 'email': email,
'client_secret': client_secret, 'client_secret': client_secret,

View File

@@ -1,443 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.streams.config import PaginationConfig
from synapse.types import (
UserID, StreamToken,
)
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class InitialSyncHandler(BaseHandler):
def __init__(self, hs):
super(InitialSyncHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks
def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield preserve_context_over_deferred(
defer.gatherResults(
[
preserve_fn(self.store.get_recent_events_for_room)(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
)
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
preserve_fn(get_presence)(),
preserve_fn(get_receipts)(),
preserve_fn(self.store.get_recent_events_for_room)(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
self.store, user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
return
except AuthError:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
visibility and
visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)

View File

@@ -17,17 +17,13 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator from synapse.util import unwrapFirstError
from synapse.types import ( from synapse.util.async import concurrently_execute
UserID, RoomAlias, RoomStreamToken, get_domain_from_id from synapse.util.caches.snapshot_cache import SnapshotCache
) from synapse.types import UserID, RoomStreamToken, StreamToken
from synapse.util.async import run_on_reactor, ReadWriteLock
from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@@ -46,24 +42,11 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self.pagination_lock = ReadWriteLock()
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
depth = event.depth
with (yield self.pagination_lock.write(room_id)):
yield self.store.delete_old_state(room_id, depth)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True, event_filter=None): as_client_event=True):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@@ -72,11 +55,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@@ -96,48 +79,42 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)): membership, member_event_id = yield self._check_in_room_or_world_readable(
membership, member_event_id = yield self._check_in_room_or_world_readable( room_id, user_id
room_id, user_id )
)
if source_config.direction == 'b': if source_config.direction == 'b':
# if we're going backwards, we might need to backfill. This # if we're going backwards, we might need to backfill. This
# requires that we have a topo token. # requires that we have a topo token.
if room_token.topological: if room_token.topological:
max_topo = room_token.topological max_topo = room_token.topological
else: else:
max_topo = yield self.store.get_max_topological_token( max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
room_id, room_token.stream room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
) )
events, next_key = yield self.store.paginate_room_events( if membership == Membership.LEAVE:
room_id=room_id, # If they have left the room then clamp the token to be before
from_key=source_config.from_key, # they left the room, to save the effort of loading from the
to_key=source_config.to_key, # database.
direction=source_config.direction, leave_token = yield self.store.get_topological_token_for_event(
limit=source_config.limit, member_event_id
event_filter=event_filter, )
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
) )
next_token = pagin_config.from_token.copy_and_replace( events, next_key = yield data_source.get_pagination_rows(
"room_key", next_key requester.user, source_config, room_id
) )
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
if not events: if not events:
defer.returnValue({ defer.returnValue({
@@ -146,11 +123,7 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
if event_filter: events = yield self._filter_events_for_client(
events = event_filter.filter(events)
events = yield filter_events_for_client(
self.store,
user_id, user_id,
events, events,
is_peeking=(member_event_id is None), is_peeking=(member_event_id is None),
@@ -244,7 +217,7 @@ class MessageHandler(BaseHandler):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state(): if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context) prev_state = self.deduplicate_state_event(event, context)
if prev_state is not None: if prev_state is not None:
defer.returnValue(prev_state) defer.returnValue(prev_state)
@@ -256,10 +229,9 @@ class MessageHandler(BaseHandler):
) )
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler() presence = self.hs.get_handlers().presence_handler
yield presence.bump_presence_active_time(user) yield presence.bump_presence_active_time(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context): def deduplicate_state_event(self, event, context):
""" """
Checks whether event is in the latest resolved state in context. Checks whether event is in the latest resolved state in context.
@@ -267,17 +239,13 @@ class MessageHandler(BaseHandler):
If so, returns the version of the event in context. If so, returns the version of the event in context.
Otherwise, returns None. Otherwise, returns None.
""" """
prev_event_id = context.prev_state_ids.get((event.type, event.state_key)) prev_event = context.current_state.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id: if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content) prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content) next_content = encode_canonical_json(event.content)
if prev_content == next_content: if prev_content == next_content:
defer.returnValue(prev_event) return prev_event
return return None
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_nonmember_event( def create_and_send_nonmember_event(
@@ -388,207 +356,371 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()] [serialize_event(c, now) for c in room_state.values()]
) )
@measure_func("_create_new_client_event") def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
joined, depending on the pagination config.
Args:
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
as_client_event (bool): True to get events in client-server format.
include_archived (bool): True to get rooms that the user has left
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
the user is currently invited or joined in on. Rooms where the user
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
key = (
user_id,
pagin_config.from_token,
pagin_config.to_token,
pagin_config.direction,
pagin_config.limit,
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
))
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None): def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
if prev_event_ids: as_client_event=True, include_archived=False):
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret: memberships = [Membership.INVITE, Membership.JOIN]
depth = max([d for _, _, d in latest_ret]) + 1 if include_archived:
else: memberships.append(Membership.LEAVE)
depth = 1
prev_events = [ room_list = yield self.store.get_rooms_for_user_where_membership_is(
(event_id, prev_hashes) user_id=user_id, membership_list=memberships
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context)
signing_key = self.hs.config.signing_key[0]
add_hashes_and_signatures(
builder, self.server_name, signing_key
) )
event = builder.build() user = UserID.from_string(user_id)
logger.debug( rooms_ret = []
"Created event %s with state: %s",
event.event_id, context.prev_state_ids, now_token = yield self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
) )
defer.returnValue( receipt_stream = self.hs.get_event_sources().sources["receipt"]
(event, context,) receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
) )
@measure_func("handle_new_client_event") tags_by_room = yield self.store.get_tags_for_user(user_id)
@defer.inlineCallbacks
def handle_new_client_event(
self,
requester,
event,
context,
ratelimit=True,
extra_users=[]
):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit: account_data, account_data_by_room = (
self.ratelimit(requester) yield self.store.get_account_data_for_user(user_id)
try:
yield self.auth.check_from_context(event, context)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
def is_inviter_member_event(e):
return (
e.type == EventTypes.Member and
e.sender == event.sender
)
state_to_include_ids = [
e_id
for k, e_id in context.current_state_ids.items()
if k[0] in self.hs.config.room_invite_state_types
or k[0] == EventTypes.Member and k[1] == event.sender
]
state_to_include = yield self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
"type": e.type,
"state_key": e.state_key,
"content": e.content,
"sender": e.sender,
}
for e in state_to_include.values()
]
invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
invitee.domain,
event,
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
if self.auth.check_redaction(event, auth_events=auth_events):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.prev_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
action_generator = ActionGenerator(self.hs)
yield action_generator.handle_push_actions_for_event(
event, context
) )
(event_stream_id, max_stream_id) = yield self.store.persist_event( public_room_ids = yield self.store.get_public_room_ids()
event, context=context
)
# this intentionally does not yield: we don't care about the result limit = pagin_config.limit
# and don't need to wait for it. if limit is None:
preserve_fn(self.hs.get_pusherpool().on_new_notifications)( limit = 10
event_stream_id, max_stream_id
)
users_in_room = yield self.store.get_joined_users_from_context(event, context)
destinations = [
get_domain_from_id(user_id) for user_id in users_in_room
if not self.hs.is_mine_id(user_id)
]
@defer.inlineCallbacks @defer.inlineCallbacks
def _notify(): def handle_room(event):
yield run_on_reactor() d = {
yield self.notifier.on_new_room_event( "room_id": event.room_id,
event, event_stream_id, max_stream_id, "membership": event.membership,
extra_users=extra_users "visibility": (
"public" if event.room_id in public_room_ids
else "private"
),
}
if event.membership == Membership.INVITE:
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = serialize_event(invite_event, time_now, as_client_event)
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
[event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield defer.gatherResults(
[
self.store.get_recent_events_for_room(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = {
"chunk": [
serialize_event(m, time_now, as_client_event)
for m in messages
],
"start": start_token.to_string(),
"end": end_token.to_string(),
}
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
]
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = {
"rooms": rooms_ret,
"presence": presence,
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
requester(Requester): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
) )
preserve_fn(_notify)() account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({
"type": "m.tag",
"content": {"tags": tags},
})
# If invite, remove room_state from unsigned before sending. account_data = yield self.store.get_account_data_for_room(user_id, room_id)
event.unsigned.pop("invite_room_state", None) for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
preserve_fn(federation_handler.handle_new_event)( result["account_data"] = account_data_events
event, destinations=destinations,
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
[member_event_id], None
) )
room_state = room_state[member_event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield self._filter_events_for_client(
user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
defer.returnValue({
"membership": membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
membership, is_peeking):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = [
serialize_event(x, time_now)
for x in current_state.values()
]
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_handlers().presence_handler
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
[m.user_id for m in room_members],
as_event=True,
)
defer.returnValue(states)
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults(
[
get_presence(),
get_receipts(),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, messages, is_peeking=is_peeking,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
ret = {
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
}
if not is_peeking:
ret["membership"] = membership
defer.returnValue(ret)

View File

@@ -33,9 +33,11 @@ from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID
import synapse.metrics import synapse.metrics
from ._base import BaseHandler
import logging import logging
@@ -50,13 +52,6 @@ timers_fired_counter = metrics.register_counter("timers_fired")
federation_presence_counter = metrics.register_counter("federation_presence") federation_presence_counter = metrics.register_counter("federation_presence")
bump_active_time_counter = metrics.register_counter("bump_active_time") bump_active_time_counter = metrics.register_counter("bump_active_time")
get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
state_transition_counter = metrics.register_counter(
"state_transition", labels=["from", "to"]
)
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active" # "currently_active"
@@ -75,26 +70,20 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
# How often to resend presence to remote servers # How often to resend presence to remote servers
FEDERATION_PING_INTERVAL = 25 * 60 * 1000 FEDERATION_PING_INTERVAL = 25 * 60 * 1000
# How long we will wait before assuming that the syncs from an external process
# are dead.
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object): class PresenceHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
self.is_mine = hs.is_mine super(PresenceHandler, self).__init__(hs)
self.is_mine_id = hs.is_mine_id self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.wheel_timer = WheelTimer() self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.state = hs.get_state_handler()
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.presence", self.incoming_presence "m.presence", self.incoming_presence
) )
@@ -149,7 +138,7 @@ class PresenceHandler(object):
obj=state.user_id, obj=state.user_id,
then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT, then=state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
) )
if self.is_mine_id(state.user_id): if self.hs.is_mine_id(state.user_id):
self.wheel_timer.insert( self.wheel_timer.insert(
now=now, now=now,
obj=state.user_id, obj=state.user_id,
@@ -171,38 +160,20 @@ class PresenceHandler(object):
self.serial_to_user = {} self.serial_to_user = {}
self._next_serial = 1 self._next_serial = 1
# Keeps track of the number of *ongoing* syncs on this process. While # Keeps track of the number of *ongoing* syncs. While this is non zero
# this is non zero a user will never go offline. # a user will never go offline.
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
# go offline.
# Each process has a unique identifier and an update frequency. If
# no update is received from that process within the update period then
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {}
self.external_process_last_updated_ms = {}
# Start a LoopingCall in 30s that fires every 5s. # Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
self.clock.call_later( self.clock.call_later(
30, 0 * 1000,
self.clock.looping_call, self.clock.looping_call,
self._handle_timeouts, self._handle_timeouts,
5000, 5000,
) )
self.clock.call_later(
60,
self.clock.looping_call,
self._persist_unpersisted_changes,
60 * 1000,
)
metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer)) metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -217,7 +188,7 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct. is some spurious presence changes that will self-correct.
""" """
logger.info( logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes", "Performing _on_shutdown. Persiting %d unpersisted changes",
len(self.user_to_current_state) len(self.user_to_current_state)
) )
@@ -228,27 +199,6 @@ class PresenceHandler(object):
]) ])
logger.info("Finished _on_shutdown") logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
logger.info(
"Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_states(self, new_states): def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes """Updates presence of users. Sets the appropriate timeouts. Pokes
@@ -265,12 +215,6 @@ class PresenceHandler(object):
to_notify = {} # Changes we want to notify everyone about to_notify = {} # Changes we want to notify everyone about
to_federation_ping = {} # These need sending keep-alives to_federation_ping = {} # These need sending keep-alives
# Only bother handling the last presence change for each user
new_states_dict = {}
for new_state in new_states:
new_states_dict[new_state.user_id] = new_state
new_state = new_states_dict.values()
for new_state in new_states: for new_state in new_states:
user_id = new_state.user_id user_id = new_state.user_id
@@ -284,7 +228,7 @@ class PresenceHandler(object):
new_state, should_notify, should_ping = handle_update( new_state, should_notify, should_ping = handle_update(
prev_state, new_state, prev_state, new_state,
is_mine=self.is_mine_id(user_id), is_mine=self.hs.is_mine_id(user_id),
wheel_timer=self.wheel_timer, wheel_timer=self.wheel_timer,
now=now now=now
) )
@@ -324,48 +268,31 @@ class PresenceHandler(object):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
""" """
logger.info("Handling presence timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()
try: with Measure(self.clock, "presence_handle_timeouts"):
with Measure(self.clock, "presence_handle_timeouts"): # Fetch the list of users that *may* have timed out. Things may have
# Fetch the list of users that *may* have timed out. Things may have # changed since the timeout was set, so we won't necessarily have to
# changed since the timeout was set, so we won't necessarily have to # take any action.
# take any action. users_to_check = self.wheel_timer.fetch(now)
users_to_check = set(self.wheel_timer.fetch(now))
# Check whether the lists of syncing processes from an external states = [
# process have expired. self.user_to_current_state.get(
expired_process_ids = [ user_id, UserPresenceState.default(user_id)
process_id for process_id, last_update
in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_last_updated_ms.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
states = [
self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id)
)
for user_id in users_to_check
]
timers_fired_counter.inc_by(len(states))
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
syncing_user_ids=self.get_currently_syncing_users(),
now=now,
) )
for user_id in set(users_to_check)
]
preserve_fn(self._update_states)(changes) timers_fired_counter.inc_by(len(states))
except:
logger.exception("Exception in _handle_timeouts loop") changes = handle_timeouts(
states,
is_mine_fn=self.hs.is_mine_id,
user_to_num_current_syncs=self.user_to_num_current_syncs,
now=now,
)
preserve_fn(self._update_states)(changes)
@defer.inlineCallbacks @defer.inlineCallbacks
def bump_presence_active_time(self, user): def bump_presence_active_time(self, user):
@@ -438,74 +365,6 @@ class PresenceHandler(object):
defer.returnValue(_user_syncing()) defer.returnValue(_user_syncing())
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
Returns:
set(str): A set of user_id strings.
"""
syncing_user_ids = {
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
syncing_user_ids.update(user_ids)
return syncing_user_ids
@defer.inlineCallbacks
def update_external_syncs(self, process_id, syncing_user_ids):
"""Update the syncing users for an external process
Args:
process_id(str): An identifier for the process the users are
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
syncing_user_ids(set(str)): The set of user_ids that are
currently syncing on that server.
"""
# Grab the previous list of user_ids that were syncing on that process
prev_syncing_user_ids = (
self.external_process_to_current_syncs.get(process_id, set())
)
# Grab the current presence state for both the users that are syncing
# now and the users that were syncing before this update.
prev_states = yield self.current_state_for_users(
syncing_user_ids | prev_syncing_user_ids
)
updates = []
time_now_ms = self.clock.time_msec()
# For each new user that is syncing check if we need to mark them as
# being online.
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
prev_state = prev_states[new_user_id]
if prev_state.state == PresenceState.OFFLINE:
updates.append(prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=time_now_ms,
last_user_sync_ts=time_now_ms,
))
else:
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
# For each user that is still syncing or stopped syncing update the
# last sync time so that we will correctly apply the grace period when
# they stop syncing.
for old_user_id in prev_syncing_user_ids:
prev_state = prev_states[old_user_id]
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
yield self._update_states(updates)
# Update the last updated time for the process. We expire the entries
# if we don't receive an update in the given timeframe.
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
self.external_process_to_current_syncs[process_id] = syncing_user_ids
@defer.inlineCallbacks @defer.inlineCallbacks
def current_state_for_user(self, user_id): def current_state_for_user(self, user_id):
"""Get the current presence state for a user. """Get the current presence state for a user.
@@ -544,7 +403,7 @@ class PresenceHandler(object):
defer.returnValue(states) defer.returnValue(states)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_interested_parties(self, states, calculate_remote_hosts=True): def _get_interested_parties(self, states):
"""Given a list of states return which entities (rooms, users, servers) """Given a list of states return which entities (rooms, users, servers)
are interested in the given states. are interested in the given states.
@@ -567,24 +426,21 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state) users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {} hosts_to_states = {}
if calculate_remote_hosts: for room_id, states in room_ids_to_states.items():
for room_id, states in room_ids_to_states.items(): local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states:
continue
users = yield self.state.get_current_user_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
if not local_states: if not local_states:
continue continue
host = get_domain_from_id(user_id) hosts = yield self.store.get_joined_hosts_for_room(room_id)
for host in hosts:
hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.hs.is_mine_id(s.user_id), states)
if not local_states:
continue
host = UserID.from_string(user_id).domain
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
# TODO: de-dup hosts_to_states, as a single host might have multiple # TODO: de-dup hosts_to_states, as a single host might have multiple
@@ -609,24 +465,24 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states) self._push_to_remotes(hosts_to_states)
@defer.inlineCallbacks
def notify_for_states(self, state, stream_id):
parties = yield self._get_interested_parties([state])
room_ids_to_states, users_to_states, hosts_to_states = parties
self.notifier.on_new_event(
"presence_key", stream_id, rooms=room_ids_to_states.keys(),
users=[UserID.from_string(u) for u in users_to_states.keys()]
)
def _push_to_remotes(self, hosts_to_states): def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers. """Sends state updates to remote servers.
Args: Args:
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]` hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
""" """
now = self.clock.time_msec()
for host, states in hosts_to_states.items(): for host, states in hosts_to_states.items():
self.federation.send_presence(host, states) self.federation.send_edu(
destination=host,
edu_type="m.presence",
content={
"push": [
_format_user_presence_state(state, now)
for state in states
]
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def incoming_presence(self, origin, content): def incoming_presence(self, origin, content):
@@ -647,13 +503,6 @@ class PresenceHandler(object):
) )
continue continue
if get_domain_from_id(user_id) != origin:
logger.info(
"Got presence update from %r with bad 'user_id': %r",
origin, user_id,
)
continue
presence_state = push.get("presence", None) presence_state = push.get("presence", None)
if not presence_state: if not presence_state:
logger.info( logger.info(
@@ -713,17 +562,17 @@ class PresenceHandler(object):
defer.returnValue([ defer.returnValue([
{ {
"type": "m.presence", "type": "m.presence",
"content": format_user_presence_state(state, now), "content": _format_user_presence_state(state, now),
} }
for state in updates for state in updates
]) ])
else: else:
defer.returnValue([ defer.returnValue([
format_user_presence_state(state, now) for state in updates _format_user_presence_state(state, now) for state in updates
]) ])
@defer.inlineCallbacks @defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False): def set_state(self, target_user, state):
"""Set the presence state of the user. """Set the presence state of the user.
""" """
status_msg = state.get("status_msg", None) status_msg = state.get("status_msg", None)
@@ -740,13 +589,10 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id) prev_state = yield self.current_state_for_user(user_id)
new_fields = { new_fields = {
"state": presence "state": presence,
"status_msg": status_msg if presence != PresenceState.OFFLINE else None
} }
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
new_fields["status_msg"] = msg
if presence == PresenceState.ONLINE: if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec() new_fields["last_active_ts"] = self.clock.time_msec()
@@ -765,14 +611,14 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part # don't need to send to local clients here, as that is done as part
# of the event stream/sync. # of the event stream/sync.
# TODO: Only send to servers not already in the room. # TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id) if self.hs.is_mine(user):
if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string()) state = yield self.current_state_for_user(user.to_string())
hosts = set(get_domain_from_id(u) for u in user_ids) hosts = yield self.store.get_joined_hosts_for_room(room_id)
self._push_to_remotes({host: (state,) for host in hosts}) self._push_to_remotes({host: (state,) for host in hosts})
else: else:
user_ids = filter(self.is_mine_id, user_ids) user_ids = yield self.store.get_users_in_room(room_id)
user_ids = filter(self.hs.is_mine_id, user_ids)
states = yield self.current_state_for_users(user_ids) states = yield self.current_state_for_users(user_ids)
@@ -782,7 +628,7 @@ class PresenceHandler(object):
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
"""Returns the presence for all users in their presence list. """Returns the presence for all users in their presence list.
""" """
if not self.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list( presence_list = yield self.store.get_presence_list(
@@ -813,7 +659,7 @@ class PresenceHandler(object):
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
if self.is_mine(observed_user): if self.hs.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user) yield self.invite_presence(observed_user, observer_user)
else: else:
yield self.federation.send_edu( yield self.federation.send_edu(
@@ -829,11 +675,11 @@ class PresenceHandler(object):
def invite_presence(self, observed_user, observer_user): def invite_presence(self, observed_user, observer_user):
"""Handles new presence invites. """Handles new presence invites.
""" """
if not self.is_mine(observed_user): if not self.hs.is_mine(observed_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
# TODO: Don't auto accept # TODO: Don't auto accept
if self.is_mine(observer_user): if self.hs.is_mine(observer_user):
yield self.accept_presence(observed_user, observer_user) yield self.accept_presence(observed_user, observer_user)
else: else:
self.federation.send_edu( self.federation.send_edu(
@@ -896,7 +742,7 @@ class PresenceHandler(object):
Returns: Returns:
A Deferred. A Deferred.
""" """
if not self.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.del_presence_list( yield self.store.del_presence_list(
@@ -947,38 +793,28 @@ class PresenceHandler(object):
def should_notify(old_state, new_state): def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties. """Decides if a presence state change should be sent to interested parties.
""" """
if old_state == new_state:
return False
if old_state.status_msg != new_state.status_msg: if old_state.status_msg != new_state.status_msg:
notify_reason_counter.inc("status_msg_change")
return True
if old_state.state != new_state.state:
notify_reason_counter.inc("state_change")
state_transition_counter.inc(old_state.state, new_state.state)
return True return True
if old_state.state == PresenceState.ONLINE: if old_state.state == PresenceState.ONLINE:
if new_state.currently_active != old_state.currently_active: if new_state.state != PresenceState.ONLINE:
notify_reason_counter.inc("current_active_change") # Always notify for online -> anything
return True return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: if new_state.currently_active != old_state.currently_active:
# Only notify about last active bumps if we're not currently acive return True
if not new_state.currently_active:
notify_reason_counter.inc("last_active_change_online")
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped. # Always notify for a transition where last active gets bumped.
notify_reason_counter.inc("last_active_change_not_online") return True
if old_state.state != new_state.state:
return True return True
return False return False
def format_user_presence_state(state, now): def _format_user_presence_state(state, now):
"""Convert UserPresenceState to a format that can be sent down to clients """Convert UserPresenceState to a format that can be sent down to clients
and to other servers. and to other servers.
""" """
@@ -998,14 +834,9 @@ def format_user_presence_state(state, now):
class PresenceEventSource(object): class PresenceEventSource(object):
def __init__(self, hs): def __init__(self, hs):
# We can't call get_presence_handler here because there's a cycle: self.hs = hs
#
# Presence -> Notifier -> PresenceEventSource -> Presence
#
self.get_presence_handler = hs.get_presence_handler
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@@ -1029,7 +860,7 @@ class PresenceEventSource(object):
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or [] room_ids = room_ids or []
presence = self.get_presence_handler() presence = self.hs.get_handlers().presence_handler
stream_change_cache = self.store.presence_stream_cache stream_change_cache = self.store.presence_stream_cache
if not room_ids: if not room_ids:
@@ -1046,13 +877,13 @@ class PresenceEventSource(object):
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
if from_key: if from_key and max_token - from_key < 100:
changed = stream_change_cache.get_all_entities_changed(from_key)
if changed is not None and len(changed) < 500:
# For small deltas, its quicker to get all changes and then # For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list # work out if we share a room or they're in our presence list
get_updates_counter.inc("stream") changed = stream_change_cache.get_all_entities_changed(from_key)
# get_all_entities_changed can return None
if changed is not None:
for other_user_id in changed: for other_user_id in changed:
if other_user_id in friends: if other_user_id in friends:
user_ids_changed.add(other_user_id) user_ids_changed.add(other_user_id)
@@ -1064,11 +895,9 @@ class PresenceEventSource(object):
else: else:
# Too many possible updates. Find all users we can see and check # Too many possible updates. Find all users we can see and check
# if any of them have changed. # if any of them have changed.
get_updates_counter.inc("full")
user_ids_to_check = set() user_ids_to_check = set()
for room_id in room_ids: for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
user_ids_to_check.update(users) user_ids_to_check.update(users)
user_ids_to_check.update(friends) user_ids_to_check.update(friends)
@@ -1091,7 +920,7 @@ class PresenceEventSource(object):
defer.returnValue(([ defer.returnValue(([
{ {
"type": "m.presence", "type": "m.presence",
"content": format_user_presence_state(s, now), "content": _format_user_presence_state(s, now),
} }
for s in updates.values() for s in updates.values()
if include_offline or s.state != PresenceState.OFFLINE if include_offline or s.state != PresenceState.OFFLINE
@@ -1104,14 +933,15 @@ class PresenceEventSource(object):
return self.get_new_events(user, from_key=None, include_offline=False) return self.get_new_events(user, from_key=None, include_offline=False)
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now): def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
Args: Args:
user_states(list): List of UserPresenceState's to check. user_states(list): List of UserPresenceState's to check.
is_mine_fn (fn): Function that returns if a user_id is ours is_mine_fn (fn): Function that returns if a user_id is ours
syncing_user_ids (set): Set of user_ids with active syncs. user_to_num_current_syncs (dict): Mapping of user_id to number of currently
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@@ -1122,20 +952,21 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
for state in user_states: for state in user_states:
is_mine = is_mine_fn(state.user_id) is_mine = is_mine_fn(state.user_id)
new_state = handle_timeout(state, is_mine, syncing_user_ids, now) new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now)
if new_state: if new_state:
changes[state.user_id] = new_state changes[state.user_id] = new_state
return changes.values() return changes.values()
def handle_timeout(state, is_mine, syncing_user_ids, now): def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
"""Checks the presence of the user to see if any of the timers have elapsed """Checks the presence of the user to see if any of the timers have elapsed
Args: Args:
state (UserPresenceState) state (UserPresenceState)
is_mine (bool): Whether the user is ours is_mine (bool): Whether the user is ours
syncing_user_ids (set): Set of user_ids with active syncs. user_to_num_current_syncs (dict): Mapping of user_id to number of currently
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@@ -1169,7 +1000,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# If there are have been no sync for a while (and none ongoing), # If there are have been no sync for a while (and none ongoing),
# set presence to offline # set presence to offline
if user_id not in syncing_user_ids: if not user_to_num_current_syncs.get(user_id, 0):
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace( state = state.copy_and_replace(
state=PresenceState.OFFLINE, state=PresenceState.OFFLINE,

View File

@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID from synapse.types import UserID, Requester
from ._base import BaseHandler from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -36,6 +36,13 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query "profile", self.on_profile_query
) )
distributor = hs.get_distributor()
distributor.observe("registered_user", self.registered_user)
def registered_user(self, user):
return self.store.create_profile(user.localpart)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_displayname(self, target_user): def get_displayname(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
@@ -165,9 +172,7 @@ class ProfileHandler(BaseHandler):
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the user isn't a guest because we don't let guests set
# profile or avatar data. # profile or avatar data.
# XXX why are we recreating `requester` here for each room? requester = Requester(user, "", False)
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, user,

View File

@@ -18,7 +18,6 @@ from ._base import BaseHandler
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import get_domain_from_id
import logging import logging
@@ -30,15 +29,12 @@ class ReceiptsHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs) super(ReceiptsHandler, self).__init__(hs)
self.server_name = hs.config.server_name
self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.receipt", self._received_remote_receipt "m.receipt", self._received_remote_receipt
) )
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, def received_client_receipt(self, room_id, receipt_type, user_id,
@@ -84,9 +80,6 @@ class ReceiptsHandler(BaseHandler):
def _handle_new_receipts(self, receipts): def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier. """Takes a list of receipts, stores them and informs the notifier.
""" """
min_batch_id = None
max_batch_id = None
for receipt in receipts: for receipt in receipts:
room_id = receipt["room_id"] room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"] receipt_type = receipt["receipt_type"]
@@ -104,21 +97,10 @@ class ReceiptsHandler(BaseHandler):
stream_id, max_persisted_id = res stream_id, max_persisted_id = res
if min_batch_id is None or stream_id < min_batch_id: with PreserveLoggingContext():
min_batch_id = stream_id self.notifier.on_new_event(
if max_batch_id is None or max_persisted_id > max_batch_id: "receipt_key", max_persisted_id, rooms=[room_id]
max_batch_id = max_persisted_id )
affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
defer.returnValue(True) defer.returnValue(True)
@@ -135,10 +117,12 @@ class ReceiptsHandler(BaseHandler):
event_ids = receipt["event_ids"] event_ids = receipt["event_ids"]
data = receipt["data"] data = receipt["data"]
users = yield self.state.get_current_user_in_room(room_id) remotedomains = set()
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy() rm_handler = self.hs.get_handlers().room_member_handler
remotedomains.discard(self.server_name) yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains) logger.debug("Sending receipt to: %r", remotedomains)
@@ -156,7 +140,6 @@ class ReceiptsHandler(BaseHandler):
} }
}, },
}, },
key=(room_id, receipt_type, user_id),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@@ -14,19 +14,19 @@
# limitations under the License. # limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
import synapse.types from synapse.types import UserID
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from synapse.http.client import CaptchaServerHttpClient
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
from synapse.util.distributor import registered_user
import logging
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,6 +37,8 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.distributor = hs.get_distributor()
self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@@ -53,13 +55,6 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
if localpart[0] == '_':
raise SynapseError(
400,
"User ID may not begin with _",
Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@@ -98,8 +93,7 @@ class RegistrationHandler(BaseHandler):
password=None, password=None,
generate_token=True, generate_token=True,
guest_access_token=None, guest_access_token=None,
make_guest=False, make_guest=False
admin=False,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@@ -107,13 +101,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@@ -151,12 +140,9 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=(
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
admin=admin,
) )
yield registered_user(self.distributor, user)
else: else:
# autogen a sequential user ID # autogen a sequential user ID
attempts = 0 attempts = 0
@@ -174,8 +160,7 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
make_guest=make_guest, make_guest=make_guest
create_profile_with_localpart=user.localpart,
) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
@@ -183,6 +168,7 @@ class RegistrationHandler(BaseHandler):
user_id = None user_id = None
token = None token = None
attempts += 1 attempts += 1
yield registered_user(self.distributor, user)
# We used to generate default identicons here, but nowadays # We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding # we want clients to generate their own as part of their branding
@@ -209,13 +195,15 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart,
) )
defer.returnValue(user_id) yield registered_user(self.distributor, user)
defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
@@ -260,9 +248,9 @@ class RegistrationHandler(BaseHandler):
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=None, password_hash=None
create_profile_with_localpart=user.localpart,
) )
yield registered_user(self.distributor, user)
except Exception as e: except Exception as e:
yield self.store.add_access_token_to_user(user_id, token) yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors # Ignore Registration errors
@@ -370,63 +358,8 @@ class RegistrationHandler(BaseHandler):
) )
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_in_ms,
password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
Args:
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
Returns:
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
"""
yield run_on_reactor()
if localpart is None:
raise SynapseError(400, "Request must include user id")
need_register = True
try:
yield self.check_username(localpart)
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
need_register = False
else:
raise
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.auth_handler().generate_access_token(
user_id, None, duration_in_ms)
if need_register:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash,
create_profile_with_localpart=user.localpart,
)
else:
yield self.store.user_delete_access_tokens(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname(
user, requester, displayname
)
defer.returnValue((user_id, token))
def auth_handler(self): def auth_handler(self):
return self.hs.get_auth_handler() return self.hs.get_handlers().auth_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def guest_access_token_for(self, medium, address, inviter_user_id):

View File

@@ -20,11 +20,12 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, JoinRules, RoomCreationPreset EventTypes, JoinRules, RoomCreationPreset,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.visibility import filter_events_for_client from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from collections import OrderedDict from collections import OrderedDict
@@ -192,11 +193,6 @@ class RoomCreationHandler(BaseHandler):
}, },
ratelimit=False) ratelimit=False)
content = {}
is_direct = config.get("is_direct", None)
if is_direct:
content["is_direct"] = is_direct
for invitee in invite_list: for invitee in invite_list:
yield room_member_handler.update_membership( yield room_member_handler.update_membership(
requester, requester,
@@ -204,7 +200,6 @@ class RoomCreationHandler(BaseHandler):
room_id, room_id,
"invite", "invite",
ratelimit=False, ratelimit=False,
content=content,
) )
for invite_3pid in invite_3pid_list: for invite_3pid in invite_3pid_list:
@@ -344,6 +339,91 @@ class RoomCreationHandler(BaseHandler):
) )
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
def get_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks
def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks
def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
# We pull each bit of state out indvidually to avoid pulling the
# full state into memory. Due to how the caching works this should
# be fairly quick, even if not originally in the cache.
def get_state(etype, state_key):
return self.state_handler.get_current_state(room_id, etype, state_key)
# Double check that this is actually a public room.
join_rules_event = yield get_state(EventTypes.JoinRules, "")
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
result = {"room_id": room_id}
if aliases:
result["aliases"] = aliases
name_event = yield get_state(EventTypes.Name, "")
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = yield get_state(EventTypes.Topic, "")
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = yield get_state(EventTypes.CanonicalAlias, "")
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "")
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = yield get_state(EventTypes.GuestAccess, "")
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = yield get_state("m.room.avatar", "")
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
joined_users = yield self.store.get_users_in_room(room_id)
result["num_joined_members"] = len(joined_users)
results.append(result)
yield concurrently_execute(handle_room, room_ids, 10)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results})
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, is_guest): def get_event_context(self, user, room_id, event_id, limit, is_guest):
@@ -366,12 +446,10 @@ class RoomContextHandler(BaseHandler):
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
def filter_evts(events): def filter_evts(events):
return filter_events_for_client( return self._filter_events_for_client(
self.store,
user.to_string(), user.to_string(),
events, events,
is_peeking=is_guest is_peeking=is_guest)
)
event = yield self.store.get_event(event_id, get_prev_content=True, event = yield self.store.get_event(event_id, get_prev_content=True,
allow_none=True) allow_none=True)

View File

@@ -1,403 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from collections import namedtuple
from unpaddedbase64 import encode_base64, decode_base64
import logging
import msgpack
logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache(hs)
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We explicitly don't bother caching searches.
return self._get_public_room_list(limit, since_token, search_filter)
result = self.response_cache.get((limit, since_token))
if not result:
result = self.response_cache.set(
(limit, since_token),
self._get_public_room_list(limit, since_token)
)
return result
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
since_token = None
rooms_to_order_value = {}
rooms_to_num_joined = {}
rooms_to_latest_event_ids = {}
newly_visible = []
newly_unpublished = []
if since_token:
stream_token = since_token.stream_ordering
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
public_room_stream_id, current_public_id
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
public_room_stream_id
)
# We want to return rooms in a particular order: the number of joined
# users. We then arbitrarily use the room_id as a tie breaker.
@defer.inlineCallbacks
def get_order_for_room(room_id):
latest_event_ids = rooms_to_latest_event_ids.get(room_id, None)
if not latest_event_ids:
latest_event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_token
)
rooms_to_latest_event_ids[room_id] = latest_event_ids
if not latest_event_ids:
return
joined_users = yield self.state_handler.get_current_user_in_room(
room_id, latest_event_ids,
)
num_joined_users = len(joined_users)
rooms_to_num_joined[room_id] = num_joined_users
if num_joined_users == 0:
return
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
sorted_rooms = [room_id for room_id, _ in sorted_entries]
# `sorted_rooms` should now be a list of all public room ids that is
# stable across pagination. Therefore, we can use indices into this
# list as our pagination tokens.
# Filter out rooms that we don't want to return
rooms_to_scan = [
r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
]
total_room_count = len(rooms_to_scan)
if since_token:
# Filter out rooms we've already returned previously
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it.
if since_token.direction_is_forward:
rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
else:
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
rooms_to_scan.reverse()
# Actually generate the entries. _generate_room_entry will append to
# chunk but will stop if len(chunk) > limit
chunk = []
if limit and not search_filter:
step = limit + 1
for i in xrange(0, len(rooms_to_scan), step):
# We iterate here because the vast majority of cases we'll stop
# at first iteration, but occaisonally _generate_room_entry
# won't append to the chunk and so we need to loop again.
# We don't want to scan over the entire range either as that
# would potentially waste a lot of work.
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan[i:i + step], 10
)
if len(chunk) >= limit + 1:
break
else:
yield concurrently_execute(
lambda r: self._generate_room_entry(
r, rooms_to_num_joined[r],
chunk, limit, search_filter
),
rooms_to_scan, 5
)
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
# Work out the new limit of the batch for pagination, or None if we
# know there are no more results that would be returned.
# i.e., [since_token.current_limit..new_limit] is the batch of rooms
# we've returned (or the reverse if we paginated backwards)
# We tried to pull out limit + 1 rooms above, so if we have <= limit
# then we know there are no more results to return
new_limit = None
if chunk and (not limit or len(chunk) > limit):
if not since_token or since_token.direction_is_forward:
if limit:
chunk = chunk[:limit]
last_room_id = chunk[-1]["room_id"]
else:
if limit:
chunk = chunk[-limit:]
last_room_id = chunk[0]["room_id"]
new_limit = sorted_rooms.index(last_room_id)
results = {
"chunk": chunk,
"total_room_count_estimate": total_room_count,
}
if since_token:
results["new_rooms"] = bool(newly_visible)
if not since_token or since_token.direction_is_forward:
if new_limit is not None:
results["next_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=True,
).to_token()
if since_token:
results["prev_batch"] = since_token.copy_and_replace(
direction_is_forward=False,
current_limit=since_token.current_limit + 1,
).to_token()
else:
if new_limit is not None:
results["prev_batch"] = RoomListNextBatch(
stream_ordering=stream_token,
public_room_stream_id=public_room_stream_id,
current_limit=new_limit,
direction_is_forward=False,
).to_token()
if since_token:
results["next_batch"] = since_token.copy_and_replace(
direction_is_forward=True,
current_limit=since_token.current_limit - 1,
).to_token()
defer.returnValue(results)
@defer.inlineCallbacks
def _generate_room_entry(self, room_id, num_joined_users, chunk, limit,
search_filter):
if limit and len(chunk) > limit + 1:
# We've already got enough, so lets just drop it.
return
result = {
"room_id": room_id,
"num_joined_members": num_joined_users,
}
current_state_ids = yield self.state_handler.get_current_state_ids(room_id)
event_map = yield self.store.get_events([
event_id for key, event_id in current_state_ids.items()
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
EventTypes.Topic,
EventTypes.CanonicalAlias,
EventTypes.RoomHistoryVisibility,
EventTypes.GuestAccess,
"m.room.avatar",
)
])
current_state = {
(ev.type, ev.state_key): ev
for ev in event_map.values()
}
# Double check that this is actually a public room.
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC:
defer.returnValue(None)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
name_event = yield current_state.get((EventTypes.Name, ""))
if name_event:
name = name_event.content.get("name", None)
if name:
result["name"] = name
topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event:
topic = topic_event.content.get("topic", None)
if topic:
result["topic"] = topic
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event:
canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias:
result["canonical_alias"] = canonical_alias
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None
if visibility_event:
visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable"
guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None
if guest_event:
guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join"
avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event:
avatar_url = avatar_event.content.get("url", None)
if avatar_url:
result["avatar_url"] = avatar_url
if _matches_room_entry(result, search_filter):
chunk.append(result)
@defer.inlineCallbacks
def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
search_filter=None):
if search_filter:
# We currently don't support searching across federation, so we have
# to do it manually without pagination
limit = None
since_token = None
res = yield self._get_remote_list_cached(
server_name, limit=limit, since_token=since_token,
)
if search_filter:
res = {"chunk": [
entry
for entry in list(res.get("chunk", []))
if _matches_room_entry(entry, search_filter)
]}
defer.returnValue(res)
def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
search_filter=None):
repl_layer = self.hs.get_replication_layer()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
result = self.remote_response_cache.get((server_name, limit, since_token))
if not result:
result = self.remote_response_cache.set(
(server_name, limit, since_token),
repl_layer.get_public_rooms(
server_name, limit=limit, since_token=since_token,
search_filter=search_filter,
)
)
return result
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
"stream_ordering", # stream_ordering of the first public room list
"public_room_stream_id", # public room stream id for first public room list
"current_limit", # The number of previous rooms returned
"direction_is_forward", # Bool if this is a next_batch, false if prev_batch
))):
KEY_DICT = {
"stream_ordering": "s",
"public_room_stream_id": "p",
"current_limit": "n",
"direction_is_forward": "d",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@classmethod
def from_token(cls, token):
return RoomListNextBatch(**{
cls.REVERSE_KEY_DICT[key]: val
for key, val in msgpack.loads(decode_base64(token)).items()
})
def to_token(self):
return encode_base64(msgpack.dumps({
self.KEY_DICT[key]: val
for key, val in self._asdict().items()
}))
def copy_and_replace(self, **kwds):
return self._replace(
**kwds
)
def _matches_room_entry(room_entry, search_filter):
if search_filter and search_filter.get("generic_search_term", None):
generic_search_term = search_filter["generic_search_term"].upper()
if generic_search_term in room_entry.get("name", "").upper():
return True
elif generic_search_term in room_entry.get("topic", "").upper():
return True
elif generic_search_term in room_entry.get("canonical_alias", "").upper():
return True
else:
return True
return False

View File

@@ -14,22 +14,24 @@
# limitations under the License. # limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, EventTypes, Membership,
) )
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -53,19 +55,45 @@ class RoomMemberHandler(BaseHandler):
self.distributor.declare("user_joined_room") self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room") self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks @defer.inlineCallbacks
def _local_membership_update( def _local_membership_update(
self, requester, target, room_id, membership, self, requester, target, room_id, membership,
prev_event_ids, prev_event_ids,
txn_id=None, txn_id=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
if content is None:
content = {}
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership content = {"membership": membership}
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
@@ -85,13 +113,7 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
) )
# Check if this event matches the previous membership event for the user. yield self.handle_new_client_event(
duplicate = yield msg_handler.deduplicate_state_event(event, context)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return
yield msg_handler.handle_new_client_event(
requester, requester,
event, event,
context, context,
@@ -99,26 +121,20 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.prev_state_ids.get( prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()), (EventTypes.Member, target.to_string()),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# room. Don't bother if the user is just changing their profile # Only fire user_joined_room if the user has acutally joined the
# info. # room. Don't bother if the user is just changing their profile
newly_joined = True # info.
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target, room_id) yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event and prev_member_event.membership == Membership.JOIN:
prev_member_event = yield self.store.get_event(prev_member_event_id) user_left_room(self.distributor, target, room_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content): def remote_join(self, remote_room_hosts, room_id, user, content):
@@ -155,9 +171,8 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None, remote_room_hosts=None,
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
key = (room_id,) key = (target, room_id,)
with (yield self.member_linearizer.queue(key)): with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership( result = yield self._update_membership(
@@ -169,7 +184,6 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed, third_party_signed=third_party_signed,
ratelimit=ratelimit, ratelimit=ratelimit,
content=content,
) )
defer.returnValue(result) defer.returnValue(result)
@@ -185,11 +199,7 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts=None, remote_room_hosts=None,
third_party_signed=None, third_party_signed=None,
ratelimit=True, ratelimit=True,
content=None,
): ):
if content is None:
content = {}
effective_membership_state = action effective_membership_state = action
if action in ["kick", "unban"]: if action in ["kick", "unban"]:
effective_membership_state = "leave" effective_membership_state = "leave"
@@ -207,32 +217,29 @@ class RoomMemberHandler(BaseHandler):
remote_room_hosts = [] remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids( current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids, room_id, latest_event_ids=latest_event_ids,
) )
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string())) old_state = current_state.get((EventTypes.Member, target.to_string()))
if old_state_id: old_membership = old_state.content.get("membership") if old_state else None
old_state = yield self.store.get_event(old_state_id, allow_none=True) if action == "unban" and old_membership != "ban":
old_membership = old_state.content.get("membership") if old_state else None raise SynapseError(
if action == "unban" and old_membership != "ban": 403,
raise SynapseError( "Cannot unban user who was not banned (membership=%s)" % old_membership,
403, errcode=Codes.BAD_STATE
"Cannot unban user who was not banned" )
" (membership=%s)" % old_membership, if old_membership == "ban" and action != "unban":
errcode=Codes.BAD_STATE raise SynapseError(
) 403,
if old_membership == "ban" and action != "unban": "Cannot %s user who was is banned" % (action,),
raise SynapseError( errcode=Codes.BAD_STATE
403, )
"Cannot %s user who was banned" % (action,),
errcode=Codes.BAD_STATE
)
is_host_in_room = yield self._is_host_in_room(current_state_ids) is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN: if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state_ids): if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept, # This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process. # so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@@ -242,7 +249,7 @@ class RoomMemberHandler(BaseHandler):
if inviter and not self.hs.is_mine(inviter): if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain) remote_room_hosts.append(inviter.domain)
content["membership"] = Membership.JOIN content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
@@ -296,7 +303,6 @@ class RoomMemberHandler(BaseHandler):
txn_id=txn_id, txn_id=txn_id,
ratelimit=ratelimit, ratelimit=ratelimit,
prev_event_ids=latest_event_ids, prev_event_ids=latest_event_ids,
content=content,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -338,22 +344,20 @@ class RoomMemberHandler(BaseHandler):
) )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = synapse.types.create_requester(target_user) requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = yield message_handler.deduplicate_state_event(event, context) prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None: if prev_event is not None:
return return
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
if requester.is_guest: if requester.is_guest and not self._can_guest_join(context.current_state):
guest_can_join = yield self._can_guest_join(context.prev_state_ids) # This should be an auth check, but guests are a local concept,
if not guest_can_join: # so don't really fit into the general auth process.
# This should be an auth check, but guests are a local concept, raise AuthError(403, "Guest access not allowed")
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield message_handler.handle_new_client_event( yield self.handle_new_client_event(
requester, requester,
event, event,
context, context,
@@ -361,39 +365,27 @@ class RoomMemberHandler(BaseHandler):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
prev_member_event_id = context.prev_state_ids.get( prev_member_event = context.current_state.get(
(EventTypes.Member, event.state_key), (EventTypes.Member, target_user.to_string()),
None None
) )
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# room. Don't bother if the user is just changing their profile # Only fire user_joined_room if the user has acutally joined the
# info. # room. Don't bother if the user is just changing their profile
newly_joined = True # info.
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield user_joined_room(self.distributor, target_user, room_id) yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE: elif event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event and prev_member_event.membership == Membership.JOIN:
prev_member_event = yield self.store.get_event(prev_member_event_id) user_left_room(self.distributor, target_user, room_id)
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
@defer.inlineCallbacks def _can_guest_join(self, current_state):
def _can_guest_join(self, current_state_ids):
""" """
Returns whether a guest can join a room based on its current state. Returns whether a guest can join a room based on its current state.
""" """
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None) guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
if not guest_access_id: return (
defer.returnValue(False)
guest_access = yield self.store.get_event(guest_access_id)
defer.returnValue(
guest_access guest_access
and guest_access.content and guest_access.content
and "guest_access" in guest_access.content and "guest_access" in guest_access.content
@@ -434,6 +426,21 @@ class RoomMemberHandler(BaseHandler):
if invite: if invite:
defer.returnValue(UserID.from_string(invite.sender)) defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks @defer.inlineCallbacks
def do_3pid_invite( def do_3pid_invite(
self, self,
@@ -450,7 +457,8 @@ class RoomMemberHandler(BaseHandler):
) )
if invitee: if invitee:
yield self.update_membership( handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester, requester,
UserID.from_string(invitee), UserID.from_string(invitee),
room_id, room_id,
@@ -712,24 +720,3 @@ class RoomMemberHandler(BaseHandler):
if membership: if membership:
yield self.store.forget(user_id, room_id) yield self.store.forget(user_id, room_id)
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
defer.returnValue(self.hs.is_mine_id(create_event_id))
for (etype, state_key), event_id in current_state_ids.items():
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
event = yield self.store.get_event(event_id, allow_none=True)
if not event:
continue
if event.membership == Membership.JOIN:
defer.returnValue(True)
defer.returnValue(False)

View File

@@ -21,7 +21,6 @@ from synapse.api.constants import Membership, EventTypes
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.visibility import filter_events_for_client
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
@@ -173,8 +172,8 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results]) filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client( events = yield self._filter_events_for_client(
self.store, user.to_string(), filtered_events user.to_string(), filtered_events
) )
events.sort(key=lambda e: -rank_map[e.event_id]) events.sort(key=lambda e: -rank_map[e.event_id])
@@ -224,8 +223,8 @@ class SearchHandler(BaseHandler):
r["event"] for r in results r["event"] for r in results
]) ])
events = yield filter_events_for_client( events = yield self._filter_events_for_client(
self.store, user.to_string(), filtered_events user.to_string(), filtered_events
) )
room_events.extend(events) room_events.extend(events)
@@ -282,12 +281,12 @@ class SearchHandler(BaseHandler):
event.room_id, event.event_id, before_limit, after_limit event.room_id, event.event_id, before_limit, after_limit
) )
res["events_before"] = yield filter_events_for_client( res["events_before"] = yield self._filter_events_for_client(
self.store, user.to_string(), res["events_before"] user.to_string(), res["events_before"]
) )
res["events_after"] = yield filter_events_for_client( res["events_after"] = yield self._filter_events_for_client(
self.store, user.to_string(), res["events_after"] user.to_string(), res["events_after"]
) )
res["start"] = now_token.copy_and_replace( res["start"] = now_token.copy_and_replace(

Some files were not shown because too many files have changed in this diff Show More