Compare commits

..

137 Commits

Author SHA1 Message Date
Erik Johnston
7861cfec0a Add authors to changelog 2016-07-28 14:35:05 +01:00
Erik Johnston
019cf013d6 Update changelog 2016-07-28 11:30:08 +01:00
Erik Johnston
a285194021 Merge branch 'erikj/key_client_fix' of github.com:matrix-org/synapse into release-v0.17.0 2016-07-28 10:47:06 +01:00
Erik Johnston
7871790db1 Bump version and changelog 2016-07-28 10:38:56 +01:00
Erik Johnston
18e044628e Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.17.0 2016-07-28 10:37:43 +01:00
Erik Johnston
b557b682d9 Merge pull request #961 from matrix-org/dbkr/fix_push_invite_name
Don't include name of room for invites in push
2016-07-28 10:36:41 +01:00
David Baker
389c890f14 Don't include name of room for invites in push
Avoids insane pushes like, "Bob invited you to invite from Bob"
2016-07-28 10:20:47 +01:00
Richard van der Hoff
cd8738ab63 Merge pull request #960 from matrix-org/rav/support_r0.2
Add r0.2.0 to the "supported versions" list
2016-07-28 10:14:33 +01:00
Richard van der Hoff
f6f8f81a48 Add r0.1.0 to the "supported versions" list 2016-07-28 10:14:07 +01:00
David Baker
ecd5e6bfa4 Typo 2016-07-28 10:04:52 +01:00
Richard van der Hoff
fda078f995 Add r0.2.0 to the "supported versions" list 2016-07-28 09:14:21 +01:00
Richard van der Hoff
40e539683c Merge pull request #956 from matrix-org/rav/check_device_id_on_key_upload
Make the device id on e2e key upload optional
2016-07-27 18:16:57 +01:00
Erik Johnston
5238960850 Bump CHANGES and version 2016-07-27 17:33:09 +01:00
Richard van der Hoff
ccec25e2c6 key upload tweaks
1. Add v2_alpha URL back in, since things seem to be using it.

2. Don't reject the request if the device_id in the upload request fails to
   match that in the access_token.
2016-07-27 16:41:06 +01:00
Mark Haines
c38b7c4104 Merge pull request #957 from matrix-org/markjh/verify
Clean up verify_json_objects_for_server
2016-07-27 15:33:32 +01:00
Mark Haines
29b25d59c6 Merge branch 'develop' into markjh/verify
Conflicts:
	synapse/crypto/keyring.py
2016-07-27 15:11:02 +01:00
Mark Haines
884b800899 Merge pull request #955 from matrix-org/markjh/only_from2
Add a couple more checks to the keyring
2016-07-27 15:08:22 +01:00
Mark Haines
09d31815b4 Merge pull request #954 from matrix-org/markjh/even_more_fixes
Fix a couple of bugs in the transaction and keyring code
2016-07-27 15:08:14 +01:00
Mark Haines
fe1b369946 Clean up verify_json_objects_for_server 2016-07-27 14:10:43 +01:00
Richard van der Hoff
26cb0efa88 SQL syntax fix 2016-07-27 12:30:22 +01:00
Richard van der Hoff
d47115ff8b Delete e2e keys on device delete 2016-07-27 12:24:52 +01:00
Richard van der Hoff
2e3d90d67c Make the device id on e2e key upload optional
We should now be able to get our device_id from the access_token, so the
device_id on the upload request is optional. Where it is supplied, we should
check that it matches.

For active access_tokens without an associated device_id, we ought to register
the device in the devices table.

Also update the table on upgrade so that all of the existing e2e keys are
associated with real devices.
2016-07-26 23:38:12 +01:00
Mark Haines
a4b06b619c Add a couple more checks to the keyring 2016-07-26 19:50:11 +01:00
Mark Haines
c63b1697f4 Merge pull request #952 from matrix-org/markjh/more_fixes
Check if the user is banned when handling 3pid invites
2016-07-26 19:20:56 +01:00
Mark Haines
87ffd21b29 Fix a couple of bugs in the transaction and keyring code 2016-07-26 19:19:08 +01:00
Richard van der Hoff
2452611d0f Merge pull request #953 from matrix-org/rav/requester
Add `create_requester` function
2016-07-26 16:57:53 +01:00
Richard van der Hoff
eb359eced4 Add create_requester function
Wrap the `Requester` constructor with a function which provides sensible
defaults, and use it throughout
2016-07-26 16:46:53 +01:00
Mark Haines
c824b29e77 Check if the user is banned when handling 3pid invites 2016-07-26 16:39:14 +01:00
Richard van der Hoff
33d7776473 Fix typo 2016-07-26 13:32:15 +01:00
Richard van der Hoff
9ad8d9b17c Merge branch 'develop' into rav/delete_refreshtoken_on_delete_device 2016-07-26 13:29:46 +01:00
Richard van der Hoff
5b1825ba5b Merge pull request #951 from matrix-org/rav/flake8
Fix flake8 noise
2016-07-26 13:27:54 +01:00
Mark Haines
9c4cf83259 Merge pull request #948 from matrix-org/markjh/auth_fixes
Don't add rejections to the state_group, persist all rejections
2016-07-26 13:22:57 +01:00
Richard van der Hoff
05e7e5e972 Fix flake8 violation
Apparently flake8 v3 puts the error on a different line to v2. Easiest way to
make sure that happens is by putting the whole statement on one line :)
2016-07-26 11:59:08 +01:00
Richard van der Hoff
db4f823d34 Fix flake8 configuration
Apparently flake8 v3 doesn't like trailing comments on config settings.

Also remove the pep8 config, which didn't work (because it was missing W503)
and duplicated the flake8 config. We don't use pep8 on its own, so the config
was duplicative.
2016-07-26 11:49:40 +01:00
Richard van der Hoff
8e02494166 Delete refresh tokens when deleting devices 2016-07-26 11:10:37 +01:00
Mark Haines
a6f06ce3e2 Fix how push_actions are redacted. 2016-07-26 11:05:39 +01:00
David Baker
d34e9f93b7 Merge pull request #949 from matrix-org/rav/update_devices
Implement updates and deletes for devices
2016-07-26 10:49:55 +01:00
Mark Haines
efeb6176c1 Don't add rejected events if we've seen them befrore. Add some comments to explain what the code is doing mechanically 2016-07-26 10:49:52 +01:00
Matthew Hodgson
1a54513cf1 federation doesn't work over ipv6 yet thanks to twisted 2016-07-26 10:09:37 +02:00
Matthew Hodgson
242c52d607 typo 2016-07-26 10:09:25 +02:00
Richard van der Hoff
012b4c1913 Implement updating devices
You can update the displayname of devices now.
2016-07-26 07:35:48 +01:00
Richard van der Hoff
436bffd15f Implement deleting devices 2016-07-26 07:35:48 +01:00
Mark Haines
1b3c3e6d68 Only update the events and event_json tables for rejected events 2016-07-25 18:44:30 +01:00
Richard van der Hoff
33d08e8433 Log when adding listeners 2016-07-25 17:22:15 +01:00
Mark Haines
8f7f4cb92b Don't add the events to forward extremities if the event is rejected 2016-07-25 17:13:37 +01:00
Mark Haines
2623cec874 Don't add rejections to the state_group, persist all rejections 2016-07-25 16:12:16 +01:00
David Baker
4fcdf7b4b2 Merge pull request #946 from matrix-org/dbkr/log_recaptcha_hostname
Log the hostname the reCAPTCHA was completed on
2016-07-25 16:10:39 +01:00
Mark Haines
955ef1f06c fix: defer.returnValue takes one argument 2016-07-25 16:04:45 +01:00
Richard van der Hoff
2ee4c9ee02 background updates: fix assert again 2016-07-25 16:01:46 +01:00
Richard van der Hoff
9dbd903f41 background updates: Fix assertion to do something 2016-07-25 14:05:23 +01:00
Richard van der Hoff
bf3de7b90b Merge pull request #945 from matrix-org/rav/background_reindex
Create index on user_ips in the background
2016-07-25 14:04:05 +01:00
Richard van der Hoff
e73ad8de3b Merge pull request #947 from matrix-org/rav/unittest_logging
Slightly saner logging for unittests
2016-07-25 13:44:46 +01:00
Richard van der Hoff
42f4feb2b7 PEP8 2016-07-25 12:25:06 +01:00
Richard van der Hoff
f16f0e169d Slightly saner logging for unittests
1. Give the handler used for logging in unit tests a formatter, so that the
output is slightly more meaningful

2. Log some synapse.storage stuff, because it's useful.
2016-07-25 12:12:47 +01:00
Richard van der Hoff
465117d7ca Fix background_update tests
A bit of a cleanup for background_updates, and make sure that the real
background updates have run before we start the unit tests, so that they don't
interfere with the tests.
2016-07-25 12:10:42 +01:00
David Baker
7ed58bb347 Use get to avoid KeyErrors 2016-07-22 17:18:50 +01:00
David Baker
dad2da7e54 Log the hostname the reCAPTCHA was completed on
This could be useful information to have in the logs. Also comment about how & why we don't verify the hostname.
2016-07-22 17:00:56 +01:00
Richard van der Hoff
363786845b PEP8 2016-07-22 13:21:07 +01:00
Richard van der Hoff
ec5717caf5 Create index on user_ips in the background
user_ips is kinda big, so really we want to add the index in the background
once we're running. Replace the schema delta with one which will do that.

I've done this in a way that's reasonably easy to reuse as there a few other
indexes I need, and I don't suppose they will be the last.
2016-07-22 13:16:39 +01:00
Erik Johnston
d26b660aa6 Cache getPeer 2016-07-21 17:38:51 +01:00
David Baker
68a92afcff Merge pull request #944 from matrix-org/rav/devices_returns_list
make /devices return a list
2016-07-21 17:00:26 +01:00
Richard van der Hoff
55abbe1850 make /devices return a list
Turns out I specced this to return a list of devices rather than a dict of them
2016-07-21 15:57:28 +01:00
David Baker
2c28e25bda Merge pull request #943 from matrix-org/rav/get_device_api
Implement GET /device/{deviceId}
2016-07-21 13:41:42 +01:00
David Baker
1e6e370b76 Merge pull request #942 from matrix-org/rav/fix_register_deviceid
Preserve device_id from first call to /register
2016-07-21 13:16:31 +01:00
Richard van der Hoff
1c3c202b96 Fix PEP8 errors 2016-07-21 13:15:15 +01:00
Richard van der Hoff
406f7aa0f6 Implement GET /device/{deviceId} 2016-07-21 12:00:29 +01:00
Richard van der Hoff
34f56b40fd Merge branch 'rav/get_devices_api' into develop 2016-07-21 11:59:30 +01:00
Richard van der Hoff
c445f5fec7 storage/client_ips: remove some dead code 2016-07-21 11:58:47 +01:00
David Baker
44adde498e Merge pull request #939 from matrix-org/rav/get_devices_api
GET /devices endpoint
2016-07-21 11:54:47 +01:00
Erik Johnston
cf94a78872 Set host not path 2016-07-21 11:45:53 +01:00
Richard van der Hoff
1a64dffb00 Preserve device_id from first call to /register
device_id may only be passed in the first call to /register, so make sure we
fish it out of the register `params` rather than the body of the final call.
2016-07-21 11:34:16 +01:00
Erik Johnston
081e5d55e6 Send the correct host header when fetching keys 2016-07-21 11:14:54 +01:00
Richard van der Hoff
40a1c96617 Fix PEP8 errors 2016-07-20 18:06:28 +01:00
Richard van der Hoff
7314bf4682 Merge branch 'develop' into rav/get_devices_api
(pick up PR #938 in the hope of fixing the UTs)
2016-07-20 17:40:00 +01:00
Richard van der Hoff
e9e3eaa67d Merge pull request #938 from matrix-org/rav/add_device_id_to_client_ips
Record device_id in client_ips
2016-07-20 17:38:45 +01:00
Erik Johnston
d36b1d849d Don't explode if we have no snapshots yet 2016-07-20 16:59:52 +01:00
David Baker
742056be0d Merge pull request #937 from matrix-org/rav/register_device_on_register
Register a device_id in the /v2/register flow.
2016-07-20 16:51:27 +01:00
Richard van der Hoff
bc8f265f0a GET /devices endpoint
implement a GET /devices endpoint which lists all of the user's devices.

It also returns the last IP where we saw that device, so there is some dancing
to fish that out of the user_ips table.
2016-07-20 16:42:32 +01:00
Richard van der Hoff
ec041b335e Record device_id in client_ips
Record the device_id when we add a client ip; it's somewhat redundant as we
could get it via the access_token, but it will make querying rather easier.
2016-07-20 16:41:03 +01:00
Richard van der Hoff
053e83dafb More doc-comments
Fix some more comments on some things
2016-07-20 16:40:28 +01:00
Richard van der Hoff
b97a1356b1 Register a device_id in the /v2/register flow.
This doesn't cover *all* of the registration flows, but it does cover the most
common ones: in particular: shared_secret registration, appservice
registration, and normal user/pass registration.

Pull device_id from the registration parameters. Register the device in the
devices table. Associate the device with the returned access and refresh
tokens. Profit.
2016-07-20 16:38:27 +01:00
Erik Johnston
b73dc0ef4d Merge pull request #936 from matrix-org/erikj/log_rss
Add metrics for psutil derived memory usage
2016-07-20 16:32:38 +01:00
Erik Johnston
499e3281e6 Make jenkins install deps on unit tests 2016-07-20 16:09:59 +01:00
Erik Johnston
66868119dc Add metrics for psutil derived memory usage 2016-07-20 16:00:21 +01:00
Erik Johnston
aba0b2a39b Merge pull request #935 from matrix-org/erikj/backfill_notifs
Don't notify pusher pool for backfilled events
2016-07-20 13:39:16 +01:00
Erik Johnston
57dca35692 Don't notify pusher pool for backfilled events 2016-07-20 13:25:06 +01:00
Richard van der Hoff
c68518dfbb Merge pull request #933 from matrix-org/rav/type_annotations
Type annotations
2016-07-20 12:26:32 +01:00
David Baker
e967bc86e7 Merge pull request #932 from matrix-org/rav/register_refactor
Further registration refactoring
2016-07-20 11:03:33 +01:00
Erik Johnston
1e2a7f18a1 Merge pull request #922 from matrix-org/erikj/file_api2
Feature: Add filter to /messages. Add 'contains_url' to filter.
2016-07-20 10:40:48 +01:00
Erik Johnston
f91faf09b3 Comment 2016-07-20 10:18:09 +01:00
Richard van der Hoff
4430b1ceb3 MANIFEST.in: Add *.pyi 2016-07-19 19:01:20 +01:00
Richard van der Hoff
3413f1e284 Type annotations
Add some type annotations to help PyCharm (in particular) to figure out the
types of a bunch of things.
2016-07-19 18:56:16 +01:00
Richard van der Hoff
40cbffb2d2 Further registration refactoring
* `RegistrationHandler.appservice_register` no longer issues an access token:
  instead it is left for the caller to do it. (There are two of these, one in
  `synapse/rest/client/v1/register.py`, which now simply calls
  `AuthHandler.issue_access_token`, and the other in
  `synapse/rest/client/v2_alpha/register.py`, which is covered below).

* In `synapse/rest/client/v2_alpha/register.py`, move the generation of
  access_tokens into `_create_registration_details`. This means that the normal
  flow no longer needs to call `AuthHandler.issue_access_token`; the
  shared-secret flow can tell `RegistrationHandler.register` not to generate a
  token; and the appservice flow continues to work despite the above change.
2016-07-19 18:46:19 +01:00
David Baker
b9e997f561 Merge pull request #931 from matrix-org/rav/refactor_register
rest/client/v2_alpha/register.py: Refactor flow somewhat.
2016-07-19 16:13:45 +01:00
Richard van der Hoff
9a7a77a22a Merge pull request #929 from matrix-org/rav/support_deviceid_in_login
Add device_id support to /login
2016-07-19 15:53:04 +01:00
Richard van der Hoff
8f6281ab0c Don't bind email unless threepid contains expected fields 2016-07-19 15:50:01 +01:00
Richard van der Hoff
0da0d0a29d rest/client/v2_alpha/register.py: Refactor flow somewhat.
This is meant to be an *almost* non-functional change, with the exception that
it fixes what looks a lot like a bug in that it only calls
`auth_handler.add_threepid` and `add_pusher` once instead of three times.

The idea is to move the generation of the `access_token` out of
`registration_handler.register`, because `access_token`s now require a
device_id, and we only want to generate a device_id once registration has been
successful.
2016-07-19 13:12:22 +01:00
Richard van der Hoff
022b9176fe schema fix
device_id should be text, not bigint.
2016-07-19 11:44:05 +01:00
Mark Haines
0c62c958fd Merge pull request #930 from matrix-org/markjh/handlers
Update docstring on Handlers.
2016-07-19 10:48:58 +01:00
Mark Haines
c41d52a042 Summary line 2016-07-19 10:28:27 +01:00
Mark Haines
7e554aac86 Update docstring on Handlers.
To indicate it is deprecated.
2016-07-19 10:20:58 +01:00
Richard van der Hoff
f863a52cea Add device_id support to /login
Add a 'devices' table to the storage, as well as a 'device_id' column to
refresh_tokens.

Allow the client to pass a device_id, and initial_device_display_name, to
/login. If login is successful, then register the device in the devices table
if it wasn't known already. If no device_id was supplied, make one up.

Associate the device_id with the access token and refresh token, so that we can
get at it again later. Ensure that the device_id is copied from the refresh
token to the access_token when the token is refreshed.
2016-07-18 16:39:44 +01:00
Richard van der Hoff
93efcb8526 Merge pull request #928 from matrix-org/rav/refactor_login
Refactor login flow
2016-07-18 16:12:35 +01:00
Richard van der Hoff
dcfd71aa4c Refactor login flow
Make sure that we have the canonical user_id *before* calling
get_login_tuple_for_user_id.

Replace login_with_password with a method which just validates the password,
and have the caller call get_login_tuple_for_user_id. This brings the password
flow into line with the other flows, and will give us a place to register the
device_id if necessary.
2016-07-18 15:23:54 +01:00
Erik Johnston
fca90b3445 Merge pull request #924 from matrix-org/erikj/purge_history
Fix /purge_history bug
2016-07-18 15:11:11 +01:00
Mark Haines
a292454aa1 Merge pull request #925 from matrix-org/markjh/auth_fix
Fix 500 ISE when sending alias event without a state_key
2016-07-18 15:04:47 +01:00
Erik Johnston
4f81edbd4f Merge pull request #927 from Half-Shot/develop
Fall back to 'username' if 'user' is not given for appservice registration.
2016-07-18 10:44:56 +01:00
Richard van der Hoff
6344db659f Fix a doc-comment
The `store` in a handler is a generic DataStore, not just an events.StateStore.
2016-07-18 09:48:10 +01:00
Will Hunt
511a52afc8 Use body.get to check for 'user' 2016-07-16 18:44:08 +01:00
Will Hunt
e885e2a623 Fall back to 'username' if 'user' is not given for appservice reg. 2016-07-16 18:33:48 +01:00
Mark Haines
d137e03231 Fix 500 ISE when sending alias event without a state_key 2016-07-15 18:58:25 +01:00
Erik Johnston
f52565de50 Fix /purge_history bug
This was caused by trying to insert duplicate backward extremeties
2016-07-15 14:23:15 +01:00
Erik Johnston
a2d288c6a9 Merge pull request #923 from matrix-org/erikj/purge_history
Various purge_history fixes
2016-07-15 13:23:29 +01:00
Erik Johnston
bd7c51921d Merge pull request #919 from matrix-org/erikj/auth_fix
Various auth.py fixes.
2016-07-15 11:38:33 +01:00
Erik Johnston
978fa53cc2 Pull out min stream_ordering from ex_outlier_stream 2016-07-15 10:22:30 +01:00
Erik Johnston
eec9609e96 event_backwards_extremeties may not be empty 2016-07-15 10:22:09 +01:00
Erik Johnston
9e1b43bcbf Comment 2016-07-15 09:29:54 +01:00
Erik Johnston
a3036ac37e Merge pull request #921 from matrix-org/erikj/account_deactivate
Feature: Add an /account/deactivate endpoint
2016-07-14 17:25:15 +01:00
Erik Johnston
ebdafd8114 Check sender signed event 2016-07-14 17:03:24 +01:00
Erik Johnston
a98d215204 Add filter param to /messages API 2016-07-14 16:30:56 +01:00
Erik Johnston
d554ca5e1d Add support for filters in paginate_room_events 2016-07-14 15:59:04 +01:00
Erik Johnston
209e04fa11 Merge pull request #918 from negzi/bugfix_for_token_expiry
Bug fix: expire invalid access tokens
2016-07-14 15:51:52 +01:00
Erik Johnston
e5142f65a6 Add 'contains_url' to filter 2016-07-14 15:35:48 +01:00
Erik Johnston
b64aa6d687 Add sender and contains_url field to events table 2016-07-14 15:35:43 +01:00
Erik Johnston
848d3bf2e1 Add hs object 2016-07-14 10:25:52 +01:00
Erik Johnston
b55c770271 Only accept password auth 2016-07-14 10:00:38 +01:00
Erik Johnston
d543b72562 Add an /account/deactivate endpoint 2016-07-14 09:56:53 +01:00
Negar Fazeli
0136a522b1 Bug fix: expire invalid access tokens 2016-07-13 15:00:37 +02:00
Erik Johnston
2cb758ac75 Check if alias event's state_key matches sender's domain 2016-07-13 13:12:25 +01:00
Erik Johnston
560c71c735 Check creation event's room_id domain matches sender's 2016-07-13 13:07:19 +01:00
David Baker
a37ee2293c Merge pull request #915 from matrix-org/dbkr/more_requesttokens
Add requestToken endpoints
2016-07-13 11:51:46 +01:00
David Baker
c55ad2e375 be more pythonic 2016-07-12 14:15:10 +01:00
David Baker
aaa9d9f0e1 on_OPTIONS isn't neccessary 2016-07-12 14:13:14 +01:00
David Baker
75fa7f6b3c Remove other debug logging 2016-07-12 14:08:57 +01:00
David Baker
a5db0026ed Separate out requestTokens to separate handlers 2016-07-11 09:57:07 +01:00
David Baker
9c491366c5 Oops, remove debug logging 2016-07-11 09:07:40 +01:00
David Baker
385aec4010 Implement https://github.com/matrix-org/matrix-doc/pull/346/files 2016-07-08 17:42:48 +01:00
79 changed files with 2643 additions and 897 deletions

View File

@@ -1,3 +1,67 @@
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
This release contains significant security bug fixes regarding authenticating
events received over federation. Please upgrade.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08)
==========================================

View File

@@ -14,6 +14,7 @@ recursive-include docs *
recursive-include res *
recursive-include scripts *
recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py
recursive-include synapse/static *.css

View File

@@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs:
1) Use the machine's own hostname as available on public DNS in the form of
its A or AAAA records. This is easier to set up initially, perhaps for
its A records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV

View File

@@ -22,4 +22,8 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
tox -e py27

View File

@@ -16,7 +16,5 @@ ignore =
[flake8]
max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.16.1-r1"
__version__ = "0.17.0-rc1"

View File

@@ -13,22 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64
import logging
import pymacaroons
import synapse.types
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -63,7 +63,7 @@ class Auth(object):
"user_id = ",
])
def check(self, event, auth_events):
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
Args:
@@ -79,6 +79,13 @@ class Auth(object):
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender)
# Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
@@ -86,6 +93,12 @@ class Auth(object):
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
@@ -108,6 +121,22 @@ class Auth(object):
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
logger.debug(
@@ -347,6 +376,10 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
@@ -537,9 +570,7 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
tuple of:
UserID (str)
Access token ID (str)
defer.Deferred: resolves to a ``synapse.types.Requester`` object
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -548,9 +579,7 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args)
if user_id:
request.authenticated_entity = user_id
defer.returnValue(
Requester(UserID.from_string(user_id), "", False)
)
defer.returnValue(synapse.types.create_requester(user_id))
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token, rights)
@@ -558,6 +587,10 @@ class Auth(object):
token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
"User-Agent",
@@ -569,7 +602,8 @@ class Auth(object):
user=user,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent
user_agent=user_agent,
device_id=device_id,
)
if is_guest and not allow_guest:
@@ -579,7 +613,8 @@ class Auth(object):
request.authenticated_entity = user.to_string()
defer.returnValue(Requester(user, token_id, is_guest))
defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -629,7 +664,10 @@ class Auth(object):
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret)
@defer.inlineCallbacks
@@ -664,6 +702,7 @@ class Auth(object):
"user": user,
"is_guest": True,
"token_id": None,
"device_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
@@ -671,13 +710,20 @@ class Auth(object):
"user": user,
"is_guest": False,
"token_id": None,
"device_id": None,
}
else:
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
# This codepath exists for several reasons:
# * so that we can actually return a token ID, which is used
# in some parts of the schema (where we probably ought to
# use device IDs instead)
# * the only way we currently have to invalidate an
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
@@ -751,10 +797,14 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"device_id": ret.get("device_id"),
}
defer.returnValue(user_info)

View File

@@ -43,6 +43,7 @@ class Codes(object):
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"

View File

@@ -191,6 +191,17 @@ class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event):
"""Checks whether the filter matches the given event.
@@ -209,9 +220,10 @@ class Filter(object):
event.get("room_id", None),
sender,
event.get("type", None),
"url" in event.get("content", {})
)
def check_fields(self, room_id, sender, event_type):
def check_fields(self, room_id, sender, event_type, contains_url):
"""Checks whether the filter matches the given event fields.
Returns:
@@ -225,15 +237,20 @@ class Filter(object):
for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,)
disallowed_values = self.filter_json.get(not_name, [])
disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)):
return False
allowed_values = self.filter_json.get(name, None)
allowed_values = getattr(self, name)
if allowed_values is not None:
if not any(map(match_func, allowed_values)):
return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True
def filter_rooms(self, room_ids):

View File

@@ -16,13 +16,11 @@
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, MissingRequirementError
) # NOQA
from synapse import python_dependencies # noqa: E402
try:
check_requirements()
except MissingRequirementError as e:
python_dependencies.check_requirements()
except python_dependencies.MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",

View File

@@ -51,6 +51,7 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
@@ -335,6 +336,8 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs)
reactor.callWhenRunning(start)
return hs

View File

@@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self):
self.remote_key = defer.Deferred()
self.host = None
self._peer = None
def connectionMade(self):
self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self._peer)
self.sendCommand(b"GET", self.path)
if self.host:
self.sendHeader(b"Host", self.host)
@@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel()
def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host)
logger.debug(
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
@@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
protocol.host = self.host
return protocol

View File

@@ -44,7 +44,21 @@ import logging
logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
VerifyKeyRequest = namedtuple("VerifyRequest", (
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class Keyring(object):
@@ -74,39 +88,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
verify_requests = []
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
deferreds[group_id] = defer.fail(SynapseError(
deferred = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
deferreds[group_id] = defer.Deferred()
deferred = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
verify_requests.append(verify_request)
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
def handle_key_deferred(verify_request):
server_name = verify_request.server_name
try:
_, _, key_id, verify_key = yield deferred
_, key_id, verify_key = yield verify_request.deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
@@ -128,7 +135,7 @@ class Keyring(object):
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
json_object = verify_request.json_object
try:
verify_signed_json(json_object, server_name, verify_key)
@@ -157,36 +164,34 @@ class Keyring(object):
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
lambda _: self.get_server_verify_keys(verify_requests)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
server_to_request_ids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
def remove_deferreds(res, server_name, verify_request):
request_id = id(verify_request)
server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
for verify_request in verify_requests:
server_name = verify_request.server_name
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
preserve_context_over_fn(
handle_key_deferred,
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
preserve_context_over_fn(handle_key_deferred, verify_request)
for verify_request in verify_requests
]
@defer.inlineCallbacks
@@ -220,7 +225,7 @@ class Keyring(object):
d.addBoth(rm, server_name)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@@ -237,62 +242,64 @@ class Keyring(object):
merged_results = {}
missing_keys = {}
for group in group_id_to_group.values():
missing_keys.setdefault(group.server_name, set()).update(
group.key_ids
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
verify_request.deferred.callback((
server_name,
key_id,
merged_results[group.server_name][key_id],
result_keys[key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_groups:
if not missing_keys:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
for verify_request in requests_missing_keys.values():
verify_request.deferred.errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
verify_request.server_name, verify_request.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
for verify_request in verify_requests:
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
@@ -447,7 +454,7 @@ class Keyring(object):
)
processed_response = yield self.process_v2_response(
perspective_name, response
perspective_name, response, only_from_server=False
)
for server_name, response_keys in processed_response.items():
@@ -527,7 +534,7 @@ class Keyring(object):
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
requested_ids=[]):
requested_ids=[], only_from_server=True):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -551,6 +558,13 @@ class Keyring(object):
results = {}
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise ValueError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(

View File

@@ -31,10 +31,21 @@ from .search import SearchHandler
class Handlers(object):
""" A collection of all the event handlers.
""" Deprecated. A collection of handlers.
There's no need to lazily create these; we'll just make them all eagerly
at construction time.
At some point most of the classes whose name ended "Handler" were
accessed through this class.
However this makes it painful to unit test the handlers and to run cut
down versions of synapse that only use specific handlers because using a
single handler required creating all of the handlers. So some of the
handlers have been lifted out of the Handlers object and are now accessed
directly through the homeserver object itself.
Any new handlers should follow the new pattern of being accessed through
the homeserver object and should not be added to the Handlers object.
The remaining handlers should be moved out of the handlers object.
"""
def __init__(self, hs):

View File

@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
import synapse.types
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester
import logging
from synapse.api.errors import LimitExceededError
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -31,11 +31,15 @@ class BaseHandler(object):
Common base class for the event handlers.
Attributes:
store (synapse.storage.events.StateStore):
store (synapse.storage.DataStore):
state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
@@ -120,7 +124,8 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
requester = Requester(target_user, "", True)
requester = synapse.types.create_requester(
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,

View File

@@ -45,6 +45,10 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
@@ -73,6 +77,7 @@ class AuthHandler(BaseHandler):
self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
@@ -230,7 +235,6 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
@@ -240,11 +244,7 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
return self._check_password(user_id, password)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -280,8 +280,17 @@ class AuthHandler(BaseHandler):
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
if 'success' in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
@@ -348,67 +357,84 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
def validate_password_login(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
user_id (str): complete @user:id
password (str): Password
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
defer.Deferred: (str) canonical user id
Raises:
StoreError if there was a problem storing the token.
StoreError if there was a problem accessing the database
LoginError if there was an authentication problem.
"""
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
return self._check_password(user_id, password)
@defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id):
def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
"""
Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args:
user_id (str): User ID
user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
# really don't want is active access_tokens without a record of the
# device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks
def does_user_exist(self, user_id):
def check_user_exists(self, user_id):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
(str) user_id: complete @user:id
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches
"""
try:
yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True)
res = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(res[0])
except LoginError:
defer.returnValue(False)
defer.returnValue(None)
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -438,27 +464,45 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_password(self, user_id, password):
"""
"""Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
True if the user_id successfully authenticated
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
defer.returnValue(True)
defer.returnValue(user_id)
valid_local_password = yield self._check_local_password(user_id, password)
if valid_local_password:
defer.returnValue(True)
defer.returnValue(False)
result = yield self._check_local_password(user_id, password)
defer.returnValue(result)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(self.validate_hash(password, password_hash))
except LoginError:
defer.returnValue(False)
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will throw if there are
multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
@@ -570,7 +614,7 @@ class AuthHandler(BaseHandler):
)
# check for existing account, if none exists, create one
if not (yield self.does_user_exist(user_id)):
if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
@@ -626,23 +670,26 @@ class AuthHandler(BaseHandler):
defer.returnValue(False)
@defer.inlineCallbacks
def issue_access_token(self, user_id):
def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
yield self.store.add_access_token_to_user(user_id, access_token,
device_id)
defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id):
def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None):
def generate_access_token(self, user_id, extra_caveats=None,
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000)
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)

181
synapse/handlers/device.py Normal file
View File

@@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api import errors
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string_with_symbols(16)
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys())
)
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),)
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})

View File

@@ -688,7 +688,9 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e)
raise e
self.auth.check(event, auth_events=context.current_state)
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
defer.returnValue(event)
@@ -918,7 +920,9 @@ class FederationHandler(BaseHandler):
)
try:
self.auth.check(event, auth_events=context.current_state)
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
@@ -1114,11 +1118,12 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
defer.returnValue((context, event_stream_id, max_stream_id))

View File

@@ -66,7 +66,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True):
as_client_event=True, event_filter=None):
"""Get messages in a room.
Args:
@@ -75,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
@@ -129,8 +129,13 @@ class MessageHandler(BaseHandler):
room_id, max_topo
)
events, next_key = yield data_source.get_pagination_rows(
requester.user, source_config, room_id
events, next_key = yield self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
event_filter=event_filter,
)
next_token = pagin_config.from_token.copy_and_replace(
@@ -144,6 +149,9 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(),
})
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
self.store,
user_id,
@@ -164,101 +172,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def get_files(self, requester, room_id, pagin_config):
"""Get files in a room.
Args:
requester (Requester): The user requesting files.
room_id (str): The room they want files from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
as_client_event (bool): True to get events in client-server format.
Returns:
dict: Pagination API results
"""
user_id = requester.user.to_string()
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
)
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
membership, member_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id
)
if source_config.direction == 'b':
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token(
room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
events, next_key = yield self.store.paginate_room_file_events(
room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield filter_events_for_client(
self.store,
user_id,
events,
is_peeking=(member_event_id is None),
)
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now)
for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
"""

View File

@@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester
from synapse.types import UserID
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
@@ -165,7 +165,9 @@ class ProfileHandler(BaseHandler):
try:
# Assume the user isn't a guest because we don't let guests set
# profile or avatar data.
requester = Requester(user, "", False)
# XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership(
requester,
user,

View File

@@ -14,18 +14,19 @@
# limitations under the License.
"""Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer
from synapse.types import UserID, Requester
import synapse.types
from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
)
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
import logging
import urllib
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -99,8 +100,13 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be generated.
password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -196,15 +202,13 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
password_hash="",
appservice_id=service_id,
create_profile_with_localpart=user.localpart,
)
defer.returnValue((user_id, token))
defer.returnValue(user_id)
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
@@ -360,7 +364,7 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data)
@defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds,
def get_or_create_user(self, localpart, displayname, duration_in_ms,
password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -390,8 +394,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.auth_handler().generate_short_term_login_token(
user_id, duration_seconds)
token = self.auth_handler().generate_access_token(
user_id, None, duration_in_ms)
if need_register:
yield self.store.register(
@@ -407,8 +411,9 @@ class RegistrationHandler(BaseHandler):
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname(
user, Requester(user, token, False), displayname
user, requester, displayname
)
defer.returnValue((user_id, token))

View File

@@ -14,24 +14,22 @@
# limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer
from unpaddedbase64 import decode_base64
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
import synapse.types
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)

View File

@@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback)
)

View File

@@ -27,7 +27,8 @@ import gc
from twisted.internet import reactor
from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
MemoryUsageMetric,
)
@@ -66,6 +67,12 @@ class Metrics(object):
return self._register(CacheMetric, *args, **kwargs)
def register_memory_metrics(hs):
metric = MemoryUsageMetric(hs)
all_metrics.append(metric)
return metric
def get_metrics_for(pkg_name):
""" Returns a Metrics instance for conveniently creating metrics
namespaced with the given name prefix. """

View File

@@ -16,6 +16,8 @@
from itertools import chain
import psutil
# TODO(paul): I can't believe Python doesn't have one of these
def map_concat(func, items):
@@ -153,3 +155,42 @@ class CacheMetric(object):
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
]
class MemoryUsageMetric(object):
"""Keeps track of the current memory usage, using psutil.
The class will keep the current min/max/sum/counts of rss over the last
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
"""
UPDATE_HZ = 2 # number of times to get memory per second
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
def __init__(self, hs):
clock = hs.get_clock()
self.memory_snapshots = []
self.process = psutil.Process()
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
def _update_curr_values(self):
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
self.memory_snapshots.append(self.process.memory_info().rss)
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
def render(self):
if not self.memory_snapshots:
return []
max_rss = max(self.memory_snapshots)
min_rss = min(self.memory_snapshots)
sum_rss = sum(self.memory_snapshots)
len_rss = len(self.memory_snapshots)
return [
"process_psutil_rss:max %d" % max_rss,
"process_psutil_rss:min %d" % min_rss,
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]

View File

@@ -54,7 +54,7 @@ def get_context_for_event(state_handler, ev, user_id):
room_state = yield state_handler.get_current_state(ev.room_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.namer, an alias or
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
name = calculate_room_name(
room_state, user_id, fallback_to_single_member=False

View File

@@ -36,6 +36,7 @@ REQUIREMENTS = {
"blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"psutil>=2.0.0": ["psutil>=2.0.0"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {

View File

@@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import (
account_data,
report_event,
openid,
devices,
)
from synapse.http.server import JsonResource
@@ -90,3 +91,4 @@ class ClientRestResource(JsonResource):
account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)

View File

@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()

View File

@@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
def on_GET(self, request):
flows = []
@@ -145,15 +146,23 @@ class LoginRestServlet(ClientV1RestServlet):
).to_string()
auth_handler = self.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
user_id = yield auth_handler.validate_password_login(
user_id=user_id,
password=login_submission["password"])
password=login_submission["password"],
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
defer.returnValue((200, result))
@@ -165,14 +174,19 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
defer.returnValue((200, result))
@@ -196,13 +210,15 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id
)
)
result = {
"user_id": user_id, # may have changed
"user_id": registered_user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
@@ -245,18 +261,27 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
device_id = yield self._register_device(
registered_user_id, login_submission
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
"user_id": registered_user_id,
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
@@ -295,6 +320,26 @@ class LoginRestServlet(ClientV1RestServlet):
return (user, attributes)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class SAML2RestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/saml2", releases=())
@@ -414,13 +459,13 @@ class CasTicketServlet(ClientV1RestServlet):
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(user_id)
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)

View File

@@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as:
# self.sessions = {
@@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
(user_id, token) = yield handler.appservice_register(
user_id = yield handler.appservice_register(
user_localpart, as_token
)
token = yield self.auth_handler.issue_access_token(user_id)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
@@ -429,7 +435,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
duration_seconds=duration_seconds,
duration_in_ms=(duration_seconds * 1000),
password_hash=password_hash
)

View File

@@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request
import logging
import urllib
import ujson as json
logger = logging.getLogger(__name__)
@@ -327,31 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
as_client_event=as_client_event
)
defer.returnValue((200, msgs))
class RoomFileListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/files$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(
request, default_limit=10, default_dir='b',
)
handler = self.handlers.message_handler
msgs = yield handler.get_files(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
as_client_event=as_client_event,
event_filter=event_filter,
)
defer.returnValue((200, msgs))
@@ -686,7 +676,6 @@ def register_servlets(hs, http_server):
RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
RoomFileListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)

View File

@@ -25,7 +25,9 @@ import logging
logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,)):
def client_v2_patterns(path_regex, releases=(0,),
v2_alpha=True,
unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
@@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)):
Returns:
SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
patterns = []
if v2_alpha:
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
if unstable:
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))

View File

@@ -28,8 +28,40 @@ import logging
logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password")
PATTERNS = client_v2_patterns("/account/password$")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
@@ -89,8 +121,83 @@ class PasswordRestServlet(RestServlet):
return 200, {}
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
if user_id != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
defer.returnValue((200, {}))
class ThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid")
PATTERNS = client_v2_patterns("/account/3pid$")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
@@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet):
def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)

View File

@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DevicesRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
device = yield self.device_handler.get_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)

View File

@@ -13,24 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
import synapse.api.errors
import synapse.server
import synapse.types
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet):
"""
POST /keys/upload/<device_id> HTTP/1.1
POST /keys/upload HTTP/1.1
Content-Type: application/json
{
@@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
},
}
"""
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
releases=())
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
body = parse_json_object_from_request(request)
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if (requester.device_id is not None and
device_id != requester.device_id):
logger.warning("Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id, device_id)
else:
device_id = requester.device_id
if device_id is None:
raise synapse.api.errors.SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
@@ -102,13 +125,14 @@ class KeyUploadServlet(RestServlet):
user_id, device_id, time_now, key_list
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(
user_id, device_id, "unknown device"
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))

View File

@@ -41,17 +41,59 @@ else:
logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register")
class RegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,)
)
if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret)
body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these
@@ -104,11 +142,12 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
if isinstance(body.get("user"), basestring):
desired_username = body["user"]
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)
desired_username = body.get("user", desired_username)
if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0], body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
@@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet):
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"]
desired_username, desired_password, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
@@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session",
registered_user_id
)
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
refresh_token = yield self.auth_handler.issue_refresh_token(
registered_user_id
# don't re-register the email address
add_email = False
else:
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
generate_token=False,
)
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
add_email = True
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
result = yield self._create_registration_details(
registered_user_id, params
)
# remember that we've now registered that user account, and with what
# user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", user_id
)
if result and LoginType.EMAIL_IDENTITY in result:
if add_email and result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
yield self._register_email_threepid(
registered_user_id, threepid, result["access_token"],
params.get("bind_email")
)
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (
self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users
):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=user_tuple["token_id"],
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")
emailThreepid = result[LoginType.EMAIL_IDENTITY]
threepid_creds = emailThreepid['threepid_creds']
logger.debug("Binding emails %s to %s" % (
emailThreepid, user_id
))
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
else:
logger.info("bind_email not specified: not binding email")
result = yield self._create_registration_details(user_id, token)
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
@defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register(
def _do_appservice_registration(self, username, as_token, body):
user_id = yield self.registration_handler.appservice_register(
username, as_token
)
defer.returnValue((yield self._create_registration_details(user_id, token)))
defer.returnValue((yield self._create_registration_details(user_id, body)))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac):
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet):
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(mac)
got_mac = str(body["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
@@ -290,43 +281,134 @@ class RegisterRestServlet(RestServlet):
403, "HMAC incorrect",
)
(user_id, token) = yield self.registration_handler.register(
localpart=username, password=password
(user_id, _) = yield self.registration_handler.register(
localpart=username, password=password, generate_token=False,
)
defer.returnValue((yield self._create_registration_details(user_id, token)))
result = yield self._create_registration_details(user_id, body)
defer.returnValue(result)
@defer.inlineCallbacks
def _create_registration_details(self, user_id, token):
refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
def _register_email_threepid(self, user_id, threepid, token, bind_email):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
HS config
Also optionally binds emails to the given user_id on the identity server
Args:
user_id (str): id of user
threepid (object): m.login.email.identity auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
logger.info("Can't add incomplete 3pid")
defer.returnValue()
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
token_id = user_tuple["token_id"]
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if bind_email:
logger.info("bind_email specified: binding")
logger.debug("Binding emails %s to %s" % (
threepid, user_id
))
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_email not specified: not binding email")
defer.returnValue()
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token
and refresh_token.
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
device_id = yield self._register_device(user_id, params)
access_token, refresh_token = (
yield self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
"device_id": device_id,
})
@defer.inlineCallbacks
def onEmailTokenRequest(self, request):
body = parse_json_object_from_request(request)
def _register_device(self, user_id, params):
"""Register a device for a user.
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
This is called after the user's credentials have been validated, but
before the access token has been issued.
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
return device_id
@defer.inlineCallbacks
def _do_guest_registration(self):
@@ -336,7 +418,11 @@ class RegisterRestServlet(RestServlet):
generate_token=False,
make_guest=True
)
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"]
)
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, {
"user_id": user_id,
"access_token": access_token,
@@ -345,4 +431,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
RegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)

View File

@@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet):
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_auth_handler()
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id)
refresh_result = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token
)
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,

View File

@@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
"versions": ["r0.0.1"]
"versions": [
"r0.0.1",
"r0.1.0",
"r0.2.0",
]
})

View File

@@ -25,6 +25,7 @@ from twisted.enterprise import adbapi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from synapse.federation import initialize_http_replication
from synapse.handlers.device import DeviceHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.notifier import Notifier
from synapse.api.auth import Auth
@@ -92,6 +93,7 @@ class HomeServer(object):
'typing_handler',
'room_list_handler',
'auth_handler',
'device_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@@ -197,6 +199,9 @@ class HomeServer(object):
def build_auth_handler(self):
return AuthHandler(self)
def build_device_handler(self):
return DeviceHandler(self)
def build_application_service_api(self):
return ApplicationServiceApi(self)

21
synapse/server.pyi Normal file
View File

@@ -0,0 +1,21 @@
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.device
import synapse.storage
import synapse.state
class HomeServer(object):
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass
def get_datastore(self) -> synapse.storage.DataStore:
pass
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
pass
def get_handlers(self) -> synapse.handlers.Handlers:
pass
def get_state_handler(self) -> synapse.state.StateHandler:
pass

View File

@@ -379,7 +379,8 @@ class StateHandler(object):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
prev_event = event
except AuthError:
return prev_event
@@ -391,7 +392,8 @@ class StateHandler(object):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
return event
except AuthError:
pass

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from twisted.internet import defer
from synapse.storage.devices import DeviceStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
)
@@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
DeviceStore,
):
def __init__(self, db_conn, hs):
@@ -92,7 +95,8 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"

View File

@@ -597,10 +597,13 @@ class SQLBaseStore(object):
more rows, returning the result as a list of dicts.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the rows with,
or None to not apply a WHERE clause.
retcols : list of strings giving the names of the columns to return
table (str): the table name
keyvalues (dict[str, Any] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
@@ -615,9 +618,11 @@ class SQLBaseStore(object):
Args:
txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
table (str): the table name
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s" % (

View File

@@ -14,6 +14,7 @@
# limitations under the License.
from ._base import SQLBaseStore
from . import engines
from twisted.internet import defer
@@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def start_doing_background_updates(self):
while True:
if self._background_update_timer is not None:
return
assert self._background_update_timer is None, \
"background updates already running"
logger.info("Starting background schema updates")
while True:
sleep = defer.Deferred()
self._background_update_timer = self._clock.call_later(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
@@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_timer = None
try:
result = yield self.do_background_update(
result = yield self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except:
logger.exception("Error doing update")
if result is None:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
)
return
else:
if result is None:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
)
defer.returnValue(None)
@defer.inlineCallbacks
def do_background_update(self, desired_duration_ms):
"""Does some amount of work on a background update
def do_next_background_update(self, desired_duration_ms):
"""Does some amount of work on the next queued background update
Args:
desired_duration_ms(float): How long we want to spend
updating.
@@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue.append(update['update_name'])
if not self._background_update_queue:
# no work left to do
defer.returnValue(None)
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
res = yield self._do_background_update(update_name, desired_duration_ms)
defer.returnValue(res)
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
logger.info("Starting update batch on background update '%s'",
update_name)
update_handler = self._background_update_handlers[update_name]
performance = self._background_update_performance.get(update_name)
@@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
"""
self._background_update_handlers[update_name] = update_handler
def register_background_index_update(self, update_name, index_name,
table, columns):
"""Helper for store classes to do a background index addition
To use:
1. use a schema delta file to add a background update. Example:
INSERT INTO background_updates (update_name, progress_json) VALUES
('my_new_index', '{}');
2. In the Store constructor, call this method
Args:
update_name (str): update_name to register for
index_name (str): name of index to add
table (str): table to add index to
columns (list[str]): columns/expressions to include in index
"""
# if this is postgres, we add the indexes concurrently. Otherwise
# we fall back to doing it inline
if isinstance(self.database_engine, engines.PostgresEngine):
conc = True
else:
conc = False
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
% {
"conc": "CONCURRENTLY" if conc else "",
"name": index_name,
"table": table,
"columns": ", ".join(columns),
}
def create_index_concurrently(conn):
conn.rollback()
# postgres insists on autocommit for the index
conn.set_session(autocommit=True)
c = conn.cursor()
c.execute(sql)
conn.set_session(autocommit=False)
def create_index(conn):
c = conn.cursor()
c.execute(sql)
@defer.inlineCallbacks
def updater(progress, batch_size):
logger.info("Adding index %s to %s", index_name, table)
if conc:
yield self.runWithConnection(create_index_concurrently)
else:
yield self.runWithConnection(create_index)
yield self._end_background_update(update_name)
defer.returnValue(1)
self.register_background_update_handler(update_name, updater)
def start_background_update(self, update_name, progress):
"""Starts a background update running.

View File

@@ -13,10 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, Cache
import logging
from twisted.internet import defer
from ._base import Cache
from . import background_updates
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
@@ -24,8 +28,7 @@ from twisted.internet import defer
LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(SQLBaseStore):
class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
@@ -34,8 +37,15 @@ class ClientIpStore(SQLBaseStore):
super(ClientIpStore, self).__init__(hs)
self.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
@defer.inlineCallbacks
def insert_client_ip(self, user, access_token, ip, user_agent):
def insert_client_ip(self, user, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec())
key = (user.to_string(), access_token, ip)
@@ -59,6 +69,7 @@ class ClientIpStore(SQLBaseStore):
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"device_id": device_id,
},
values={
"last_seen": now,
@@ -66,3 +77,69 @@ class ClientIpStore(SQLBaseStore):
desc="insert_client_ip",
lock=False,
)
@defer.inlineCallbacks
def get_last_client_ip_by_device(self, devices):
"""For each device_id listed, give the user_ip it was last seen on
Args:
devices (iterable[(str, str)]): list of (user_id, device_id) pairs
Returns:
defer.Deferred: resolves to a dict, where the keys
are (user_id, device_id) tuples. The values are also dicts, with
keys giving the column names
"""
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
retcols=(
"user_id",
"access_token",
"ip",
"user_agent",
"device_id",
"last_seen",
),
devices=devices
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
defer.returnValue(ret)
@classmethod
def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols):
where_clauses = []
bindings = []
for (user_id, device_id) in devices:
if device_id is None:
where_clauses.append("(user_id = ? AND device_id IS NULL)")
bindings.extend((user_id, ))
else:
where_clauses.append("(user_id = ? AND device_id = ?)")
bindings.extend((user_id, device_id))
inner_select = (
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
"WHERE %(where)s "
"GROUP BY user_id, device_id"
) % {
"where": " OR ".join(where_clauses),
}
sql = (
"SELECT %(retcols)s FROM user_ips "
"JOIN (%(inner_select)s) ips ON"
" user_ips.last_seen = ips.mls AND"
" user_ips.user_id = ips.user_id AND"
" (user_ips.device_id = ips.device_id OR"
" (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
" )"
) % {
"retcols": ",".join("user_ips." + c for c in retcols),
"inner_select": inner_select,
}
txn.execute(sql, bindings)
return cls.cursor_to_dict(txn)

137
synapse/storage/devices.py Normal file
View File

@@ -0,0 +1,137 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name,
ignore_if_known=True):
"""Ensure the given device is known; add it to the store if not
Args:
user_id (str): id of user associated with the device
device_id (str): id of device
initial_device_display_name (str): initial displayname of the
device
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
Returns:
defer.Deferred
Raises:
StoreError: if ignore_if_known is False and the device was already
known
"""
try:
yield self._simple_insert(
"devices",
values={
"user_id": user_id,
"device_id": device_id,
"display_name": initial_device_display_name
},
desc="store_device",
or_ignore=ignore_if_known,
)
except Exception as e:
logger.error("store_device with device_id=%s failed: %s",
device_id, e)
raise StoreError(500, "Problem storing device.")
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
def delete_device(self, user_id, device_id):
"""Delete a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to delete
Returns:
defer.Deferred
"""
return self._simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_device",
)
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to update
new_display_name (str|None): new displayname for device; None
to leave unchanged
Raises:
StoreError: if the device is not found
Returns:
defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
return self._simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues=updates,
desc="update_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import twisted.internet.defer
from ._base import SQLBaseStore
@@ -123,3 +125,16 @@ class EndToEndKeyStore(SQLBaseStore):
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
@twisted.internet.defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete(
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_device_keys_by_device"
)
yield self._simple_delete(
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
desc="delete_e2e_one_time_keys_by_device"
)

View File

@@ -152,6 +152,7 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
@@ -159,6 +160,10 @@ class EventsStore(SQLBaseStore):
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
self.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
self._event_persist_queue = _EventPeristenceQueue()
@@ -392,6 +397,12 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
"""Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table,
and the rejections table. Things reading from those table will need to check
whether the event was rejected.
"""
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -402,21 +413,11 @@ class EventsStore(SQLBaseStore):
event.room_id, event.internal_metadata.stream_ordering,
)
if not event.internal_metadata.is_outlier():
if not event.internal_metadata.is_outlier() and not context.rejected:
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
)
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
@@ -426,14 +427,24 @@ class EventsStore(SQLBaseStore):
),
[event.event_id for event, _ in events_and_contexts]
)
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
# Remove the events that we've seen before.
event_map = {}
to_remove = set()
for event, context in events_and_contexts:
if context.rejected:
# If the event is rejected then we don't care if the event
# was an outlier or not.
if event.event_id in have_persisted:
# If we have already seen the event then ignore it.
to_remove.add(event)
continue
# Handle the case of the list including the same event multiple
# times. The tricky thing here is when they differ by whether
# they are an outlier.
@@ -458,6 +469,12 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
# We received a copy of an event that we had already stored as
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
# insert into the state_group, state_groups_state and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json(
@@ -473,6 +490,8 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,)
)
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
@@ -494,6 +513,8 @@ class EventsStore(SQLBaseStore):
(False, event.event_id,)
)
# Update the event_backward_extremities table now that this
# event isn't an outlier any more.
self._update_extremeties(txn, [event])
events_and_contexts = [
@@ -501,38 +522,12 @@ class EventsStore(SQLBaseStore):
]
if not events_and_contexts:
# Make sure we don't pass an empty list to functions that expect to
# be storing at least one element.
return
self._store_mult_state_groups_txn(txn, events_and_contexts)
self._handle_mult_prev_events(
txn,
events=[event for event, _ in events_and_contexts],
)
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
self._store_guest_access_txn(txn, event)
self._store_room_members_txn(
txn,
[
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
backfilled=backfilled,
)
# From this point onwards the events are only events that we haven't
# seen before.
def event_dict(event):
return {
@@ -576,15 +571,51 @@ class EventsStore(SQLBaseStore):
"content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
"received_ts": self._clock.time_msec(),
"sender": event.sender,
"contains_url": (
"url" in event.content
and isinstance(event.content["url"], basestring)
),
}
for event, _ in events_and_contexts
],
)
if context.rejected:
self._store_rejections_txn(
txn, event.event_id, context.rejected
)
# Remove the rejected events from the list now that we've added them
# to the events table and the events_json table.
to_remove = set()
for event, context in events_and_contexts:
if context.rejected:
# Insert the event_id into the rejections table
self._store_rejections_txn(
txn, event.event_id, context.rejected
)
to_remove.add(event)
events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
if not events_and_contexts:
# Make sure we don't pass an empty list to functions that expect to
# be storing at least one element.
return
# From this point onwards the events are only ones that weren't rejected.
for event, context in events_and_contexts:
# Insert all the push actions into the event_push_actions table.
if context.push_actions:
self._set_push_actions_for_event_and_users_txn(
txn, event, context.push_actions
)
if event.type == EventTypes.Redaction and event.redacts is not None:
# Remove the entries in the event_push_actions table for the
# redacted event.
self._remove_push_actions_for_event_id_txn(
txn, event.room_id, event.redacts
)
self._simple_insert_many_txn(
txn,
@@ -600,6 +631,49 @@ class EventsStore(SQLBaseStore):
],
)
# Insert into the state_groups, state_groups_state, and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, events_and_contexts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
txn,
events=[event for event, _ in events_and_contexts],
)
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
# Insert into the room_names and event_search tables.
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
# Insert into the topics table and event_search table.
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
# Insert into the event_search table.
self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction:
# Insert into the redactions table.
self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
# Insert into the event_search table.
self._store_history_visibility_txn(txn, event)
elif event.type == EventTypes.GuestAccess:
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
[
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
],
backfilled=backfilled,
)
# Insert event_reference_hashes table.
self._store_event_reference_hashes_txn(
txn, [event for event, _ in events_and_contexts]
)
@@ -644,6 +718,7 @@ class EventsStore(SQLBaseStore):
],
)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
if backfilled:
@@ -656,11 +731,6 @@ class EventsStore(SQLBaseStore):
# Outlier events shouldn't clobber the current state.
continue
if context.rejected:
# If the event failed it's auth checks then it shouldn't
# clobbler the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
@@ -1115,6 +1185,78 @@ class EventsStore(SQLBaseStore):
ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret)
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
" INNER JOIN event_json USING (event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
update_rows = []
for row in rows:
try:
event_id = row[1]
event_json = json.loads(row[2])
sender = event_json["sender"]
content = event_json["content"]
contains_url = "url" in content
if contains_url:
contains_url &= isinstance(content["url"], basestring)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
update_rows.append((sender, contains_url, event_id))
sql = (
"UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
)
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
clump = update_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows)
}
self._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
defer.returnValue(result)
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -1343,7 +1485,7 @@ class EventsStore(SQLBaseStore):
# We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged
txn.execute(
"SELECT e.event_id FROM events as e"
"SELECT DISTINCT e.event_id FROM events as e"
" INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
" INNER JOIN events as e2 ON e2.event_id = ed.event_id"
" WHERE e.room_id = ? AND e.topological_ordering < ?"
@@ -1352,6 +1494,20 @@ class EventsStore(SQLBaseStore):
)
new_backwards_extrems = txn.fetchall()
txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?",
(room_id,)
)
# Update backward extremeties
txn.executemany(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[
(room_id, event_id) for event_id, in new_backwards_extrems
]
)
# Get all state groups that are only referenced by events that are
# to be deleted.
txn.execute(
@@ -1404,20 +1560,12 @@ class EventsStore(SQLBaseStore):
"event_search",
"event_signatures",
"rejections",
"event_backward_extremities",
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
to_delete
)
# Update backward extremeties
txn.executemany(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems]
)
txn.executemany(
"DELETE FROM events WHERE event_id = ?",
to_delete

View File

@@ -18,25 +18,40 @@ import re
from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
from synapse.storage import background_updates
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore):
class RegistrationStore(background_updates.BackgroundUpdateStore):
def __init__(self, hs):
super(RegistrationStore, self).__init__(hs)
self.clock = hs.get_clock()
self.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
self.register_background_index_update(
"refresh_tokens_device_index",
index_name="refresh_tokens_device_id",
table="refresh_tokens",
columns=["user_id", "device_id"],
)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token):
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
token
Raises:
StoreError if there was a problem adding this.
"""
@@ -47,18 +62,21 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
"token": token
"token": token,
"device_id": device_id,
},
desc="add_access_token_to_user",
)
@defer.inlineCallbacks
def add_refresh_token_to_user(self, user_id, token):
def add_refresh_token_to_user(self, user_id, token, device_id=None):
"""Adds a refresh token for the given user.
Args:
user_id (str): The user ID.
token (str): The new refresh token to add.
device_id (str): ID of the device to associate with the access
token
Raises:
StoreError if there was a problem adding this.
"""
@@ -69,20 +87,23 @@ class RegistrationStore(SQLBaseStore):
{
"id": next_id,
"user_id": user_id,
"token": token
"token": token,
"device_id": device_id,
},
desc="add_refresh_token_to_user",
)
@defer.inlineCallbacks
def register(self, user_id, token, password_hash,
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_localpart=None, admin=False):
"""Attempts to register an account.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -230,16 +251,37 @@ class RegistrationStore(SQLBaseStore):
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[]):
def f(txn):
sql = "SELECT token FROM access_tokens WHERE user_id = ?"
def user_delete_access_tokens(self, user_id, except_token_ids=[],
device_id=None,
delete_refresh_tokens=False):
"""
Invalidate access/refresh tokens belonging to a user
Args:
user_id (str): ID of user the tokens belong to
except_token_ids (list[str]): list of access_tokens which should
*not* be deleted
device_id (str|None): ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
delete_refresh_tokens (bool): True to delete refresh tokens as
well as access tokens.
Returns:
defer.Deferred:
"""
def f(txn, table, except_tokens, call_after_delete):
sql = "SELECT token FROM %s WHERE user_id = ?" % table
clauses = [user_id]
if except_token_ids:
if device_id is not None:
sql += " AND device_id = ?"
clauses.append(device_id)
if except_tokens:
sql += " AND id NOT IN (%s)" % (
",".join(["?" for _ in except_token_ids]),
",".join(["?" for _ in except_tokens]),
)
clauses += except_token_ids
clauses += except_tokens
txn.execute(sql, clauses)
@@ -248,16 +290,33 @@ class RegistrationStore(SQLBaseStore):
n = 100
chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)]
for chunk in chunks:
for row in chunk:
txn.call_after(self.get_user_by_access_token.invalidate, (row[0],))
if call_after_delete:
for row in chunk:
txn.call_after(call_after_delete, (row[0],))
txn.execute(
"DELETE FROM access_tokens WHERE token in (%s)" % (
"DELETE FROM %s WHERE token in (%s)" % (
table,
",".join(["?" for _ in chunk]),
), [r[0] for r in chunk]
)
yield self.runInteraction("user_delete_access_tokens", f)
# delete refresh tokens first, to stop new access tokens being
# allocated while our backs are turned
if delete_refresh_tokens:
yield self.runInteraction(
"user_delete_access_tokens", f,
table="refresh_tokens",
except_tokens=[],
call_after_delete=None,
)
yield self.runInteraction(
"user_delete_access_tokens", f,
table="access_tokens",
except_tokens=except_token_ids,
call_after_delete=self.get_user_by_access_token.invalidate,
)
def delete_access_token(self, access_token):
def f(txn):
@@ -280,9 +339,8 @@ class RegistrationStore(SQLBaseStore):
Args:
token (str): The access token of a user.
Returns:
dict: Including the name (user_id) and the ID of their access token.
Raises:
StoreError if no user was found.
defer.Deferred: None, if the token did not match, otherwise dict
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
@@ -291,18 +349,18 @@ class RegistrationStore(SQLBaseStore):
)
def exchange_refresh_token(self, refresh_token, token_generator):
"""Exchange a refresh token for a new access token and refresh token.
"""Exchange a refresh token for a new one.
Doing so invalidates the old refresh token - refresh tokens are single
use.
Args:
token (str): The refresh token of a user.
refresh_token (str): The refresh token of a user.
token_generator (fn: str -> str): Function which, when given a
user ID, returns a unique refresh token for that user. This
function must never return the same value twice.
Returns:
tuple of (user_id, refresh_token)
tuple of (user_id, new_refresh_token, device_id)
Raises:
StoreError if no user was found with that refresh token.
"""
@@ -314,12 +372,13 @@ class RegistrationStore(SQLBaseStore):
)
def _exchange_refresh_token(self, txn, old_token, token_generator):
sql = "SELECT user_id FROM refresh_tokens WHERE token = ?"
sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?"
txn.execute(sql, (old_token,))
rows = self.cursor_to_dict(txn)
if not rows:
raise StoreError(403, "Did not recognize refresh token")
user_id = rows[0]["user_id"]
device_id = rows[0]["device_id"]
# TODO(danielwh): Maybe perform a validation on the macaroon that
# macaroon.user_id == user_id.
@@ -328,7 +387,7 @@ class RegistrationStore(SQLBaseStore):
sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?"
txn.execute(sql, (new_token, old_token,))
return user_id, new_token
return user_id, new_token, device_id
@defer.inlineCallbacks
def is_server_admin(self, user):
@@ -356,7 +415,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
" access_tokens.device_id"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"

View File

@@ -34,108 +34,6 @@ OpsLevel = collections.namedtuple(
class RoomStore(SQLBaseStore):
EVENT_FILES_UPDATE_NAME = "event_files"
FILE_MSGTYPES = (
"m.image",
"m.video",
"m.file",
"m.audio",
)
def __init__(self, hs):
super(RoomStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_FILES_UPDATE_NAME, self._background_reindex_files
)
@defer.inlineCallbacks
def _background_reindex_files(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
def reindex_txn(txn):
sql = (
"SELECT topological_ordering, stream_ordering, event_id, room_id,"
" type, content FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND type = 'm.room.message'"
" AND content LIKE ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
txn.execute(sql, (target_min_stream_id, max_stream_id, '%url%', batch_size))
rows = self.cursor_to_dict(txn)
if not rows:
return 0
min_stream_id = rows[-1]["stream_ordering"]
event_files_rows = []
for row in rows:
try:
so = row["stream_ordering"]
to = row["topological_ordering"]
event_id = row["event_id"]
room_id = row["room_id"]
try:
content = json.loads(row["content"])
except:
continue
msgtype = content["msgtype"]
if msgtype not in self.FILE_MSGTYPES:
continue
url = content["url"]
if not isinstance(url, basestring):
continue
if not isinstance(msgtype, basestring):
continue
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
event_files_rows.append({
"topological_ordering": to,
"stream_ordering": so,
"event_id": event_id,
"room_id": room_id,
"msgtype": msgtype,
"url": url,
})
self._simple_insert_many_txn(
txn,
table="event_files",
values=event_files_rows,
)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(event_files_rows)
}
self._background_update_progress_txn(
txn, self.EVENT_FILES_UPDATE_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
self.EVENT_FILES_UPDATE_NAME, reindex_txn
)
if not result:
yield self._end_background_update(self.EVENT_FILES_UPDATE_NAME)
defer.returnValue(result)
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
@@ -244,22 +142,6 @@ class RoomStore(SQLBaseStore):
)
def _store_room_message_txn(self, txn, event):
msgtype = event.content.get("msgtype")
url = event.content.get("url")
if msgtype in self.FILE_MSGTYPES and url:
self._simple_insert_txn(
txn,
table="event_files",
values={
"topological_ordering": event.depth,
"stream_ordering": event.internal_metadata.stream_ordering,
"room_id": event.room_id,
"event_id": event.event_id,
"msgtype": msgtype,
"url": url,
}
)
if hasattr(event, "content") and "body" in event.content:
self._store_event_search_txn(
txn, event, "content.body", event.content["body"]

View File

@@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('access_tokens_device_index', '{}');

View File

@@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE devices (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
display_name TEXT,
CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
);

View File

@@ -0,0 +1,19 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- make sure that we have a device record for each set of E2E keys, so that the
-- user can delete them if they like.
INSERT INTO devices
SELECT user_id, device_id, 'unknown device' FROM e2e_device_keys_json;

View File

@@ -20,26 +20,14 @@ import ujson
logger = logging.getLogger(__name__)
CREATE_TABLE = """
CREATE TABLE event_files(
topological_ordering BIGINT NOT NULL,
stream_ordering BIGINT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
msgtype TEXT NOT NULL,
url TEXT NOT NULL
);
CREATE INDEX event_files_rm_id ON event_files(room_id, event_id);
CREATE INDEX event_files_order ON event_files(
room_id, topological_ordering, stream_ordering
);
CREATE INDEX event_files_order_stream ON event_files(room_id, stream_ordering);
ALTER_TABLE = """
ALTER TABLE events ADD COLUMN sender TEXT;
ALTER TABLE events ADD COLUMN contains_url BOOLEAN;
"""
def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(CREATE_TABLE.splitlines()):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
cur.execute("SELECT MIN(stream_ordering) FROM events")
@@ -65,7 +53,7 @@ def run_create(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_files", progress_json))
cur.execute(sql, ("event_fields_sender_url", progress_json))
def run_upgrade(cur, database_engine, *args, **kwargs):

View File

@@ -0,0 +1,16 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT;

View File

@@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('refresh_tokens_device_index', '{}');

View File

@@ -0,0 +1,17 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('user_ips_device_index', '{}');

View File

@@ -95,6 +95,54 @@ def upper_bound(token, engine, inclusive=True):
)
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
if not event_filter:
return "", []
clauses = []
args = []
if event_filter.types:
clauses.append(
"(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
)
args.extend(event_filter.types)
for typ in event_filter.not_types:
clauses.append("type != ?")
args.append(typ)
if event_filter.senders:
clauses.append(
"(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
)
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
clauses.append("sender != ?")
args.append(sender)
if event_filter.rooms:
clauses.append(
"(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
)
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
clauses.append("room_id != ?")
args.append(room_id)
if event_filter.contains_url:
clauses.append("contains_url = ?")
args.append(event_filter.contains_url)
return " AND ".join(clauses), args
class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@@ -320,7 +368,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1):
direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
@@ -344,6 +392,12 @@ class StreamStore(SQLBaseStore):
RoomStreamToken.parse(to_key), self.database_engine
))
filter_clause, filter_args = filter_to_clause(event_filter)
if filter_clause:
bounds += " AND " + filter_clause
args.extend(filter_args)
if int(limit) > 0:
args.append(int(limit))
limit_str = " LIMIT ?"
@@ -394,82 +448,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue((events, token))
@defer.inlineCallbacks
def paginate_room_file_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
args = [room_id]
if direction == 'b':
order = "DESC"
bounds = upper_bound(
RoomStreamToken.parse(from_key), self.database_engine
)
if to_key:
bounds = "%s AND %s" % (bounds, lower_bound(
RoomStreamToken.parse(to_key), self.database_engine
))
else:
order = "ASC"
bounds = lower_bound(
RoomStreamToken.parse(from_key), self.database_engine
)
if to_key:
bounds = "%s AND %s" % (bounds, upper_bound(
RoomStreamToken.parse(to_key), self.database_engine
))
if int(limit) > 0:
args.append(int(limit))
limit_str = " LIMIT ?"
else:
limit_str = ""
sql = (
"SELECT * FROM event_files"
" WHERE room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s %(limit)s"
) % {
"bounds": bounds,
"order": order,
"limit": limit_str
}
def f(txn):
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
if rows:
topo = rows[-1]["topological_ordering"]
toke = rows[-1]["stream_ordering"]
if direction == 'b':
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
# when we are going backwards so we subtract one from the
# stream part.
toke -= 1
next_token = str(RoomStreamToken(topo, toke))
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_key if to_key else from_key
return rows, next_token,
rows, token = yield self.runInteraction("paginate_file_events", f)
events = yield self._get_events(
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(events, rows)
defer.returnValue((events, token))
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
rows, token = yield self.get_recent_event_ids_for_room(

View File

@@ -24,6 +24,7 @@ from collections import namedtuple
import itertools
import logging
import ujson as json
logger = logging.getLogger(__name__)
@@ -101,7 +102,7 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
return result["response_code"], result["response_json"]
return result["response_code"], json.loads(str(result["response_json"]))
else:
return None

View File

@@ -56,7 +56,7 @@ class PaginationConfig(object):
@classmethod
def from_request(cls, request, raise_invalid_params=True,
default_limit=None, default_dir='f'):
default_limit=None):
def get_param(name, default=None):
lst = request.args.get(name, [])
if len(lst) > 1:
@@ -68,7 +68,7 @@ class PaginationConfig(object):
else:
return default
direction = get_param("dir", default_dir)
direction = get_param("dir", 'f')
if direction not in ['f', 'b']:
raise SynapseError(400, "'dir' parameter is invalid.")

View File

@@ -18,7 +18,38 @@ from synapse.api.errors import SynapseError
from collections import namedtuple
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
Requester = namedtuple("Requester",
["user", "access_token_id", "is_guest", "device_id"])
"""
Represents the user making a request
Attributes:
user (UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time
"""
def create_requester(user_id, access_token_id=None, is_guest=False,
device_id=None):
"""
Create a new ``Requester`` object
Args:
user_id (str|UserID): id of the user making the request
access_token_id (int|None): *ID* of the access token used for this
request, or None if it came via the appservice API or similar
is_guest (bool): True if the user making this request is a guest user
device_id (str|None): device_id which was set at authentication time
Returns:
Requester
"""
if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id)
return Requester(user_id, access_token_id, is_guest, device_id)
def get_domain_from_id(string):

View File

@@ -84,7 +84,7 @@ class Measure(object):
if context != self.start_context:
logger.warn(
"Context have unexpectedly changed from '%s' to '%s'. (%r)",
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
context, self.start_context, self.name
)
return

View File

@@ -83,7 +83,10 @@ def calculate_room_name(room_state, user_id, fallback_to_members=True,
):
if ("m.room.member", my_member_event.sender) in room_state:
inviter_member_event = room_state[("m.room.member", my_member_event.sender)]
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
if fallback_to_single_member:
return "Invite from %s" % (name_from_member_event(inviter_member_event),)
else:
return None
else:
return "Room Invite"

View File

@@ -45,6 +45,7 @@ class AuthTestCase(unittest.TestCase):
user_info = {
"name": self.test_user,
"token_id": "ditto",
"device_id": "device",
}
self.store.get_user_by_access_token = Mock(return_value=user_info)
@@ -143,7 +144,10 @@ class AuthTestCase(unittest.TestCase):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
return_value={
"name": "@baldrick:matrix.org",
"device_id": "device",
}
)
user_id = "@baldrick:matrix.org"
@@ -158,6 +162,10 @@ class AuthTestCase(unittest.TestCase):
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
# TODO: device_id should come from the macaroon, but currently comes
# from the db.
self.assertEqual(user_info["device_id"], "device")
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
user_id = "@baldrick:matrix.org"
@@ -281,7 +289,7 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user,))
macaroon.add_first_party_caveat("time < 1") # ms
macaroon.add_first_party_caveat("time < -2000") # ms
self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True
@@ -293,3 +301,32 @@ class AuthTestCase(unittest.TestCase):
yield self.auth.get_user_from_macaroon(macaroon.serialize())
self.assertEqual(401, cm.exception.code)
self.assertIn("Invalid macaroon", cm.exception.msg)
@defer.inlineCallbacks
def test_get_user_from_macaroon_with_valid_duration(self):
# TODO(danielwh): Remove this mock when we remove the
# get_user_by_access_token fallback.
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
self.store.get_user_by_access_token = Mock(
return_value={"name": "@baldrick:matrix.org"}
)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
macaroon.add_first_party_caveat("time < 900000000") # ms
self.hs.clock.now = 5000 # seconds
self.hs.config.expire_access_token = True
user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize())
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)

View File

@@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
from synapse import types
from tests import unittest, utils
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock
@defer.inlineCallbacks
def setUp(self):
hs = yield utils.setup_test_homeserver(handlers=None)
self.handler = synapse.handlers.device.DeviceHandler(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
user_id="boris",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
user_id="boris",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
user_id="boris",
device_id="fco",
initial_device_display_name="new display name"
)
self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
user_id="theresa",
device_id=None,
initial_device_display_name="display"
)
dev = yield self.handler.store.get_device("theresa", device_id)
self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks
def test_get_devices_by_user(self):
yield self._record_users()
res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res))
device_map = {
d["device_id"]: d for d in res
}
self.assertDictContainsSubset({
"user_id": user1,
"device_id": "xyz",
"display_name": "display 0",
"last_seen_ip": None,
"last_seen_ts": None,
}, device_map["xyz"])
self.assertDictContainsSubset({
"user_id": user1,
"device_id": "fco",
"display_name": "display 1",
"last_seen_ip": "ip1",
"last_seen_ts": 1000000,
}, device_map["fco"])
self.assertDictContainsSubset({
"user_id": user1,
"device_id": "abc",
"display_name": "display 2",
"last_seen_ip": "ip3",
"last_seen_ts": 3000000,
}, device_map["abc"])
@defer.inlineCallbacks
def test_get_device(self):
yield self._record_users()
res = yield self.handler.get_device(user1, "abc")
self.assertDictContainsSubset({
"user_id": user1,
"device_id": "abc",
"display_name": "display 2",
"last_seen_ip": "ip3",
"last_seen_ts": 3000000,
}, res)
@defer.inlineCallbacks
def test_delete_device(self):
yield self._record_users()
# delete the device
yield self.handler.delete_device(user1, "abc")
# check the device was deleted
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.get_device(user1, "abc")
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self):
yield self._record_users()
update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update)
res = yield self.handler.get_device(user1, "abc")
self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id",
update)
@defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
yield self._record_user(user1, "xyz", "display 0")
yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
@defer.inlineCallbacks
def _record_user(self, user_id, device_id, display_name,
access_token=None, ip=None):
device_id = yield self.handler.check_device_registered(
user_id=user_id,
device_id=device_id,
initial_device_display_name=display_name
)
if ip is not None:
yield self.store.insert_client_ip(
types.UserID.from_string(user_id),
access_token, ip, "user_agent", device_id)
self.clock.advance_time(1000)

View File

@@ -19,11 +19,12 @@ from twisted.internet import defer
from mock import Mock, NonCallableMock
import synapse.types
from synapse.api.errors import AuthError
from synapse.handlers.profile import ProfileHandler
from synapse.types import UserID
from tests.utils import setup_test_homeserver, requester_for_user
from tests.utils import setup_test_homeserver
class ProfileHandlers(object):
@@ -86,7 +87,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name(self):
yield self.handler.set_displayname(
self.frank,
requester_for_user(self.frank),
synapse.types.create_requester(self.frank),
"Frank Jr."
)
@@ -99,7 +100,7 @@ class ProfileTestCase(unittest.TestCase):
def test_set_my_name_noauth(self):
d = self.handler.set_displayname(
self.frank,
requester_for_user(self.bob),
synapse.types.create_requester(self.bob),
"Frank Jr."
)
@@ -144,7 +145,8 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_set_my_avatar(self):
yield self.handler.set_avatar_url(
self.frank, requester_for_user(self.frank), "http://my.server/pic.gif"
self.frank, synapse.types.create_requester(self.frank),
"http://my.server/pic.gif"
)
self.assertEquals(

View File

@@ -42,12 +42,12 @@ class RegistrationTestCase(unittest.TestCase):
http_client=None,
expire_access_token=True)
self.auth_handler = Mock(
generate_short_term_login_token=Mock(return_value='secret'))
generate_access_token=Mock(return_value='secret'))
self.hs.handlers = RegistrationHandlers(self.hs)
self.handler = self.hs.get_handlers().registration_handler
self.hs.get_handlers().profile_handler = Mock()
self.mock_handler = Mock(spec=[
"generate_short_term_login_token",
"generate_access_token",
])
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)

View File

@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.resource import ReplicationResource
from synapse.types import Requester, UserID
from twisted.internet import defer
from tests import unittest
from tests.utils import setup_test_homeserver, requester_for_user
from mock import Mock, NonCallableMock
import json
import contextlib
import json
from mock import Mock, NonCallableMock
from twisted.internet import defer
import synapse.types
from synapse.replication.resource import ReplicationResource
from synapse.types import UserID
from tests import unittest
from tests.utils import setup_test_homeserver
class ReplicationResourceCase(unittest.TestCase):
@@ -61,7 +63,7 @@ class ReplicationResourceCase(unittest.TestCase):
def test_events_and_state(self):
get = self.get(events="-1", state="-1", timeout="0")
yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {}
synapse.types.create_requester(self.user), {}
)
code, body = yield get
self.assertEquals(code, 200)
@@ -144,7 +146,7 @@ class ReplicationResourceCase(unittest.TestCase):
def send_text_message(self, room_id, message):
handler = self.hs.get_handlers().message_handler
event = yield handler.create_and_send_nonmember_event(
requester_for_user(self.user),
synapse.types.create_requester(self.user),
{
"type": "m.room.message",
"content": {"body": "message", "msgtype": "m.text"},
@@ -157,7 +159,7 @@ class ReplicationResourceCase(unittest.TestCase):
@defer.inlineCallbacks
def create_room(self):
result = yield self.hs.get_handlers().room_creation_handler.create_room(
Requester(self.user, "", False), {}
synapse.types.create_requester(self.user), {}
)
defer.returnValue(result["room_id"])

View File

@@ -14,17 +14,14 @@
# limitations under the License.
"""Tests REST events for /profile paths."""
from tests import unittest
from mock import Mock
from twisted.internet import defer
from mock import Mock
from ....utils import MockHttpResource, setup_test_homeserver
import synapse.types
from synapse.api.errors import SynapseError, AuthError
from synapse.types import Requester, UserID
from synapse.rest.client.v1 import profile
from tests import unittest
from ....utils import MockHttpResource, setup_test_homeserver
myid = "@1234ABCD:test"
PATH_PREFIX = "/_matrix/client/api/v1"
@@ -52,7 +49,7 @@ class ProfileTestCase(unittest.TestCase):
)
def _get_user_by_req(request=None, allow_guest=False):
return Requester(UserID.from_string(myid), "", False)
return synapse.types.create_requester(myid)
hs.get_v1auth().get_user_by_req = _get_user_by_req

View File

@@ -30,6 +30,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
# do the dance to hook it up to the hs global
self.handlers = Mock(
@@ -42,6 +43,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True
# init the thing we're testing
@@ -61,13 +63,18 @@ class RegisterRestServletTestCase(unittest.TestCase):
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
return_value=(user_id, token)
return_value=user_id
)
self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname
}
self.assertDictContainsSubset(det_data, result)
@@ -105,26 +112,37 @@ class RegisterRestServletTestCase(unittest.TestCase):
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
device_id = "frogfone"
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
"password": "monkey",
"device_id": device_id,
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
}, None)
self.registration_handler.register = Mock(return_value=(user_id, token))
self.registration_handler.register = Mock(return_value=(user_id, None))
self.auth_handler.get_login_tuple_for_user_id = Mock(
return_value=(token, "kermits_refresh_token")
)
self.device_handler.check_device_registered = \
Mock(return_value=device_id)
(code, result) = yield self.servlet.on_POST(self.request)
self.assertEquals(code, 200)
det_data = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
"refresh_token": "kermits_refresh_token",
"home_server": self.hs.hostname,
"device_id": device_id,
}
self.assertDictContainsSubset(det_data, result)
self.assertIn("refresh_token", result)
self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None)
def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False

View File

@@ -30,6 +30,7 @@ class EventInjector:
def create_room(self, room):
builder = self.event_builder_factory.new({
"type": EventTypes.Create,
"sender": "",
"room_id": room.to_string(),
"content": {},
})

View File

@@ -10,7 +10,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver()
hs = yield setup_test_homeserver() # type: synapse.server.HomeServer
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@@ -20,11 +20,20 @@ class BackgroundUpdateTestCase(unittest.TestCase):
"test_update", self.update_handler
)
# run the real background updates, to get them out the way
# (perhaps we should run them as part of the test HS setup, since we
# run all of the other schema setup stuff there?)
while True:
res = yield self.store.do_next_background_update(1000)
if res is None:
break
@defer.inlineCallbacks
def test_do_background_update(self):
desired_count = 1000
duration_ms = 42
# first step: make a bit of progress
@defer.inlineCallbacks
def update(progress, count):
self.clock.advance_time_msec(count * duration_ms)
@@ -42,7 +51,7 @@ class BackgroundUpdateTestCase(unittest.TestCase):
yield self.store.start_background_update("test_update", {"my_key": 1})
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@@ -50,15 +59,15 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE
)
# second step: complete the update
@defer.inlineCallbacks
def update(progress, count):
yield self.store._end_background_update("test_update")
defer.returnValue(count)
self.update_handler.side_effect = update
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNotNone(result)
@@ -66,8 +75,9 @@ class BackgroundUpdateTestCase(unittest.TestCase):
{"my_key": 2}, desired_count
)
# third step: we don't expect to be called any more
self.update_handler.reset_mock()
result = yield self.store.do_background_update(
result = yield self.store.do_next_background_update(
duration_ms * desired_count
)
self.assertIsNone(result)

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import synapse.server
import synapse.storage
import synapse.types
import tests.unittest
import tests.utils
class ClientIpStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ClientIpStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
self.clock = None # type: tests.utils.MockClock
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_insert_new_client_ip(self):
self.clock.now = 12345678
user_id = "@user:id"
yield self.store.insert_client_ip(
synapse.types.UserID.from_string(user_id),
"access_token", "ip", "user_agent", "device_id",
)
# deliberately use an iterable here to make sure that the lookup
# method doesn't iterate it twice
device_list = iter(((user_id, "device_id"),))
result = yield self.store.get_last_client_ip_by_device(device_list)
r = result[(user_id, "device_id")]
self.assertDictContainsSubset(
{
"user_id": user_id,
"device_id": "device_id",
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
},
r
)

View File

@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import synapse.api.errors
import tests.unittest
import tests.utils
class DeviceStoreTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(DeviceStoreTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
@defer.inlineCallbacks
def setUp(self):
hs = yield tests.utils.setup_test_homeserver()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def test_store_new_device(self):
yield self.store.store_device(
"user_id", "device_id", "display_name"
)
res = yield self.store.get_device("user_id", "device_id")
self.assertDictContainsSubset({
"user_id": "user_id",
"device_id": "device_id",
"display_name": "display_name",
}, res)
@defer.inlineCallbacks
def test_get_devices_by_user(self):
yield self.store.store_device(
"user_id", "device1", "display_name 1"
)
yield self.store.store_device(
"user_id", "device2", "display_name 2"
)
yield self.store.store_device(
"user_id2", "device3", "display_name 3"
)
res = yield self.store.get_devices_by_user("user_id")
self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset({
"user_id": "user_id",
"device_id": "device1",
"display_name": "display_name 1",
}, res["device1"])
self.assertDictContainsSubset({
"user_id": "user_id",
"device_id": "device2",
"display_name": "display_name 2",
}, res["device2"])
@defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device(
"user_id", "device_id", "display_name 1"
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield self.store.update_device(
"user_id", "device_id",
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
yield self.store.update_device(
"user_id", "device_id",
new_display_name="display_name 2",
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device(
"user_id", "unknown_device_id",
new_display_name="display_name 2",
)
self.assertEqual(404, cm.exception.code)

View File

@@ -37,7 +37,7 @@ class EventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_count_daily_messages(self):
self.db_pool.runQuery("DELETE FROM stats_reporting")
yield self.db_pool.runQuery("DELETE FROM stats_reporting")
self.hs.clock.now = 100
@@ -60,7 +60,7 @@ class EventsStoreTestCase(unittest.TestCase):
# it isn't old enough.
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(1, self.hs.clock.now)
yield self._assert_stats_reporting(1, self.hs.clock.now)
# Already reported yesterday, two new events from today.
yield self.event_injector.inject_message(room, user, "Yeah they are!")
@@ -68,21 +68,21 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += 60 * 60 * 24
count = yield self.store.count_daily_messages()
self.assertEqual(2, count) # 2 since yesterday
self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever
# Last reported too recently.
yield self.event_injector.inject_message(room, user, "Who could disagree?")
self.hs.clock.now += 60 * 60 * 22
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(4, self.hs.clock.now)
yield self._assert_stats_reporting(4, self.hs.clock.now)
# Last reported too long ago
yield self.event_injector.inject_message(room, user, "No one.")
self.hs.clock.now += 60 * 60 * 26
count = yield self.store.count_daily_messages()
self.assertIsNone(count)
self._assert_stats_reporting(5, self.hs.clock.now)
yield self._assert_stats_reporting(5, self.hs.clock.now)
# And now let's actually report something
yield self.event_injector.inject_message(room, user, "Indeed.")
@@ -92,7 +92,7 @@ class EventsStoreTestCase(unittest.TestCase):
self.hs.clock.now += (60 * 60 * 24) + 50
count = yield self.store.count_daily_messages()
self.assertEqual(3, count)
self._assert_stats_reporting(8, self.hs.clock.now)
yield self._assert_stats_reporting(8, self.hs.clock.now)
@defer.inlineCallbacks
def _get_last_stream_token(self):

View File

@@ -38,6 +38,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"BcDeFgHiJkLmNoPqRsTuVwXyZa"
]
self.pwhash = "{xx1}123456789"
self.device_id = "akgjhdjklgshg"
@defer.inlineCallbacks
def test_register(self):
@@ -64,13 +65,15 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_add_tokens(self):
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
self.device_id)
result = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertDictContainsSubset(
{
"name": self.user_id,
"device_id": self.device_id,
},
result
)
@@ -80,20 +83,24 @@ class RegistrationStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_exchange_refresh_token_valid(self):
uid = stringutils.random_string(32)
device_id = stringutils.random_string(16)
generator = TokenGenerator()
last_token = generator.generate(uid)
self.db_pool.runQuery(
"INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)",
(uid, last_token,))
"INSERT INTO refresh_tokens(user_id, token, device_id) "
"VALUES(?,?,?)",
(uid, last_token, device_id))
(found_user_id, refresh_token) = yield self.store.exchange_refresh_token(
last_token, generator.generate)
(found_user_id, refresh_token, device_id) = \
yield self.store.exchange_refresh_token(last_token,
generator.generate)
self.assertEqual(uid, found_user_id)
rows = yield self.db_pool.runQuery(
"SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, ))
self.assertEqual([(refresh_token,)], rows)
"SELECT token, device_id FROM refresh_tokens WHERE user_id = ?",
(uid, ))
self.assertEqual([(refresh_token, device_id)], rows)
# We issued token 1, then exchanged it for token 2
expected_refresh_token = u"%s-%d" % (uid, 2,)
self.assertEqual(expected_refresh_token, refresh_token)
@@ -121,6 +128,40 @@ class RegistrationStoreTestCase(unittest.TestCase):
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(last_token, generator.generate)
@defer.inlineCallbacks
def test_user_delete_access_tokens(self):
# add some tokens
generator = TokenGenerator()
refresh_token = generator.generate(self.user_id)
yield self.store.register(self.user_id, self.tokens[0], self.pwhash)
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1],
self.device_id)
yield self.store.add_refresh_token_to_user(self.user_id, refresh_token,
self.device_id)
# now delete some
yield self.store.user_delete_access_tokens(
self.user_id, device_id=self.device_id, delete_refresh_tokens=True)
# check they were deleted
user = yield self.store.get_user_by_access_token(self.tokens[1])
self.assertIsNone(user, "access token was not deleted by device_id")
with self.assertRaises(StoreError):
yield self.store.exchange_refresh_token(refresh_token,
generator.generate)
# check the one not associated with the device was not deleted
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertEqual(self.user_id, user["name"])
# now delete the rest
yield self.store.user_delete_access_tokens(
self.user_id, delete_refresh_tokens=True)
user = yield self.store.get_user_by_access_token(self.tokens[0])
self.assertIsNone(user,
"access token was not deleted without device_id")
class TokenGenerator:
def __init__(self):

View File

@@ -17,13 +17,18 @@ from twisted.trial import unittest
import logging
# logging doesn't have a "don't log anything at all EVARRRR setting,
# but since the highest value is 50, 1000000 should do ;)
NEVER = 1000000
logging.getLogger().addHandler(logging.StreamHandler())
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(
"%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]"
))
logging.getLogger().addHandler(handler)
logging.getLogger().setLevel(NEVER)
logging.getLogger("synapse.storage.SQL").setLevel(NEVER)
logging.getLogger("synapse.storage.txn").setLevel(NEVER)
def around(target):
@@ -70,8 +75,6 @@ class TestCase(unittest.TestCase):
return ret
logging.getLogger().setLevel(level)
# Don't set SQL logging
logging.getLogger("synapse.storage").setLevel(old_level)
return orig()
def assertObjectHasAttributes(self, attrs, obj):

View File

@@ -20,7 +20,6 @@ from synapse.storage.prepare_database import prepare_database
from synapse.storage.engines import create_engine
from synapse.server import HomeServer
from synapse.federation.transport import server
from synapse.types import Requester
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.logcontext import LoggingContext
@@ -512,7 +511,3 @@ class DeferredMockCallable(object):
"call(%s)" % _format_call(c[0], c[1]) for c in calls
])
)
def requester_for_user(user):
return Requester(user, None, False)