mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-09 01:30:18 +00:00
Compare commits
420 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d330d45e2d | ||
|
|
ccd4290777 | ||
|
|
9cb42507f8 | ||
|
|
d960910c72 | ||
|
|
b46b8a5efb | ||
|
|
27fe3e2d4f | ||
|
|
3410142741 | ||
|
|
e7674eb759 | ||
|
|
7c1a92274c | ||
|
|
f5deaff424 | ||
|
|
5f360182c6 | ||
|
|
46453bfc2f | ||
|
|
6bf6bc1d1d | ||
|
|
f45be05305 | ||
|
|
24f36469bc | ||
|
|
597c79be10 | ||
|
|
e6021c370e | ||
|
|
a8b946decb | ||
|
|
3f5ac150b2 | ||
|
|
5bcccfde6c | ||
|
|
c95dd7a426 | ||
|
|
4d87d3659a | ||
|
|
32fc39fd4c | ||
|
|
93acf49e9b | ||
|
|
b2c290a6e5 | ||
|
|
499dc1b349 | ||
|
|
9377509dcd | ||
|
|
2d4de61fb7 | ||
|
|
87ef315ad5 | ||
|
|
fccadb7719 | ||
|
|
f0fa66f495 | ||
|
|
1515d1b581 | ||
|
|
a2b7102eea | ||
|
|
a5d7968b3e | ||
|
|
b5525c76d1 | ||
|
|
e97648c4e2 | ||
|
|
835ceeee76 | ||
|
|
b3682df2ca | ||
|
|
8ad8490cff | ||
|
|
59fa91fe88 | ||
|
|
1b5436ad78 | ||
|
|
257c41cc2e | ||
|
|
b4e2290d89 | ||
|
|
e3ee63578f | ||
|
|
ea0b767114 | ||
|
|
ab03912e94 | ||
|
|
1fc50712d6 | ||
|
|
7c7786d4e1 | ||
|
|
b0a14bf53e | ||
|
|
f131cd9e53 | ||
|
|
edb33eb163 | ||
|
|
bcc9cda8ca | ||
|
|
05e3354047 | ||
|
|
3364d96c6b | ||
|
|
98385888b8 | ||
|
|
68264d7404 | ||
|
|
91fa69e029 | ||
|
|
4c56bedee3 | ||
|
|
520ee9bd2c | ||
|
|
a60a2eaa02 | ||
|
|
e3a720217a | ||
|
|
530bc862dc | ||
|
|
a6f5cc65d9 | ||
|
|
e555bc6551 | ||
|
|
a843868fe9 | ||
|
|
f5da3bacb2 | ||
|
|
97f072db74 | ||
|
|
80ad710217 | ||
|
|
4fec5e57be | ||
|
|
a8a32d2714 | ||
|
|
921f17f938 | ||
|
|
9a2f296fa2 | ||
|
|
58c9653c6b | ||
|
|
6b58ade2f0 | ||
|
|
9e66c58ceb | ||
|
|
f83f5fbce8 | ||
|
|
aecaec3e10 | ||
|
|
1efee2f52b | ||
|
|
49e047c55e | ||
|
|
59a2c6d60e | ||
|
|
06f812b95c | ||
|
|
c9154b970c | ||
|
|
b3d5c4ad9d | ||
|
|
456544b621 | ||
|
|
d199f2248f | ||
|
|
8f650bd338 | ||
|
|
7b0f6293f2 | ||
|
|
fcde5b2a97 | ||
|
|
342e072024 | ||
|
|
55e8a87888 | ||
|
|
26a8c1d7ab | ||
|
|
54de6a812a | ||
|
|
733bf44290 | ||
|
|
f14cd95342 | ||
|
|
986615b0b2 | ||
|
|
bfeaab6dfc | ||
|
|
b260f92936 | ||
|
|
271d3e7865 | ||
|
|
18b7eb830b | ||
|
|
74106ba171 | ||
|
|
2a9ce8c422 | ||
|
|
cbea0c7044 | ||
|
|
5aa024e501 | ||
|
|
c51a52f300 | ||
|
|
3d13c3a295 | ||
|
|
a679a01dbe | ||
|
|
8dad08a950 | ||
|
|
0a7d3cd00f | ||
|
|
ec8b217722 | ||
|
|
328ad6901d | ||
|
|
76b89d0edb | ||
|
|
370135ad0b | ||
|
|
0fcbca531f | ||
|
|
1e2740caab | ||
|
|
6ede23ff1b | ||
|
|
591ad2268c | ||
|
|
3c3246c078 | ||
|
|
5fb41a955c | ||
|
|
367b594183 | ||
|
|
7861cfec0a | ||
|
|
019cf013d6 | ||
|
|
4329f20e3f | ||
|
|
a285194021 | ||
|
|
bf81e38d36 | ||
|
|
78cac3e594 | ||
|
|
7871790db1 | ||
|
|
18e044628e | ||
|
|
b557b682d9 | ||
|
|
389c890f14 | ||
|
|
cd8738ab63 | ||
|
|
f6f8f81a48 | ||
|
|
ecd5e6bfa4 | ||
|
|
fda078f995 | ||
|
|
ab5580e152 | ||
|
|
e8d212d92e | ||
|
|
40e539683c | ||
|
|
05f6447301 | ||
|
|
5238960850 | ||
|
|
ccec25e2c6 | ||
|
|
c38b7c4104 | ||
|
|
29b25d59c6 | ||
|
|
884b800899 | ||
|
|
09d31815b4 | ||
|
|
fe1b369946 | ||
|
|
26cb0efa88 | ||
|
|
d47115ff8b | ||
|
|
2e3d90d67c | ||
|
|
a4b06b619c | ||
|
|
c63b1697f4 | ||
|
|
87ffd21b29 | ||
|
|
2452611d0f | ||
|
|
eb359eced4 | ||
|
|
c824b29e77 | ||
|
|
33d7776473 | ||
|
|
9ad8d9b17c | ||
|
|
5b1825ba5b | ||
|
|
9c4cf83259 | ||
|
|
05e7e5e972 | ||
|
|
db4f823d34 | ||
|
|
8e02494166 | ||
|
|
a6f06ce3e2 | ||
|
|
d34e9f93b7 | ||
|
|
efeb6176c1 | ||
|
|
1a54513cf1 | ||
|
|
242c52d607 | ||
|
|
012b4c1913 | ||
|
|
436bffd15f | ||
|
|
1b3c3e6d68 | ||
|
|
33d08e8433 | ||
|
|
8f7f4cb92b | ||
|
|
2623cec874 | ||
|
|
4fcdf7b4b2 | ||
|
|
955ef1f06c | ||
|
|
2ee4c9ee02 | ||
|
|
9dbd903f41 | ||
|
|
bf3de7b90b | ||
|
|
e73ad8de3b | ||
|
|
42f4feb2b7 | ||
|
|
f16f0e169d | ||
|
|
465117d7ca | ||
|
|
7ed58bb347 | ||
|
|
dad2da7e54 | ||
|
|
363786845b | ||
|
|
ec5717caf5 | ||
|
|
d26b660aa6 | ||
|
|
aede7248ab | ||
|
|
68a92afcff | ||
|
|
55abbe1850 | ||
|
|
2c28e25bda | ||
|
|
1e6e370b76 | ||
|
|
1c3c202b96 | ||
|
|
406f7aa0f6 | ||
|
|
34f56b40fd | ||
|
|
c445f5fec7 | ||
|
|
44adde498e | ||
|
|
cf94a78872 | ||
|
|
1a64dffb00 | ||
|
|
081e5d55e6 | ||
|
|
248e6770ca | ||
|
|
40a1c96617 | ||
|
|
7314bf4682 | ||
|
|
e9e3eaa67d | ||
|
|
d36b1d849d | ||
|
|
742056be0d | ||
|
|
bc8f265f0a | ||
|
|
ec041b335e | ||
|
|
053e83dafb | ||
|
|
b97a1356b1 | ||
|
|
b73dc0ef4d | ||
|
|
499e3281e6 | ||
|
|
66868119dc | ||
|
|
aba0b2a39b | ||
|
|
57dca35692 | ||
|
|
c68518dfbb | ||
|
|
e967bc86e7 | ||
|
|
1e2a7f18a1 | ||
|
|
f91faf09b3 | ||
|
|
4430b1ceb3 | ||
|
|
3413f1e284 | ||
|
|
40cbffb2d2 | ||
|
|
b9e997f561 | ||
|
|
9a7a77a22a | ||
|
|
8f6281ab0c | ||
|
|
0da0d0a29d | ||
|
|
022b9176fe | ||
|
|
0c62c958fd | ||
|
|
c41d52a042 | ||
|
|
7e554aac86 | ||
|
|
f863a52cea | ||
|
|
93efcb8526 | ||
|
|
dcfd71aa4c | ||
|
|
fca90b3445 | ||
|
|
a292454aa1 | ||
|
|
4f81edbd4f | ||
|
|
6344db659f | ||
|
|
511a52afc8 | ||
|
|
e885e2a623 | ||
|
|
d137e03231 | ||
|
|
f52565de50 | ||
|
|
a2d288c6a9 | ||
|
|
bd7c51921d | ||
|
|
978fa53cc2 | ||
|
|
eec9609e96 | ||
|
|
9e1b43bcbf | ||
|
|
a3036ac37e | ||
|
|
ebdafd8114 | ||
|
|
a98d215204 | ||
|
|
d554ca5e1d | ||
|
|
209e04fa11 | ||
|
|
e5142f65a6 | ||
|
|
b64aa6d687 | ||
|
|
848d3bf2e1 | ||
|
|
b55c770271 | ||
|
|
d543b72562 | ||
|
|
0136a522b1 | ||
|
|
2cb758ac75 | ||
|
|
560c71c735 | ||
|
|
a37ee2293c | ||
|
|
c55ad2e375 | ||
|
|
aaa9d9f0e1 | ||
|
|
75fa7f6b3c | ||
|
|
a5db0026ed | ||
|
|
9c491366c5 | ||
|
|
385aec4010 | ||
|
|
10f4856b0c | ||
|
|
dfde67a6fe | ||
|
|
10c843fcfb | ||
|
|
58930da52b | ||
|
|
0870588c20 | ||
|
|
f90cf150e2 | ||
|
|
067596d341 | ||
|
|
70d650be2b | ||
|
|
b92e7955be | ||
|
|
c98e1479bd | ||
|
|
67f2c901ea | ||
|
|
eef7778af9 | ||
|
|
a17e7caeb7 | ||
|
|
f0c06ac65c | ||
|
|
76b18df3d9 | ||
|
|
0da24cac8b | ||
|
|
2e3c8acc68 | ||
|
|
8d9a884cee | ||
|
|
896bc6cd46 | ||
|
|
be3548f7e1 | ||
|
|
4adf93e0f7 | ||
|
|
651faee698 | ||
|
|
caf33b2d9b | ||
|
|
8f8798bc0d | ||
|
|
7335f0adda | ||
|
|
ef535178ff | ||
|
|
04dee11e97 | ||
|
|
dd2ccee27d | ||
|
|
b6b0132ac7 | ||
|
|
e34cb5e7dc | ||
|
|
252ee2d979 | ||
|
|
14362bf359 | ||
|
|
1ee2584307 | ||
|
|
507b8bb091 | ||
|
|
d44d11d864 | ||
|
|
2d21d43c34 | ||
|
|
0fb76c71ac | ||
|
|
8bdaf5f7af | ||
|
|
a67bf0b074 | ||
|
|
f18d7546c6 | ||
|
|
3de8168343 | ||
|
|
bb069079bb | ||
|
|
2e5a31f197 | ||
|
|
fc8007dbec | ||
|
|
1238203bc4 | ||
|
|
41f072fd0e | ||
|
|
5a6ef20ef6 | ||
|
|
be8be535f7 | ||
|
|
ab71589c0b | ||
|
|
f328d95cef | ||
|
|
aac546c978 | ||
|
|
f52cb4cd78 | ||
|
|
6783534a0f | ||
|
|
a70688445d | ||
|
|
314b146b2e | ||
|
|
7846ac3125 | ||
|
|
56ec5869c9 | ||
|
|
1ea358b28b | ||
|
|
db74dcda5b | ||
|
|
551fe80bed | ||
|
|
63bb8f0df9 | ||
|
|
70d820c875 | ||
|
|
3f7652c56f | ||
|
|
535b6bfacc | ||
|
|
f7fe0e5f67 | ||
|
|
2455ad8468 | ||
|
|
0b640aa56b | ||
|
|
aa3a4944d5 | ||
|
|
46b7362304 | ||
|
|
870c45913e | ||
|
|
05f1a4596a | ||
|
|
13517e2914 | ||
|
|
b5fb7458d5 | ||
|
|
f73fdb04a6 | ||
|
|
3a4120e49a | ||
|
|
9fe894402f | ||
|
|
0a32208e5d | ||
|
|
774f3a692c | ||
|
|
5cc7564c5c | ||
|
|
0fe0b0eeb6 | ||
|
|
0126a5d60a | ||
|
|
13e334506c | ||
|
|
6b40e4f52a | ||
|
|
d5fb561709 | ||
|
|
d8ec81cc31 | ||
|
|
bc72d381b2 | ||
|
|
4d362a61ea | ||
|
|
00c281f6a4 | ||
|
|
c8cd41cdd8 | ||
|
|
41e4b2efea | ||
|
|
0c13d45522 | ||
|
|
9f1800fba8 | ||
|
|
8f4a9bbc16 | ||
|
|
9ba2bf1570 | ||
|
|
120c238705 | ||
|
|
1c1f633b13 | ||
|
|
6660f37558 | ||
|
|
20e5b46b20 | ||
|
|
0113ad36ee | ||
|
|
3e41de05cc | ||
|
|
2884712ca7 | ||
|
|
ded01c3bf6 | ||
|
|
8c75040c25 | ||
|
|
f1073ad43d | ||
|
|
a352b68acf | ||
|
|
364d616792 | ||
|
|
bde13833cb | ||
|
|
80a1bc7db5 | ||
|
|
f1f70bf4b5 | ||
|
|
dbb5a39b64 | ||
|
|
885ee861f7 | ||
|
|
486b9a6a2d | ||
|
|
1f31381611 | ||
|
|
ed5f43a55a | ||
|
|
09a17f965c | ||
|
|
1e9026e484 | ||
|
|
a60169ea09 | ||
|
|
78a16d395c | ||
|
|
5436848955 | ||
|
|
0477368e9a | ||
|
|
0ef0655b83 | ||
|
|
a64dbae90b | ||
|
|
d41a1a91d3 | ||
|
|
15bf3e3376 | ||
|
|
d9f7fa2e57 | ||
|
|
b31c49d676 | ||
|
|
255c229f23 | ||
|
|
d12134ce37 | ||
|
|
36e2aade87 | ||
|
|
33546b58aa | ||
|
|
fdc015c6e9 | ||
|
|
e9892b4b26 | ||
|
|
50f69e2ef2 | ||
|
|
41b35412bf | ||
|
|
7dbb473339 | ||
|
|
16a8884233 | ||
|
|
5ee5b655b2 | ||
|
|
0a43219a27 | ||
|
|
eba4ff1bcb | ||
|
|
2eab219a70 | ||
|
|
95f305c35a | ||
|
|
6e7dc7c7dd | ||
|
|
b7fbc9bd95 | ||
|
|
919a2c74f6 | ||
|
|
81c07a32fd | ||
|
|
b063784b78 | ||
|
|
defa28efa1 | ||
|
|
690029d1a3 | ||
|
|
d88faf92d1 | ||
|
|
958c968d02 | ||
|
|
746b2f5657 | ||
|
|
efeabd3180 | ||
|
|
1fd6eb695d | ||
|
|
128360d4f0 | ||
|
|
17aab5827a | ||
|
|
1a815fb04f |
173
CHANGES.rst
173
CHANGES.rst
@@ -1,3 +1,174 @@
|
||||
Changes in synapse v0.17.0 (2016-08-08)
|
||||
=======================================
|
||||
|
||||
This release contains significant security bug fixes regarding authenticating
|
||||
events received over federation. PLEASE UPGRADE.
|
||||
|
||||
This release changes the LDAP configuration format in a backwards incompatible
|
||||
way, see PR #843 for details.
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Add federation /version API (PR #990)
|
||||
* Make psutil dependency optional (PR #992)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix URL preview API to exclude HTML comments in description (PR #988)
|
||||
* Fix error handling of remote joins (PR #991)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc4 (2016-08-05)
|
||||
===========================================
|
||||
|
||||
Changes:
|
||||
|
||||
* Change the way we summarize URLs when previewing (PR #973)
|
||||
* Add new ``/state_ids/`` federation API (PR #979)
|
||||
* Speed up processing of ``/state/`` response (PR #986)
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix event persistence when event has already been partially persisted
|
||||
(PR #975, #983, #985)
|
||||
* Fix port script to also copy across backfilled events (PR #982)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc3 (2016-08-02)
|
||||
===========================================
|
||||
|
||||
Changes:
|
||||
|
||||
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
|
||||
* Add some basic admin API docs (PR #963)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Send the correct host header when fetching keys (PR #941)
|
||||
* Fix joining a room that has missing auth events (PR #964)
|
||||
* Fix various push bugs (PR #966, #970)
|
||||
* Fix adding emails on registration (PR #968)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc2 (2016-08-02)
|
||||
===========================================
|
||||
|
||||
(This release did not include the changes advertised and was identical to RC1)
|
||||
|
||||
|
||||
Changes in synapse v0.17.0-rc1 (2016-07-28)
|
||||
===========================================
|
||||
|
||||
This release changes the LDAP configuration format in a backwards incompatible
|
||||
way, see PR #843 for details.
|
||||
|
||||
|
||||
Features:
|
||||
|
||||
* Add purge_media_cache admin API (PR #902)
|
||||
* Add deactivate account admin API (PR #903)
|
||||
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
|
||||
* Add an admin option to shared secret registration (breaks backwards compat)
|
||||
(PR #909)
|
||||
* Add purge local room history API (PR #911, #923, #924)
|
||||
* Add requestToken endpoints (PR #915)
|
||||
* Add an /account/deactivate endpoint (PR #921)
|
||||
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
|
||||
* Add device_id support to /login (PR #929)
|
||||
* Add device_id support to /v2/register flow. (PR #937, #942)
|
||||
* Add GET /devices endpoint (PR #939, #944)
|
||||
* Add GET /device/{deviceId} (PR #943)
|
||||
* Add update and delete APIs for devices (PR #949)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
|
||||
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
|
||||
* Remove the legacy v0 content upload API. (PR #888)
|
||||
* Use similar naming we use in email notifs for push (PR #894)
|
||||
* Optionally include password hash in createUser endpoint (PR #905 by
|
||||
KentShikama)
|
||||
* Use a query that postgresql optimises better for get_events_around (PR #906)
|
||||
* Fall back to 'username' if 'user' is not given for appservice registration.
|
||||
(PR #927 by Half-Shot)
|
||||
* Add metrics for psutil derived memory usage (PR #936)
|
||||
* Record device_id in client_ips (PR #938)
|
||||
* Send the correct host header when fetching keys (PR #941)
|
||||
* Log the hostname the reCAPTCHA was completed on (PR #946)
|
||||
* Make the device id on e2e key upload optional (PR #956)
|
||||
* Add r0.2.0 to the "supported versions" list (PR #960)
|
||||
* Don't include name of room for invites in push (PR #961)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix substitution failure in mail template (PR #887)
|
||||
* Put most recent 20 messages in email notif (PR #892)
|
||||
* Ensure that the guest user is in the database when upgrading accounts
|
||||
(PR #914)
|
||||
* Fix various edge cases in auth handling (PR #919)
|
||||
* Fix 500 ISE when sending alias event without a state_key (PR #925)
|
||||
* Fix bug where we stored rejections in the state_group, persist all
|
||||
rejections (PR #948)
|
||||
* Fix lack of check of if the user is banned when handling 3pid invites
|
||||
(PR #952)
|
||||
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
|
||||
|
||||
|
||||
|
||||
Changes in synapse v0.16.1-r1 (2016-07-08)
|
||||
==========================================
|
||||
|
||||
THIS IS A CRITICAL SECURITY UPDATE.
|
||||
|
||||
This fixes a bug which allowed users' accounts to be accessed by unauthorised
|
||||
users.
|
||||
|
||||
Changes in synapse v0.16.1 (2016-06-20)
|
||||
=======================================
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix assorted bugs in ``/preview_url`` (PR #872)
|
||||
* Fix TypeError when setting unicode passwords (PR #873)
|
||||
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Turn ``use_frozen_events`` off by default (PR #877)
|
||||
* Disable responding with canonical json for federation (PR #878)
|
||||
|
||||
|
||||
Changes in synapse v0.16.1-rc1 (2016-06-15)
|
||||
===========================================
|
||||
|
||||
Features: None
|
||||
|
||||
Changes:
|
||||
|
||||
* Log requester for ``/publicRoom`` endpoints when possible (PR #856)
|
||||
* 502 on ``/thumbnail`` when can't connect to remote server (PR #862)
|
||||
* Linearize fetching of gaps on incoming events (PR #871)
|
||||
|
||||
|
||||
Bugs fixes:
|
||||
|
||||
* Fix bug where rooms where marked as published by default (PR #857)
|
||||
* Fix bug where joining room with an event with invalid sender (PR #868)
|
||||
* Fix bug where backfilled events were sent down sync streams (PR #869)
|
||||
* Fix bug where outgoing connections could wedge indefinitely, causing push
|
||||
notifications to be unreliable (PR #870)
|
||||
|
||||
|
||||
Performance improvements:
|
||||
|
||||
* Improve ``/publicRooms`` performance(PR #859)
|
||||
|
||||
|
||||
Changes in synapse v0.16.0 (2016-06-09)
|
||||
=======================================
|
||||
|
||||
@@ -28,7 +199,7 @@ Bug fixes:
|
||||
* Fix bug where synapse sent malformed transactions to AS's when retrying
|
||||
transactions (Commits 310197b, 8437906)
|
||||
|
||||
Performance Improvements:
|
||||
Performance improvements:
|
||||
|
||||
* Remove event fetching from DB threads (PR #835)
|
||||
* Change the way we cache events (PR #836)
|
||||
|
||||
@@ -14,6 +14,7 @@ recursive-include docs *
|
||||
recursive-include res *
|
||||
recursive-include scripts *
|
||||
recursive-include scripts-dev *
|
||||
recursive-include synapse *.pyi
|
||||
recursive-include tests *.py
|
||||
|
||||
recursive-include synapse/static *.css
|
||||
@@ -23,5 +24,7 @@ recursive-include synapse/static *.js
|
||||
|
||||
exclude jenkins.sh
|
||||
exclude jenkins*.sh
|
||||
exclude jenkins*
|
||||
recursive-exclude jenkins *.sh
|
||||
|
||||
prune demo/etc
|
||||
|
||||
@@ -11,8 +11,8 @@ VoIP. The basics you need to know to get up and running are:
|
||||
like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
|
||||
|
||||
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
|
||||
you will normally refer to yourself and others using a 3PID: email
|
||||
address, phone number, etc rather than manipulating Matrix user IDs)
|
||||
you will normally refer to yourself and others using a third party identifier
|
||||
(3PID): email address, phone number, etc rather than manipulating Matrix user IDs)
|
||||
|
||||
The overall architecture is::
|
||||
|
||||
@@ -445,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
|
||||
IDs:
|
||||
|
||||
1) Use the machine's own hostname as available on public DNS in the form of
|
||||
its A or AAAA records. This is easier to set up initially, perhaps for
|
||||
its A records. This is easier to set up initially, perhaps for
|
||||
testing, but lacks the flexibility of SRV.
|
||||
|
||||
2) Set up a SRV record for your domain name. This requires you create a SRV
|
||||
|
||||
@@ -27,7 +27,7 @@ running:
|
||||
# Pull the latest version of the master branch.
|
||||
git pull
|
||||
# Update the versions of synapse's python dependencies.
|
||||
python synapse/python_dependencies.py | xargs -n1 pip install
|
||||
python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
|
||||
|
||||
|
||||
Upgrading to v0.15.0
|
||||
|
||||
12
docs/admin_api/README.rst
Normal file
12
docs/admin_api/README.rst
Normal file
@@ -0,0 +1,12 @@
|
||||
Admin APIs
|
||||
==========
|
||||
|
||||
This directory includes documentation for the various synapse specific admin
|
||||
APIs available.
|
||||
|
||||
Only users that are server admins can use these APIs. A user can be marked as a
|
||||
server admin by updating the database directly, e.g.:
|
||||
|
||||
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
|
||||
|
||||
Restarting may be required for the changes to register.
|
||||
15
docs/admin_api/purge_history_api.rst
Normal file
15
docs/admin_api/purge_history_api.rst
Normal file
@@ -0,0 +1,15 @@
|
||||
Purge History API
|
||||
=================
|
||||
|
||||
The purge history API allows server admins to purge historic events from their
|
||||
database, reclaiming disk space.
|
||||
|
||||
Depending on the amount of history being purged a call to the API may take
|
||||
several minutes or longer. During this period users will not be able to
|
||||
paginate further back in the room from the point being purged from.
|
||||
|
||||
The API is simply:
|
||||
|
||||
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
|
||||
|
||||
including an ``access_token`` of a server admin.
|
||||
19
docs/admin_api/purge_remote_media.rst
Normal file
19
docs/admin_api/purge_remote_media.rst
Normal file
@@ -0,0 +1,19 @@
|
||||
Purge Remote Media API
|
||||
======================
|
||||
|
||||
The purge remote media API allows server admins to purge old cached remote
|
||||
media.
|
||||
|
||||
The API is::
|
||||
|
||||
POST /_matrix/client/r0/admin/purge_media_cache
|
||||
|
||||
{
|
||||
"before_ts": <unix_timestamp_in_ms>
|
||||
}
|
||||
|
||||
Which will remove all cached media that was last accessed before
|
||||
``<unix_timestamp_in_ms>``.
|
||||
|
||||
If the user re-requests purged remote media, synapse will re-request the media
|
||||
from the originating server.
|
||||
@@ -43,7 +43,10 @@ Basically, PEP8
|
||||
together, or want to deliberately extend or preserve vertical/horizontal
|
||||
space)
|
||||
|
||||
Comments should follow the google code style. This is so that we can generate
|
||||
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/)
|
||||
Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
|
||||
This is so that we can generate documentation with
|
||||
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
|
||||
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
|
||||
in the sphinx documentation.
|
||||
|
||||
Code should pass pep8 --max-line-length=100 without any warnings.
|
||||
|
||||
@@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
|
||||
server through the use of a secret shared between the Home Server and the
|
||||
TURN server.
|
||||
|
||||
This document described how to install coturn
|
||||
(https://code.google.com/p/coturn/) which also supports the TURN REST API,
|
||||
This document describes how to install coturn
|
||||
(https://github.com/coturn/coturn) which also supports the TURN REST API,
|
||||
and integrate it with synapse.
|
||||
|
||||
coturn Setup
|
||||
============
|
||||
|
||||
You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
|
||||
|
||||
1. Check out coturn::
|
||||
svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
|
||||
|
||||
git clone https://github.com/coturn/coturn.git coturn
|
||||
cd coturn
|
||||
|
||||
2. Configure it::
|
||||
|
||||
./configure
|
||||
|
||||
You may need to install libevent2: if so, you should do so
|
||||
You may need to install ``libevent2``: if so, you should do so
|
||||
in the way recommended by your operating system.
|
||||
You can ignore warnings about lack of database support: a
|
||||
database is unnecessary for this purpose.
|
||||
|
||||
3. Build and install it::
|
||||
|
||||
make
|
||||
make install
|
||||
|
||||
4. Make a config file in /etc/turnserver.conf. You can customise
|
||||
a config file from turnserver.conf.default. The relevant
|
||||
4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
|
||||
lines, with example values, are::
|
||||
|
||||
lt-cred-mech
|
||||
@@ -41,7 +45,7 @@ coturn Setup
|
||||
static-auth-secret=[your secret key here]
|
||||
realm=turn.myserver.org
|
||||
|
||||
See turnserver.conf.default for explanations of the options.
|
||||
See turnserver.conf for explanations of the options.
|
||||
One way to generate the static-auth-secret is with pwgen::
|
||||
|
||||
pwgen -s 64 1
|
||||
@@ -54,6 +58,7 @@ coturn Setup
|
||||
import your private key and certificate.
|
||||
|
||||
7. Start the turn server::
|
||||
|
||||
bin/turnserver -o
|
||||
|
||||
|
||||
|
||||
@@ -4,81 +4,19 @@ set -eux
|
||||
|
||||
: ${WORKSPACE:="$(pwd)"}
|
||||
|
||||
export WORKSPACE
|
||||
export PYTHONDONTWRITEBYTECODE=yep
|
||||
export SYNAPSE_CACHE_FACTOR=1
|
||||
|
||||
# Output test results as junit xml
|
||||
export TRIAL_FLAGS="--reporter=subunit"
|
||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
||||
# Write coverage reports to a separate file for each process
|
||||
export COVERAGE_OPTS="-p"
|
||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
||||
./jenkins/prepare_synapse.sh
|
||||
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
|
||||
./dendron/jenkins/build_dendron.sh
|
||||
./sytest/jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
# Output flake8 violations to violations.flake8.log
|
||||
# Don't exit with non-0 status code on Jenkins,
|
||||
# so that the build steps continue and a later step can decided whether to
|
||||
# UNSTABLE or FAILURE this build.
|
||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
||||
|
||||
rm .coverage* || echo "No coverage files to remove"
|
||||
|
||||
tox --notest -e py27
|
||||
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install psycopg2
|
||||
$TOX_BIN/pip install lxml
|
||||
|
||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||
|
||||
if [[ ! -e .dendron-base ]]; then
|
||||
git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror
|
||||
else
|
||||
(cd .dendron-base; git fetch -p)
|
||||
fi
|
||||
|
||||
rm -rf dendron
|
||||
git clone .dendron-base dendron --shared
|
||||
cd dendron
|
||||
|
||||
: ${GOPATH:=${WORKSPACE}/.gopath}
|
||||
if [[ "${GOPATH}" != *:* ]]; then
|
||||
mkdir -p "${GOPATH}"
|
||||
export PATH="${GOPATH}/bin:${PATH}"
|
||||
fi
|
||||
export GOPATH
|
||||
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
go get github.com/constabulary/gb/...
|
||||
gb generate
|
||||
gb build
|
||||
|
||||
cd ..
|
||||
|
||||
|
||||
if [[ ! -e .sytest-base ]]; then
|
||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
||||
else
|
||||
(cd .sytest-base; git fetch -p)
|
||||
fi
|
||||
|
||||
rm -rf sytest
|
||||
git clone .sytest-base sytest --shared
|
||||
cd sytest
|
||||
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8000}
|
||||
|
||||
./jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
echo >&2 "Running sytest with PostgreSQL";
|
||||
./jenkins/install_and_run.sh --python $TOX_BIN/python \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--dendron $WORKSPACE/dendron/bin/dendron \
|
||||
--synchrotron \
|
||||
--pusher \
|
||||
--port-base $PORT_BASE
|
||||
|
||||
cd ..
|
||||
./sytest/jenkins/install_and_run.sh \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--dendron $WORKSPACE/dendron/bin/dendron \
|
||||
--pusher \
|
||||
--synchrotron \
|
||||
--federation-reader \
|
||||
|
||||
@@ -4,60 +4,14 @@ set -eux
|
||||
|
||||
: ${WORKSPACE:="$(pwd)"}
|
||||
|
||||
export WORKSPACE
|
||||
export PYTHONDONTWRITEBYTECODE=yep
|
||||
export SYNAPSE_CACHE_FACTOR=1
|
||||
|
||||
# Output test results as junit xml
|
||||
export TRIAL_FLAGS="--reporter=subunit"
|
||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
||||
# Write coverage reports to a separate file for each process
|
||||
export COVERAGE_OPTS="-p"
|
||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
||||
./jenkins/prepare_synapse.sh
|
||||
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||
|
||||
# Output flake8 violations to violations.flake8.log
|
||||
# Don't exit with non-0 status code on Jenkins,
|
||||
# so that the build steps continue and a later step can decided whether to
|
||||
# UNSTABLE or FAILURE this build.
|
||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
||||
./sytest/jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
rm .coverage* || echo "No coverage files to remove"
|
||||
|
||||
tox --notest -e py27
|
||||
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install psycopg2
|
||||
$TOX_BIN/pip install lxml
|
||||
|
||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||
|
||||
if [[ ! -e .sytest-base ]]; then
|
||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
||||
else
|
||||
(cd .sytest-base; git fetch -p)
|
||||
fi
|
||||
|
||||
rm -rf sytest
|
||||
git clone .sytest-base sytest --shared
|
||||
cd sytest
|
||||
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8000}
|
||||
|
||||
./jenkins/prep_sytest_for_postgres.sh
|
||||
|
||||
echo >&2 "Running sytest with PostgreSQL";
|
||||
./jenkins/install_and_run.sh --coverage \
|
||||
--python $TOX_BIN/python \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--port-base $PORT_BASE
|
||||
|
||||
cd ..
|
||||
cp sytest/.coverage.* .
|
||||
|
||||
# Combine the coverage reports
|
||||
echo "Combining:" .coverage.*
|
||||
$TOX_BIN/python -m coverage combine
|
||||
# Output coverage to coverage.xml
|
||||
$TOX_BIN/coverage xml -o coverage.xml
|
||||
./sytest/jenkins/install_and_run.sh \
|
||||
--synapse-directory $WORKSPACE \
|
||||
|
||||
@@ -4,54 +4,12 @@ set -eux
|
||||
|
||||
: ${WORKSPACE:="$(pwd)"}
|
||||
|
||||
export WORKSPACE
|
||||
export PYTHONDONTWRITEBYTECODE=yep
|
||||
export SYNAPSE_CACHE_FACTOR=1
|
||||
|
||||
# Output test results as junit xml
|
||||
export TRIAL_FLAGS="--reporter=subunit"
|
||||
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
|
||||
# Write coverage reports to a separate file for each process
|
||||
export COVERAGE_OPTS="-p"
|
||||
export DUMP_COVERAGE_COMMAND="coverage help"
|
||||
./jenkins/prepare_synapse.sh
|
||||
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
|
||||
|
||||
# Output flake8 violations to violations.flake8.log
|
||||
# Don't exit with non-0 status code on Jenkins,
|
||||
# so that the build steps continue and a later step can decided whether to
|
||||
# UNSTABLE or FAILURE this build.
|
||||
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
|
||||
|
||||
rm .coverage* || echo "No coverage files to remove"
|
||||
|
||||
tox --notest -e py27
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
|
||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||
|
||||
if [[ ! -e .sytest-base ]]; then
|
||||
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
|
||||
else
|
||||
(cd .sytest-base; git fetch -p)
|
||||
fi
|
||||
|
||||
rm -rf sytest
|
||||
git clone .sytest-base sytest --shared
|
||||
cd sytest
|
||||
|
||||
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
|
||||
|
||||
: ${PORT_BASE:=8500}
|
||||
./jenkins/install_and_run.sh --coverage \
|
||||
--python $TOX_BIN/python \
|
||||
--synapse-directory $WORKSPACE \
|
||||
--port-base $PORT_BASE
|
||||
|
||||
cd ..
|
||||
cp sytest/.coverage.* .
|
||||
|
||||
# Combine the coverage reports
|
||||
echo "Combining:" .coverage.*
|
||||
$TOX_BIN/python -m coverage combine
|
||||
# Output coverage to coverage.xml
|
||||
$TOX_BIN/coverage xml -o coverage.xml
|
||||
./sytest/jenkins/install_and_run.sh \
|
||||
--synapse-directory $WORKSPACE \
|
||||
|
||||
@@ -22,4 +22,8 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
|
||||
|
||||
rm .coverage* || echo "No coverage files to remove"
|
||||
|
||||
tox --notest -e py27
|
||||
TOX_BIN=$WORKSPACE/.tox/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
|
||||
tox -e py27
|
||||
|
||||
44
jenkins/clone.sh
Executable file
44
jenkins/clone.sh
Executable file
@@ -0,0 +1,44 @@
|
||||
#! /bin/bash
|
||||
|
||||
# This clones a project from github into a named subdirectory
|
||||
# If the project has a branch with the same name as this branch
|
||||
# then it will checkout that branch after cloning.
|
||||
# Otherwise it will checkout "origin/develop."
|
||||
# The first argument is the name of the directory to checkout
|
||||
# the branch into.
|
||||
# The second argument is the URL of the remote repository to checkout.
|
||||
# Usually something like https://github.com/matrix-org/sytest.git
|
||||
|
||||
set -eux
|
||||
|
||||
NAME=$1
|
||||
PROJECT=$2
|
||||
BASE=".$NAME-base"
|
||||
|
||||
# Update our mirror.
|
||||
if [ ! -d ".$NAME-base" ]; then
|
||||
# Create a local mirror of the source repository.
|
||||
# This saves us from having to download the entire repository
|
||||
# when this script is next run.
|
||||
git clone "$PROJECT" "$BASE" --mirror
|
||||
else
|
||||
# Fetch any updates from the source repository.
|
||||
(cd "$BASE"; git fetch -p)
|
||||
fi
|
||||
|
||||
# Remove the existing repository so that we have a clean copy
|
||||
rm -rf "$NAME"
|
||||
# Cloning with --shared means that we will share portions of the
|
||||
# .git directory with our local mirror.
|
||||
git clone "$BASE" "$NAME" --shared
|
||||
|
||||
# Jenkins may have supplied us with the name of the branch in the
|
||||
# environment. Otherwise we will have to guess based on the current
|
||||
# commit.
|
||||
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
|
||||
cd "$NAME"
|
||||
# check out the relevant branch
|
||||
git checkout "${GIT_BRANCH}" || (
|
||||
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
|
||||
git checkout "origin/develop"
|
||||
)
|
||||
19
jenkins/prepare_synapse.sh
Executable file
19
jenkins/prepare_synapse.sh
Executable file
@@ -0,0 +1,19 @@
|
||||
#! /bin/bash
|
||||
|
||||
cd "`dirname $0`/.."
|
||||
|
||||
TOX_DIR=$WORKSPACE/.tox
|
||||
|
||||
mkdir -p $TOX_DIR
|
||||
|
||||
if ! [ $TOX_DIR -ef .tox ]; then
|
||||
ln -s "$TOX_DIR" .tox
|
||||
fi
|
||||
|
||||
# set up the virtualenv
|
||||
tox -e py27 --notest -v
|
||||
|
||||
TOX_BIN=$TOX_DIR/py27/bin
|
||||
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
|
||||
$TOX_BIN/pip install lxml
|
||||
$TOX_BIN/pip install psycopg2
|
||||
@@ -36,7 +36,7 @@
|
||||
<div class="debug">
|
||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
||||
an event was received at {{ reason.received_at|format_ts("%c") }}
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago,
|
||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
|
||||
{% if reason.last_sent_ts %}
|
||||
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
||||
|
||||
@@ -116,17 +116,19 @@ def get_json(origin_name, origin_key, destination, path):
|
||||
authorization_headers = []
|
||||
|
||||
for key, sig in signed_json["signatures"][origin_name].items():
|
||||
authorization_headers.append(bytes(
|
||||
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
origin_name, key, sig,
|
||||
)
|
||||
))
|
||||
header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
|
||||
origin_name, key, sig,
|
||||
)
|
||||
authorization_headers.append(bytes(header))
|
||||
sys.stderr.write(header)
|
||||
sys.stderr.write("\n")
|
||||
|
||||
result = requests.get(
|
||||
lookup(destination, path),
|
||||
headers={"Authorization": authorization_headers[0]},
|
||||
verify=False,
|
||||
)
|
||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||
return result.json()
|
||||
|
||||
|
||||
@@ -141,6 +143,7 @@ def main():
|
||||
)
|
||||
|
||||
json.dump(result, sys.stdout)
|
||||
print ""
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
|
||||
import sys
|
||||
|
||||
import bcrypt
|
||||
import getpass
|
||||
|
||||
import yaml
|
||||
|
||||
bcrypt_rounds=12
|
||||
password_pepper = ""
|
||||
|
||||
def prompt_for_pass():
|
||||
password = getpass.getpass("Password: ")
|
||||
@@ -28,12 +34,22 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
type=argparse.FileType('r'),
|
||||
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if "config" in args and args.config:
|
||||
config = yaml.safe_load(args.config)
|
||||
bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
|
||||
password_config = config.get("password_config", {})
|
||||
password_pepper = password_config.get("pepper", password_pepper)
|
||||
password = args.password
|
||||
|
||||
if not password:
|
||||
password = prompt_for_pass()
|
||||
|
||||
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
|
||||
print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))
|
||||
|
||||
|
||||
@@ -25,18 +25,26 @@ import urllib2
|
||||
import yaml
|
||||
|
||||
|
||||
def request_registration(user, password, server_location, shared_secret):
|
||||
def request_registration(user, password, server_location, shared_secret, admin=False):
|
||||
mac = hmac.new(
|
||||
key=shared_secret,
|
||||
msg=user,
|
||||
digestmod=hashlib.sha1,
|
||||
).hexdigest()
|
||||
)
|
||||
|
||||
mac.update(user)
|
||||
mac.update("\x00")
|
||||
mac.update(password)
|
||||
mac.update("\x00")
|
||||
mac.update("admin" if admin else "notadmin")
|
||||
|
||||
mac = mac.hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
"admin": admin,
|
||||
}
|
||||
|
||||
server_location = server_location.rstrip("/")
|
||||
@@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def register_new_user(user, password, server_location, shared_secret):
|
||||
def register_new_user(user, password, server_location, shared_secret, admin):
|
||||
if not user:
|
||||
try:
|
||||
default_user = getpass.getuser()
|
||||
@@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
|
||||
print "Passwords do not match"
|
||||
sys.exit(1)
|
||||
|
||||
request_registration(user, password, server_location, shared_secret)
|
||||
if not admin:
|
||||
admin = raw_input("Make admin [no]: ")
|
||||
if admin in ("y", "yes", "true"):
|
||||
admin = True
|
||||
else:
|
||||
admin = False
|
||||
|
||||
request_registration(user, password, server_location, shared_secret, bool(admin))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -119,6 +134,11 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="New password for user. Will prompt if omitted.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--admin",
|
||||
action="store_true",
|
||||
help="Register new user as an admin. Will prompt if omitted.",
|
||||
)
|
||||
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
@@ -151,4 +171,4 @@ if __name__ == "__main__":
|
||||
else:
|
||||
secret = args.shared_secret
|
||||
|
||||
register_new_user(args.user, args.password, args.server_url, secret)
|
||||
register_new_user(args.user, args.password, args.server_url, secret, args.admin)
|
||||
|
||||
@@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
|
||||
|
||||
|
||||
BOOLEAN_COLUMNS = {
|
||||
"events": ["processed", "outlier"],
|
||||
"events": ["processed", "outlier", "contains_url"],
|
||||
"rooms": ["is_public"],
|
||||
"event_edges": ["is_state"],
|
||||
"presence_list": ["accepted"],
|
||||
@@ -92,8 +92,12 @@ class Store(object):
|
||||
|
||||
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
|
||||
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
|
||||
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
|
||||
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
|
||||
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
|
||||
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
|
||||
_simple_select_one_onecol_txn = SQLBaseStore.__dict__[
|
||||
"_simple_select_one_onecol_txn"
|
||||
]
|
||||
|
||||
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
||||
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
||||
@@ -158,31 +162,40 @@ class Porter(object):
|
||||
def setup_table(self, table):
|
||||
if table in APPEND_ONLY_TABLES:
|
||||
# It's safe to just carry on inserting.
|
||||
next_chunk = yield self.postgres_store._simple_select_one_onecol(
|
||||
row = yield self.postgres_store._simple_select_one(
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
retcol="rowid",
|
||||
retcols=("forward_rowid", "backward_rowid"),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
total_to_port = None
|
||||
if next_chunk is None:
|
||||
if row is None:
|
||||
if table == "sent_transactions":
|
||||
next_chunk, already_ported, total_to_port = (
|
||||
forward_chunk, already_ported, total_to_port = (
|
||||
yield self._setup_sent_transactions()
|
||||
)
|
||||
backward_chunk = 0
|
||||
else:
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": table, "rowid": 1}
|
||||
values={
|
||||
"table_name": table,
|
||||
"forward_rowid": 1,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
)
|
||||
|
||||
next_chunk = 1
|
||||
forward_chunk = 1
|
||||
backward_chunk = 0
|
||||
already_ported = 0
|
||||
else:
|
||||
forward_chunk = row["forward_rowid"]
|
||||
backward_chunk = row["backward_rowid"]
|
||||
|
||||
if total_to_port is None:
|
||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||
table, next_chunk
|
||||
table, forward_chunk, backward_chunk
|
||||
)
|
||||
else:
|
||||
def delete_all(txn):
|
||||
@@ -196,46 +209,85 @@ class Porter(object):
|
||||
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": table, "rowid": 0}
|
||||
values={
|
||||
"table_name": table,
|
||||
"forward_rowid": 1,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
)
|
||||
|
||||
next_chunk = 1
|
||||
forward_chunk = 1
|
||||
backward_chunk = 0
|
||||
|
||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||
table, next_chunk
|
||||
table, forward_chunk, backward_chunk
|
||||
)
|
||||
|
||||
defer.returnValue((table, already_ported, total_to_port, next_chunk))
|
||||
defer.returnValue(
|
||||
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_table(self, table, postgres_size, table_size, next_chunk):
|
||||
def handle_table(self, table, postgres_size, table_size, forward_chunk,
|
||||
backward_chunk):
|
||||
if not table_size:
|
||||
return
|
||||
|
||||
self.progress.add_table(table, postgres_size, table_size)
|
||||
|
||||
if table == "event_search":
|
||||
yield self.handle_search_table(postgres_size, table_size, next_chunk)
|
||||
yield self.handle_search_table(
|
||||
postgres_size, table_size, forward_chunk, backward_chunk
|
||||
)
|
||||
return
|
||||
|
||||
select = (
|
||||
forward_select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
)
|
||||
|
||||
backward_select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
)
|
||||
|
||||
do_forward = [True]
|
||||
do_backward = [True]
|
||||
|
||||
while True:
|
||||
def r(txn):
|
||||
txn.execute(select, (next_chunk, self.batch_size,))
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
forward_rows = []
|
||||
backward_rows = []
|
||||
if do_forward[0]:
|
||||
txn.execute(forward_select, (forward_chunk, self.batch_size,))
|
||||
forward_rows = txn.fetchall()
|
||||
if not forward_rows:
|
||||
do_forward[0] = False
|
||||
|
||||
return headers, rows
|
||||
if do_backward[0]:
|
||||
txn.execute(backward_select, (backward_chunk, self.batch_size,))
|
||||
backward_rows = txn.fetchall()
|
||||
if not backward_rows:
|
||||
do_backward[0] = False
|
||||
|
||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||
if forward_rows or backward_rows:
|
||||
headers = [column[0] for column in txn.description]
|
||||
else:
|
||||
headers = None
|
||||
|
||||
if rows:
|
||||
next_chunk = rows[-1][0] + 1
|
||||
return headers, forward_rows, backward_rows
|
||||
|
||||
headers, frows, brows = yield self.sqlite_store.runInteraction(
|
||||
"select", r
|
||||
)
|
||||
|
||||
if frows or brows:
|
||||
if frows:
|
||||
forward_chunk = max(row[0] for row in frows) + 1
|
||||
if brows:
|
||||
backward_chunk = min(row[0] for row in brows) - 1
|
||||
|
||||
rows = frows + brows
|
||||
self._convert_rows(table, headers, rows)
|
||||
|
||||
def insert(txn):
|
||||
@@ -247,7 +299,10 @@ class Porter(object):
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
updatevalues={
|
||||
"forward_rowid": forward_chunk,
|
||||
"backward_rowid": backward_chunk,
|
||||
},
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
@@ -259,7 +314,8 @@ class Porter(object):
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_search_table(self, postgres_size, table_size, next_chunk):
|
||||
def handle_search_table(self, postgres_size, table_size, forward_chunk,
|
||||
backward_chunk):
|
||||
select = (
|
||||
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
|
||||
" FROM event_search as es"
|
||||
@@ -270,7 +326,7 @@ class Porter(object):
|
||||
|
||||
while True:
|
||||
def r(txn):
|
||||
txn.execute(select, (next_chunk, self.batch_size,))
|
||||
txn.execute(select, (forward_chunk, self.batch_size,))
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
|
||||
@@ -279,7 +335,7 @@ class Porter(object):
|
||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||
|
||||
if rows:
|
||||
next_chunk = rows[-1][0] + 1
|
||||
forward_chunk = rows[-1][0] + 1
|
||||
|
||||
# We have to treat event_search differently since it has a
|
||||
# different structure in the two different databases.
|
||||
@@ -312,7 +368,10 @@ class Porter(object):
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": "event_search"},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
updatevalues={
|
||||
"forward_rowid": forward_chunk,
|
||||
"backward_rowid": backward_chunk,
|
||||
},
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
@@ -324,7 +383,6 @@ class Porter(object):
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def setup_db(self, db_config, database_engine):
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
@@ -395,10 +453,32 @@ class Porter(object):
|
||||
txn.execute(
|
||||
"CREATE TABLE port_from_sqlite3 ("
|
||||
" table_name varchar(100) NOT NULL UNIQUE,"
|
||||
" rowid bigint NOT NULL"
|
||||
" forward_rowid bigint NOT NULL,"
|
||||
" backward_rowid bigint NOT NULL"
|
||||
")"
|
||||
)
|
||||
|
||||
# The old port script created a table with just a "rowid" column.
|
||||
# We want people to be able to rerun this script from an old port
|
||||
# so that they can pick up any missing events that were not
|
||||
# ported across.
|
||||
def alter_table(txn):
|
||||
txn.execute(
|
||||
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||
" RENAME rowid TO forward_rowid"
|
||||
)
|
||||
txn.execute(
|
||||
"ALTER TABLE IF EXISTS port_from_sqlite3"
|
||||
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.postgres_store.runInteraction(
|
||||
"alter_table", alter_table
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("Failed to create port table: %s", e)
|
||||
|
||||
try:
|
||||
yield self.postgres_store.runInteraction(
|
||||
"create_port_table", create_port_table
|
||||
@@ -458,7 +538,7 @@ class Porter(object):
|
||||
@defer.inlineCallbacks
|
||||
def _setup_sent_transactions(self):
|
||||
# Only save things from the last day
|
||||
yesterday = int(time.time()*1000) - 86400000
|
||||
yesterday = int(time.time() * 1000) - 86400000
|
||||
|
||||
# And save the max transaction id from each destination
|
||||
select = (
|
||||
@@ -514,7 +594,11 @@ class Porter(object):
|
||||
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": "sent_transactions", "rowid": next_chunk}
|
||||
values={
|
||||
"table_name": "sent_transactions",
|
||||
"forward_rowid": next_chunk,
|
||||
"backward_rowid": 0,
|
||||
}
|
||||
)
|
||||
|
||||
def get_sent_table_size(txn):
|
||||
@@ -535,13 +619,18 @@ class Porter(object):
|
||||
defer.returnValue((next_chunk, inserted_rows, total_count))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remaining_count_to_port(self, table, next_chunk):
|
||||
rows = yield self.sqlite_store.execute_sql(
|
||||
def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||
frows = yield self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
|
||||
next_chunk,
|
||||
forward_chunk,
|
||||
)
|
||||
|
||||
defer.returnValue(rows[0][0])
|
||||
brows = yield self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
|
||||
backward_chunk,
|
||||
)
|
||||
|
||||
defer.returnValue(frows[0][0] + brows[0][0])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_already_ported_count(self, table):
|
||||
@@ -552,10 +641,10 @@ class Porter(object):
|
||||
defer.returnValue(rows[0][0])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_total_count_to_port(self, table, next_chunk):
|
||||
def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
|
||||
remaining, done = yield defer.gatherResults(
|
||||
[
|
||||
self._get_remaining_count_to_port(table, next_chunk),
|
||||
self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
|
||||
self._get_already_ported_count(table),
|
||||
],
|
||||
consumeErrors=True,
|
||||
@@ -686,7 +775,7 @@ class CursesProgress(Progress):
|
||||
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i+2, left_margin + max_len - len(table),
|
||||
i + 2, left_margin + max_len - len(table),
|
||||
table,
|
||||
curses.A_BOLD | color,
|
||||
)
|
||||
@@ -694,18 +783,18 @@ class CursesProgress(Progress):
|
||||
size = 20
|
||||
|
||||
progress = "[%s%s]" % (
|
||||
"#" * int(perc*size/100),
|
||||
" " * (size - int(perc*size/100)),
|
||||
"#" * int(perc * size / 100),
|
||||
" " * (size - int(perc * size / 100)),
|
||||
)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i+2, left_margin + max_len + middle_space,
|
||||
i + 2, left_margin + max_len + middle_space,
|
||||
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
|
||||
)
|
||||
|
||||
if self.finished:
|
||||
self.stdscr.addstr(
|
||||
rows-1, 0,
|
||||
rows - 1, 0,
|
||||
"Press any key to exit...",
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,5 @@ ignore =
|
||||
|
||||
[flake8]
|
||||
max-line-length = 90
|
||||
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||
|
||||
[pep8]
|
||||
max-line-length = 90
|
||||
# W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
|
||||
ignore = W503
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.16.0"
|
||||
__version__ = "0.17.0"
|
||||
|
||||
@@ -13,22 +13,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import pymacaroons
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json, SignatureVerifyException
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import Requester, UserID, get_domain_from_id
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.metrics import Measure
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
import logging
|
||||
import pymacaroons
|
||||
import synapse.types
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,7 +63,7 @@ class Auth(object):
|
||||
"user_id = ",
|
||||
])
|
||||
|
||||
def check(self, event, auth_events):
|
||||
def check(self, event, auth_events, do_sig_check=True):
|
||||
""" Checks if this event is correctly authed.
|
||||
|
||||
Args:
|
||||
@@ -79,6 +79,13 @@ class Auth(object):
|
||||
|
||||
if not hasattr(event, "room_id"):
|
||||
raise AuthError(500, "Event has no room_id: %s" % event)
|
||||
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
|
||||
# Check the sender's domain has signed the event
|
||||
if do_sig_check and not event.signatures.get(sender_domain):
|
||||
raise AuthError(403, "Event not signed by sending server")
|
||||
|
||||
if auth_events is None:
|
||||
# Oh, we don't know what the state of the room was, so we
|
||||
# are trusting that this is allowed (at least for now)
|
||||
@@ -86,6 +93,12 @@ class Auth(object):
|
||||
return True
|
||||
|
||||
if event.type == EventTypes.Create:
|
||||
room_id_domain = get_domain_from_id(event.room_id)
|
||||
if room_id_domain != sender_domain:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Creation event's room_id domain does not match sender's"
|
||||
)
|
||||
# FIXME
|
||||
return True
|
||||
|
||||
@@ -108,6 +121,22 @@ class Auth(object):
|
||||
|
||||
# FIXME: Temp hack
|
||||
if event.type == EventTypes.Aliases:
|
||||
if not event.is_state():
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event must be a state event",
|
||||
)
|
||||
if not event.state_key:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event must have non-empty state_key"
|
||||
)
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
if event.state_key != sender_domain:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Alias event's state_key does not match sender's domain"
|
||||
)
|
||||
return True
|
||||
|
||||
logger.debug(
|
||||
@@ -347,6 +376,10 @@ class Auth(object):
|
||||
if Membership.INVITE == membership and "third_party_invite" in event.content:
|
||||
if not self._verify_third_party_invite(event, auth_events):
|
||||
raise AuthError(403, "You are not invited to this room.")
|
||||
if target_banned:
|
||||
raise AuthError(
|
||||
403, "%s is banned from the room" % (target_user_id,)
|
||||
)
|
||||
return True
|
||||
|
||||
if Membership.JOIN != membership:
|
||||
@@ -537,9 +570,7 @@ class Auth(object):
|
||||
Args:
|
||||
request - An HTTP request with an access_token query parameter.
|
||||
Returns:
|
||||
tuple of:
|
||||
UserID (str)
|
||||
Access token ID (str)
|
||||
defer.Deferred: resolves to a ``synapse.types.Requester`` object
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
@@ -548,9 +579,7 @@ class Auth(object):
|
||||
user_id = yield self._get_appservice_user_id(request.args)
|
||||
if user_id:
|
||||
request.authenticated_entity = user_id
|
||||
defer.returnValue(
|
||||
Requester(UserID.from_string(user_id), "", False)
|
||||
)
|
||||
defer.returnValue(synapse.types.create_requester(user_id))
|
||||
|
||||
access_token = request.args["access_token"][0]
|
||||
user_info = yield self.get_user_by_access_token(access_token, rights)
|
||||
@@ -558,6 +587,10 @@ class Auth(object):
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
|
||||
# device_id may not be present if get_user_by_access_token has been
|
||||
# stubbed out.
|
||||
device_id = user_info.get("device_id")
|
||||
|
||||
ip_addr = self.hs.get_ip_from_request(request)
|
||||
user_agent = request.requestHeaders.getRawHeaders(
|
||||
"User-Agent",
|
||||
@@ -569,7 +602,8 @@ class Auth(object):
|
||||
user=user,
|
||||
access_token=access_token,
|
||||
ip=ip_addr,
|
||||
user_agent=user_agent
|
||||
user_agent=user_agent,
|
||||
device_id=device_id,
|
||||
)
|
||||
|
||||
if is_guest and not allow_guest:
|
||||
@@ -579,7 +613,8 @@ class Auth(object):
|
||||
|
||||
request.authenticated_entity = user.to_string()
|
||||
|
||||
defer.returnValue(Requester(user, token_id, is_guest))
|
||||
defer.returnValue(synapse.types.create_requester(
|
||||
user, token_id, is_guest, device_id))
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
@@ -629,7 +664,10 @@ class Auth(object):
|
||||
except AuthError:
|
||||
# TODO(daniel): Remove this fallback when all existing access tokens
|
||||
# have been re-issued as macaroons.
|
||||
if self.hs.config.expire_access_token:
|
||||
raise
|
||||
ret = yield self._look_up_user_by_access_token(token)
|
||||
|
||||
defer.returnValue(ret)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -637,17 +675,22 @@ class Auth(object):
|
||||
try:
|
||||
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
|
||||
|
||||
self.validate_macaroon(macaroon, rights, self.hs.config.expire_access_token)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
user_id = None
|
||||
guest = False
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
|
||||
user_id = caveat.caveat_id[len(user_prefix):]
|
||||
user = UserID.from_string(user_id)
|
||||
elif caveat.caveat_id == "guest = true":
|
||||
guest = True
|
||||
|
||||
self.validate_macaroon(
|
||||
macaroon, rights, self.hs.config.expire_access_token,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
@@ -659,6 +702,7 @@ class Auth(object):
|
||||
"user": user,
|
||||
"is_guest": True,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
}
|
||||
elif rights == "delete_pusher":
|
||||
# We don't store these tokens in the database
|
||||
@@ -666,13 +710,20 @@ class Auth(object):
|
||||
"user": user,
|
||||
"is_guest": False,
|
||||
"token_id": None,
|
||||
"device_id": None,
|
||||
}
|
||||
else:
|
||||
# This codepath exists so that we can actually return a
|
||||
# token ID, because we use token IDs in place of device
|
||||
# identifiers throughout the codebase.
|
||||
# TODO(daniel): Remove this fallback when device IDs are
|
||||
# properly implemented.
|
||||
# This codepath exists for several reasons:
|
||||
# * so that we can actually return a token ID, which is used
|
||||
# in some parts of the schema (where we probably ought to
|
||||
# use device IDs instead)
|
||||
# * the only way we currently have to invalidate an
|
||||
# access_token is by removing it from the database, so we
|
||||
# have to check here that it is still in the db
|
||||
# * some attributes (notably device_id) aren't stored in the
|
||||
# macaroon. They probably should be.
|
||||
# TODO: build the dictionary from the macaroon once the
|
||||
# above are fixed
|
||||
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
||||
if ret["user"] != user:
|
||||
logger.error(
|
||||
@@ -692,7 +743,7 @@ class Auth(object):
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry):
|
||||
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
|
||||
"""
|
||||
validate that a Macaroon is understood by and was signed by this server.
|
||||
|
||||
@@ -707,7 +758,7 @@ class Auth(object):
|
||||
v = pymacaroons.Verifier()
|
||||
v.satisfy_exact("gen = 1")
|
||||
v.satisfy_exact("type = " + type_string)
|
||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||
v.satisfy_exact("user_id = %s" % user_id)
|
||||
v.satisfy_exact("guest = true")
|
||||
if verify_expiry:
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
@@ -746,10 +797,14 @@ class Auth(object):
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
# we use ret.get() below because *lots* of unit tests stub out
|
||||
# get_user_by_access_token in a way where it only returns a couple of
|
||||
# the fields.
|
||||
user_info = {
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
"is_guest": False,
|
||||
"device_id": ret.get("device_id"),
|
||||
}
|
||||
defer.returnValue(user_info)
|
||||
|
||||
|
||||
@@ -42,8 +42,10 @@ class Codes(object):
|
||||
TOO_LARGE = "M_TOO_LARGE"
|
||||
EXCLUSIVE = "M_EXCLUSIVE"
|
||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
THREEPID_IN_USE = "THREEPID_IN_USE"
|
||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
||||
@@ -191,6 +191,17 @@ class Filter(object):
|
||||
def __init__(self, filter_json):
|
||||
self.filter_json = filter_json
|
||||
|
||||
self.types = self.filter_json.get("types", None)
|
||||
self.not_types = self.filter_json.get("not_types", [])
|
||||
|
||||
self.rooms = self.filter_json.get("rooms", None)
|
||||
self.not_rooms = self.filter_json.get("not_rooms", [])
|
||||
|
||||
self.senders = self.filter_json.get("senders", None)
|
||||
self.not_senders = self.filter_json.get("not_senders", [])
|
||||
|
||||
self.contains_url = self.filter_json.get("contains_url", None)
|
||||
|
||||
def check(self, event):
|
||||
"""Checks whether the filter matches the given event.
|
||||
|
||||
@@ -209,9 +220,10 @@ class Filter(object):
|
||||
event.get("room_id", None),
|
||||
sender,
|
||||
event.get("type", None),
|
||||
"url" in event.get("content", {})
|
||||
)
|
||||
|
||||
def check_fields(self, room_id, sender, event_type):
|
||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||
"""Checks whether the filter matches the given event fields.
|
||||
|
||||
Returns:
|
||||
@@ -225,15 +237,20 @@ class Filter(object):
|
||||
|
||||
for name, match_func in literal_keys.items():
|
||||
not_name = "not_%s" % (name,)
|
||||
disallowed_values = self.filter_json.get(not_name, [])
|
||||
disallowed_values = getattr(self, not_name)
|
||||
if any(map(match_func, disallowed_values)):
|
||||
return False
|
||||
|
||||
allowed_values = self.filter_json.get(name, None)
|
||||
allowed_values = getattr(self, name)
|
||||
if allowed_values is not None:
|
||||
if not any(map(match_func, allowed_values)):
|
||||
return False
|
||||
|
||||
contains_url_filter = self.filter_json.get("contains_url")
|
||||
if contains_url_filter is not None:
|
||||
if contains_url_filter != contains_url:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def filter_rooms(self, room_ids):
|
||||
|
||||
@@ -16,13 +16,11 @@
|
||||
import sys
|
||||
sys.dont_write_bytecode = True
|
||||
|
||||
from synapse.python_dependencies import (
|
||||
check_requirements, MissingRequirementError
|
||||
) # NOQA
|
||||
from synapse import python_dependencies # noqa: E402
|
||||
|
||||
try:
|
||||
check_requirements()
|
||||
except MissingRequirementError as e:
|
||||
python_dependencies.check_requirements()
|
||||
except python_dependencies.MissingRequirementError as e:
|
||||
message = "\n".join([
|
||||
"Missing Requirement: %s" % (e.message,),
|
||||
"To install run:",
|
||||
|
||||
206
synapse/app/federation_reader.py
Normal file
206
synapse/app/federation_reader.py
Normal file
@@ -0,0 +1,206 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import synapse
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.async import sleep
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.api.urls import FEDERATION_PREFIX
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
from synapse.crypto import context_factory
|
||||
|
||||
|
||||
from twisted.internet import reactor, defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.federation_reader")
|
||||
|
||||
|
||||
class FederationReaderSlavedStore(
|
||||
SlavedEventStore,
|
||||
SlavedKeyStore,
|
||||
RoomStore,
|
||||
DirectoryStore,
|
||||
TransactionStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class FederationReaderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_address = listener_config.get("bind_address", "")
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "federation":
|
||||
resources.update({
|
||||
FEDERATION_PREFIX: TransportLayerServer(self),
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
),
|
||||
interface=bind_address
|
||||
)
|
||||
logger.info("Synapse federation reader now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
reactor.listenTCP(
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
),
|
||||
interface=listener.get("bind_address", '127.0.0.1')
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.worker_replication_url
|
||||
|
||||
while True:
|
||||
try:
|
||||
args = store.stream_positions()
|
||||
args["timeout"] = 30000
|
||||
result = yield http_client.get_json(replication_url, args=args)
|
||||
yield store.process_replication(result)
|
||||
except:
|
||||
logger.exception("Error replicating from %r", replication_url)
|
||||
yield sleep(5)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse federation reader", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.federation_reader"
|
||||
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
ss = FederationReaderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.get_handlers()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
ss.replicate()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-federation-reader",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
@@ -51,6 +51,7 @@ from synapse.api.urls import (
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.metrics import register_memory_metrics
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
|
||||
from synapse.federation.transport.server import TransportLayerServer
|
||||
@@ -147,7 +148,7 @@ class SynapseHomeServer(HomeServer):
|
||||
MEDIA_PREFIX: media_repo,
|
||||
LEGACY_MEDIA_PREFIX: media_repo,
|
||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||
self, self.config.uploads_path, self.auth, self.content_addr
|
||||
self, self.config.uploads_path
|
||||
),
|
||||
})
|
||||
|
||||
@@ -266,10 +267,9 @@ def setup(config_options):
|
||||
HomeServer
|
||||
"""
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
config = HomeServerConfig.load_or_generate_config(
|
||||
"Synapse Homeserver",
|
||||
config_options,
|
||||
generate_section="Homeserver"
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
@@ -285,7 +285,7 @@ def setup(config_options):
|
||||
# check any extra requirements we have now we have a config
|
||||
check_requirements(config)
|
||||
|
||||
version_string = get_version_string("Synapse", synapse)
|
||||
version_string = "Synapse/" + get_version_string(synapse)
|
||||
|
||||
logger.info("Server hostname: %s", config.server_name)
|
||||
logger.info("Server version: %s", version_string)
|
||||
@@ -302,7 +302,6 @@ def setup(config_options):
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
content_addr=config.content_addr,
|
||||
version_string=version_string,
|
||||
database_engine=database_engine,
|
||||
)
|
||||
@@ -337,6 +336,8 @@ def setup(config_options):
|
||||
hs.get_datastore().start_doing_background_updates()
|
||||
hs.get_replication_layer().start_get_pdu_cache()
|
||||
|
||||
register_memory_metrics(hs)
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return hs
|
||||
|
||||
@@ -18,10 +18,8 @@ import synapse
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.emailconfig import EmailConfig
|
||||
from synapse.config.key import KeyConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
@@ -43,98 +41,13 @@ from twisted.web.resource import Resource
|
||||
|
||||
from daemonize import Daemonize
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import logging
|
||||
import gc
|
||||
|
||||
logger = logging.getLogger("synapse.app.pusher")
|
||||
|
||||
|
||||
class SlaveConfig(DatabaseConfig):
|
||||
def read_config(self, config):
|
||||
self.replication_url = config["replication_url"]
|
||||
self.server_name = config["server_name"]
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||
)
|
||||
self.user_agent_suffix = None
|
||||
self.start_pushers = True
|
||||
self.listeners = config["listeners"]
|
||||
self.soft_file_limit = config.get("soft_file_limit")
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.public_baseurl = config["public_baseurl"]
|
||||
|
||||
thresholds = config.get("gc_thresholds", None)
|
||||
if thresholds is not None:
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
self.gc_thresholds = (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
else:
|
||||
self.gc_thresholds = None
|
||||
|
||||
# some things used by the auth handler but not actually used in the
|
||||
# pusher codebase
|
||||
self.bcrypt_rounds = None
|
||||
self.ldap_enabled = None
|
||||
self.ldap_server = None
|
||||
self.ldap_port = None
|
||||
self.ldap_tls = None
|
||||
self.ldap_search_base = None
|
||||
self.ldap_search_property = None
|
||||
self.ldap_email_property = None
|
||||
self.ldap_full_name_property = None
|
||||
|
||||
# We would otherwise try to use the registration shared secret as the
|
||||
# macaroon shared secret if there was no macaroon_shared_secret, but
|
||||
# that means pulling in RegistrationConfig too. We don't need to be
|
||||
# backwards compaitible in the pusher codebase so just make people set
|
||||
# macaroon_shared_secret. We set this to None to prevent it referencing
|
||||
# an undefined key.
|
||||
self.registration_shared_secret = None
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("pusher.pid")
|
||||
return """\
|
||||
# Slave configuration
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
listeners: []
|
||||
# Enable a ssh manhole listener on the pusher.
|
||||
# - type: manhole
|
||||
# port: {manhole_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# Enable a metric listener on the pusher.
|
||||
# - type: http
|
||||
# port: {metrics_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# resources:
|
||||
# - names: ["metrics"]
|
||||
# compress: False
|
||||
|
||||
report_stats: False
|
||||
|
||||
daemonize: False
|
||||
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
""" % locals()
|
||||
|
||||
|
||||
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig, KeyConfig):
|
||||
pass
|
||||
|
||||
|
||||
class PusherSlaveStore(
|
||||
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
|
||||
SlavedAccountDataStore
|
||||
@@ -199,7 +112,7 @@ class PusherServer(HomeServer):
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
http_client = self.get_simple_http_client()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
url = replication_url + "/remove_pushers"
|
||||
return http_client.post_json_get_json(url, {
|
||||
"remove": [{
|
||||
@@ -232,8 +145,8 @@ class PusherServer(HomeServer):
|
||||
)
|
||||
logger.info("Synapse pusher now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
for listener in self.config.listeners:
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
@@ -253,7 +166,7 @@ class PusherServer(HomeServer):
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
pusher_pool = self.get_pusherpool()
|
||||
clock = self.get_clock()
|
||||
|
||||
@@ -329,19 +242,30 @@ class PusherServer(HomeServer):
|
||||
yield sleep(30)
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
def start(config_options):
|
||||
try:
|
||||
config = PusherSlaveConfig.load_config(
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse pusher", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
sys.exit(0)
|
||||
assert config.worker_app == "synapse.app.pusher"
|
||||
|
||||
config.setup_logging()
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
if config.start_pushers:
|
||||
sys.stderr.write(
|
||||
"\nThe pushers must be disabled in the main synapse process"
|
||||
"\nbefore they can be run in a separate worker."
|
||||
"\nPlease add ``start_pushers: false`` to the main config"
|
||||
"\n"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.start_pushers = True
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
@@ -349,16 +273,20 @@ def setup(config_options):
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string=get_version_string("Synapse", synapse),
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ps.setup()
|
||||
ps.start_listening()
|
||||
ps.start_listening(config.worker_listeners)
|
||||
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
if ps.config.gc_thresholds:
|
||||
gc.set_threshold(*ps.config.gc_thresholds)
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ps.replicate()
|
||||
@@ -367,30 +295,20 @@ def setup(config_options):
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return ps
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-pusher",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ps = setup(sys.argv[1:])
|
||||
|
||||
if ps.config.daemonize:
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(ps.config.soft_file_limit)
|
||||
if ps.config.gc_thresholds:
|
||||
gc.set_threshold(*ps.config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-pusher",
|
||||
pid=ps.config.pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
ps = start(sys.argv[1:])
|
||||
|
||||
@@ -18,9 +18,8 @@ import synapse
|
||||
|
||||
from synapse.api.constants import EventTypes, PresenceState
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.database import DatabaseConfig
|
||||
from synapse.config.logger import LoggingConfig
|
||||
from synapse.config.appservice import AppServiceConfig
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.http.site import SynapseSite
|
||||
@@ -63,70 +62,6 @@ import ujson as json
|
||||
logger = logging.getLogger("synapse.app.synchrotron")
|
||||
|
||||
|
||||
class SynchrotronConfig(DatabaseConfig, LoggingConfig, AppServiceConfig):
|
||||
def read_config(self, config):
|
||||
self.replication_url = config["replication_url"]
|
||||
self.server_name = config["server_name"]
|
||||
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
|
||||
"use_insecure_ssl_client_just_for_testing_do_not_use", False
|
||||
)
|
||||
self.user_agent_suffix = None
|
||||
self.listeners = config["listeners"]
|
||||
self.soft_file_limit = config.get("soft_file_limit")
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.pid_file = self.abspath(config.get("pid_file"))
|
||||
self.macaroon_secret_key = config["macaroon_secret_key"]
|
||||
self.expire_access_token = config.get("expire_access_token", False)
|
||||
|
||||
thresholds = config.get("gc_thresholds", None)
|
||||
if thresholds is not None:
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
self.gc_thresholds = (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
else:
|
||||
self.gc_thresholds = None
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
pid_file = self.abspath("synchroton.pid")
|
||||
return """\
|
||||
# Slave configuration
|
||||
|
||||
# The replication listener on the synapse to talk to.
|
||||
#replication_url: https://localhost:{replication_port}/_synapse/replication
|
||||
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
listeners:
|
||||
# Enable a /sync listener on the synchrontron
|
||||
#- type: http
|
||||
# port: {http_port}
|
||||
# bind_address: ""
|
||||
# Enable a ssh manhole listener on the synchrotron
|
||||
# - type: manhole
|
||||
# port: {manhole_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# Enable a metric listener on the synchrotron
|
||||
# - type: http
|
||||
# port: {metrics_port}
|
||||
# bind_address: 127.0.0.1
|
||||
# resources:
|
||||
# - names: ["metrics"]
|
||||
# compress: False
|
||||
|
||||
report_stats: False
|
||||
|
||||
daemonize: False
|
||||
|
||||
pid_file: %(pid_file)s
|
||||
""" % locals()
|
||||
|
||||
|
||||
class SynchrotronSlavedStore(
|
||||
SlavedPushRuleStore,
|
||||
SlavedEventStore,
|
||||
@@ -163,7 +98,7 @@ class SynchrotronPresence(object):
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.store = hs.get_datastore()
|
||||
self.user_to_num_current_syncs = {}
|
||||
self.syncing_users_url = hs.config.replication_url + "/syncing_users"
|
||||
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
@@ -350,8 +285,8 @@ class SynchrotronServer(HomeServer):
|
||||
)
|
||||
logger.info("Synapse synchrotron now listening on port %d", port)
|
||||
|
||||
def start_listening(self):
|
||||
for listener in self.config.listeners:
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
@@ -371,7 +306,7 @@ class SynchrotronServer(HomeServer):
|
||||
def replicate(self):
|
||||
http_client = self.get_simple_http_client()
|
||||
store = self.get_datastore()
|
||||
replication_url = self.config.replication_url
|
||||
replication_url = self.config.worker_replication_url
|
||||
clock = self.get_clock()
|
||||
notifier = self.get_notifier()
|
||||
presence_handler = self.get_presence_handler()
|
||||
@@ -470,19 +405,18 @@ class SynchrotronServer(HomeServer):
|
||||
return SynchrotronTyping(self)
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
def start(config_options):
|
||||
try:
|
||||
config = SynchrotronConfig.load_config(
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse synchrotron", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
if not config:
|
||||
sys.exit(0)
|
||||
assert config.worker_app == "synapse.app.synchrotron"
|
||||
|
||||
config.setup_logging()
|
||||
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
@@ -490,17 +424,21 @@ def setup(config_options):
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string=get_version_string("Synapse", synapse),
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
application_service_handler=SynchrotronApplicationService(),
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.start_listening()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
change_resource_limit(ss.config.soft_file_limit)
|
||||
if ss.config.gc_thresholds:
|
||||
ss.set_threshold(*ss.config.gc_thresholds)
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
logger.info("Running")
|
||||
change_resource_limit(config.soft_file_limit)
|
||||
if config.gc_thresholds:
|
||||
gc.set_threshold(*config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
def start():
|
||||
ss.get_datastore().start_profiling()
|
||||
@@ -508,30 +446,20 @@ def setup(config_options):
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
return ss
|
||||
if config.worker_daemonize:
|
||||
daemon = Daemonize(
|
||||
app="synapse-synchrotron",
|
||||
pid=config.worker_pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
ss = setup(sys.argv[1:])
|
||||
|
||||
if ss.config.daemonize:
|
||||
def run():
|
||||
with LoggingContext("run"):
|
||||
change_resource_limit(ss.config.soft_file_limit)
|
||||
if ss.config.gc_thresholds:
|
||||
gc.set_threshold(*ss.config.gc_thresholds)
|
||||
reactor.run()
|
||||
|
||||
daemon = Daemonize(
|
||||
app="synapse-synchrotron",
|
||||
pid=ss.config.pid_file,
|
||||
action=run,
|
||||
auto_close_fds=False,
|
||||
verbose=True,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
daemon.start()
|
||||
else:
|
||||
reactor.run()
|
||||
start(sys.argv[1:])
|
||||
|
||||
@@ -14,11 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
import os.path
|
||||
import subprocess
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
|
||||
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
|
||||
@@ -28,60 +31,181 @@ RED = "\x1b[1;31m"
|
||||
NORMAL = "\x1b[m"
|
||||
|
||||
|
||||
def write(message, colour=NORMAL, stream=sys.stdout):
|
||||
if colour == NORMAL:
|
||||
stream.write(message + "\n")
|
||||
else:
|
||||
stream.write(colour + message + NORMAL + "\n")
|
||||
|
||||
|
||||
def start(configfile):
|
||||
print ("Starting ...")
|
||||
write("Starting ...")
|
||||
args = SYNAPSE
|
||||
args.extend(["--daemonize", "-c", configfile])
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
print (GREEN + "started" + NORMAL)
|
||||
write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print (
|
||||
RED +
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode +
|
||||
NORMAL
|
||||
write(
|
||||
"error starting (exit code: %d); see above for logs" % e.returncode,
|
||||
colour=RED,
|
||||
)
|
||||
|
||||
|
||||
def stop(pidfile):
|
||||
def start_worker(app, configfile, worker_configfile):
|
||||
args = [
|
||||
"python", "-B",
|
||||
"-m", app,
|
||||
"-c", configfile,
|
||||
"-c", worker_configfile
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.check_call(args)
|
||||
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
|
||||
except subprocess.CalledProcessError as e:
|
||||
write(
|
||||
"error starting %s(%r) (exit code: %d); see above for logs" % (
|
||||
app, worker_configfile, e.returncode,
|
||||
),
|
||||
colour=RED,
|
||||
)
|
||||
|
||||
|
||||
def stop(pidfile, app):
|
||||
if os.path.exists(pidfile):
|
||||
pid = int(open(pidfile).read())
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
print (GREEN + "stopped" + NORMAL)
|
||||
write("stopped %s" % (app,), colour=GREEN)
|
||||
|
||||
|
||||
Worker = collections.namedtuple("Worker", [
|
||||
"app", "configfile", "pidfile", "cache_factor"
|
||||
])
|
||||
|
||||
|
||||
def main():
|
||||
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
|
||||
|
||||
if not os.path.exists(configfile):
|
||||
sys.stderr.write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), configfile
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"action",
|
||||
choices=["start", "stop", "restart"],
|
||||
help="whether to start, stop or restart the synapse",
|
||||
)
|
||||
parser.add_argument(
|
||||
"configfile",
|
||||
nargs="?",
|
||||
default="homeserver.yaml",
|
||||
help="the homeserver config file, defaults to homserver.yaml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w", "--worker",
|
||||
metavar="WORKERCONFIG",
|
||||
help="start or stop a single worker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--all-processes",
|
||||
metavar="WORKERCONFIGDIR",
|
||||
help="start or stop all the workers in the given directory"
|
||||
" and the main synapse process",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
if options.worker and options.all_processes:
|
||||
write(
|
||||
'Cannot use "--worker" with "--all-processes"',
|
||||
stream=sys.stderr
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
config = yaml.load(open(configfile))
|
||||
configfile = options.configfile
|
||||
|
||||
if not os.path.exists(configfile):
|
||||
write(
|
||||
"No config file found\n"
|
||||
"To generate a config file, run '%s -c %s --generate-config"
|
||||
" --server-name=<server name>'\n" % (
|
||||
" ".join(SYNAPSE), options.configfile
|
||||
),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
with open(configfile) as stream:
|
||||
config = yaml.load(stream)
|
||||
|
||||
pidfile = config["pid_file"]
|
||||
cache_factor = config.get("synctl_cache_factor", None)
|
||||
cache_factor = config.get("synctl_cache_factor")
|
||||
start_stop_synapse = True
|
||||
|
||||
if cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
|
||||
action = sys.argv[1] if sys.argv[1:] else "usage"
|
||||
if action == "start":
|
||||
start(configfile)
|
||||
elif action == "stop":
|
||||
stop(pidfile)
|
||||
elif action == "restart":
|
||||
stop(pidfile)
|
||||
start(configfile)
|
||||
else:
|
||||
sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],))
|
||||
sys.exit(1)
|
||||
worker_configfiles = []
|
||||
if options.worker:
|
||||
start_stop_synapse = False
|
||||
worker_configfile = options.worker
|
||||
if not os.path.exists(worker_configfile):
|
||||
write(
|
||||
"No worker config found at %r" % (worker_configfile,),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
worker_configfiles.append(worker_configfile)
|
||||
|
||||
if options.all_processes:
|
||||
worker_configdir = options.all_processes
|
||||
if not os.path.isdir(worker_configdir):
|
||||
write(
|
||||
"No worker config directory found at %r" % (worker_configdir,),
|
||||
stream=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
worker_configfiles.extend(sorted(glob.glob(
|
||||
os.path.join(worker_configdir, "*.yaml")
|
||||
)))
|
||||
|
||||
workers = []
|
||||
for worker_configfile in worker_configfiles:
|
||||
with open(worker_configfile) as stream:
|
||||
worker_config = yaml.load(stream)
|
||||
worker_app = worker_config["worker_app"]
|
||||
worker_pidfile = worker_config["worker_pid_file"]
|
||||
worker_daemonize = worker_config["worker_daemonize"]
|
||||
assert worker_daemonize # TODO print something more user friendly
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||
workers.append(Worker(
|
||||
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
|
||||
))
|
||||
|
||||
action = options.action
|
||||
|
||||
if action == "stop" or action == "restart":
|
||||
for worker in workers:
|
||||
stop(worker.pidfile, worker.app)
|
||||
|
||||
if start_stop_synapse:
|
||||
stop(pidfile, "synapse.app.homeserver")
|
||||
|
||||
# TODO: Wait for synapse to actually shutdown before starting it again
|
||||
|
||||
if action == "start" or action == "restart":
|
||||
if start_stop_synapse:
|
||||
start(configfile)
|
||||
|
||||
for worker in workers:
|
||||
if worker.cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
|
||||
|
||||
start_worker(worker.app, configfile, worker.configfile)
|
||||
|
||||
if cache_factor:
|
||||
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
|
||||
else:
|
||||
os.environ.pop("SYNAPSE_CACHE_FACTOR", None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -157,9 +157,40 @@ class Config(object):
|
||||
return default_config, config
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, description, argv, generate_section=None):
|
||||
obj = cls()
|
||||
def load_config(cls, description, argv):
|
||||
config_parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
)
|
||||
config_parser.add_argument(
|
||||
"-c", "--config-path",
|
||||
action="append",
|
||||
metavar="CONFIG_FILE",
|
||||
help="Specify config file. Can be given multiple times and"
|
||||
" may specify directories containing *.yaml files."
|
||||
)
|
||||
|
||||
config_parser.add_argument(
|
||||
"--keys-directory",
|
||||
metavar="DIRECTORY",
|
||||
help="Where files such as certs and signing keys are stored when"
|
||||
" their location is given explicitly in the config."
|
||||
" Defaults to the directory containing the last config file",
|
||||
)
|
||||
|
||||
config_args = config_parser.parse_args(argv)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
obj = cls()
|
||||
obj.read_config_files(
|
||||
config_files,
|
||||
keys_directory=config_args.keys_directory,
|
||||
generate_keys=False,
|
||||
)
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def load_or_generate_config(cls, description, argv):
|
||||
config_parser = argparse.ArgumentParser(add_help=False)
|
||||
config_parser.add_argument(
|
||||
"-c", "--config-path",
|
||||
@@ -176,7 +207,7 @@ class Config(object):
|
||||
config_parser.add_argument(
|
||||
"--report-stats",
|
||||
action="store",
|
||||
help="Stuff",
|
||||
help="Whether the generated config reports anonymized usage statistics",
|
||||
choices=["yes", "no"]
|
||||
)
|
||||
config_parser.add_argument(
|
||||
@@ -197,36 +228,11 @@ class Config(object):
|
||||
)
|
||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||
|
||||
config_files = find_config_files(search_paths=config_args.config_path)
|
||||
|
||||
generate_keys = config_args.generate_keys
|
||||
|
||||
config_files = []
|
||||
if config_args.config_path:
|
||||
for config_path in config_args.config_path:
|
||||
if os.path.isdir(config_path):
|
||||
# We accept specifying directories as config paths, we search
|
||||
# inside that directory for all files matching *.yaml, and then
|
||||
# we apply them in *sorted* order.
|
||||
files = []
|
||||
for entry in os.listdir(config_path):
|
||||
entry_path = os.path.join(config_path, entry)
|
||||
if not os.path.isfile(entry_path):
|
||||
print (
|
||||
"Found subdirectory in config directory: %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
if not entry.endswith(".yaml"):
|
||||
print (
|
||||
"Found file in config directory that does not"
|
||||
" end in '.yaml': %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
files.append(entry_path)
|
||||
|
||||
config_files.extend(sorted(files))
|
||||
else:
|
||||
config_files.append(config_path)
|
||||
obj = cls()
|
||||
|
||||
if config_args.generate_config:
|
||||
if config_args.report_stats is None:
|
||||
@@ -299,28 +305,43 @@ class Config(object):
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
|
||||
if config_args.keys_directory:
|
||||
config_dir_path = config_args.keys_directory
|
||||
else:
|
||||
config_dir_path = os.path.dirname(config_args.config_path[-1])
|
||||
config_dir_path = os.path.abspath(config_dir_path)
|
||||
obj.read_config_files(
|
||||
config_files,
|
||||
keys_directory=config_args.keys_directory,
|
||||
generate_keys=generate_keys,
|
||||
)
|
||||
|
||||
if generate_keys:
|
||||
return None
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
|
||||
def read_config_files(self, config_files, keys_directory=None,
|
||||
generate_keys=False):
|
||||
if not keys_directory:
|
||||
keys_directory = os.path.dirname(config_files[-1])
|
||||
|
||||
config_dir_path = os.path.abspath(keys_directory)
|
||||
|
||||
specified_config = {}
|
||||
for config_file in config_files:
|
||||
yaml_config = cls.read_config_file(config_file)
|
||||
yaml_config = self.read_config_file(config_file)
|
||||
specified_config.update(yaml_config)
|
||||
|
||||
if "server_name" not in specified_config:
|
||||
raise ConfigError(MISSING_SERVER_NAME)
|
||||
|
||||
server_name = specified_config["server_name"]
|
||||
_, config = obj.generate_config(
|
||||
_, config = self.generate_config(
|
||||
config_dir_path=config_dir_path,
|
||||
server_name=server_name,
|
||||
is_generating_file=False,
|
||||
)
|
||||
config.pop("log_config")
|
||||
config.update(specified_config)
|
||||
|
||||
if "report_stats" not in config:
|
||||
raise ConfigError(
|
||||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
|
||||
@@ -328,11 +349,51 @@ class Config(object):
|
||||
)
|
||||
|
||||
if generate_keys:
|
||||
obj.invoke_all("generate_files", config)
|
||||
self.invoke_all("generate_files", config)
|
||||
return
|
||||
|
||||
obj.invoke_all("read_config", config)
|
||||
self.invoke_all("read_config", config)
|
||||
|
||||
obj.invoke_all("read_arguments", args)
|
||||
|
||||
return obj
|
||||
def find_config_files(search_paths):
|
||||
"""Finds config files using a list of search paths. If a path is a file
|
||||
then that file path is added to the list. If a search path is a directory
|
||||
then all the "*.yaml" files in that directory are added to the list in
|
||||
sorted order.
|
||||
|
||||
Args:
|
||||
search_paths(list(str)): A list of paths to search.
|
||||
|
||||
Returns:
|
||||
list(str): A list of file paths.
|
||||
"""
|
||||
|
||||
config_files = []
|
||||
if search_paths:
|
||||
for config_path in search_paths:
|
||||
if os.path.isdir(config_path):
|
||||
# We accept specifying directories as config paths, we search
|
||||
# inside that directory for all files matching *.yaml, and then
|
||||
# we apply them in *sorted* order.
|
||||
files = []
|
||||
for entry in os.listdir(config_path):
|
||||
entry_path = os.path.join(config_path, entry)
|
||||
if not os.path.isfile(entry_path):
|
||||
print (
|
||||
"Found subdirectory in config directory: %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
if not entry.endswith(".yaml"):
|
||||
print (
|
||||
"Found file in config directory that does not"
|
||||
" end in '.yaml': %r. IGNORING."
|
||||
) % (entry_path, )
|
||||
continue
|
||||
|
||||
files.append(entry_path)
|
||||
|
||||
config_files.extend(sorted(files))
|
||||
else:
|
||||
config_files.append(config_path)
|
||||
return config_files
|
||||
|
||||
@@ -27,6 +27,7 @@ class CaptchaConfig(Config):
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
## Captcha ##
|
||||
# See docs/CAPTCHA_SETUP for full details of configuring this.
|
||||
|
||||
# This Home Server's ReCAPTCHA public key.
|
||||
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
||||
|
||||
@@ -32,13 +32,15 @@ from .password import PasswordConfig
|
||||
from .jwt import JWTConfig
|
||||
from .ldap import LDAPConfig
|
||||
from .emailconfig import EmailConfig
|
||||
from .workers import WorkerConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,):
|
||||
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,
|
||||
WorkerConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -13,40 +13,88 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
|
||||
MISSING_LDAP3 = (
|
||||
"Missing ldap3 library. This is required for LDAP Authentication."
|
||||
)
|
||||
|
||||
|
||||
class LDAPMode(object):
|
||||
SIMPLE = "simple",
|
||||
SEARCH = "search",
|
||||
|
||||
LIST = (SIMPLE, SEARCH)
|
||||
|
||||
|
||||
class LDAPConfig(Config):
|
||||
def read_config(self, config):
|
||||
ldap_config = config.get("ldap_config", None)
|
||||
if ldap_config:
|
||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||
self.ldap_server = ldap_config["server"]
|
||||
self.ldap_port = ldap_config["port"]
|
||||
self.ldap_tls = ldap_config.get("tls", False)
|
||||
self.ldap_search_base = ldap_config["search_base"]
|
||||
self.ldap_search_property = ldap_config["search_property"]
|
||||
self.ldap_email_property = ldap_config["email_property"]
|
||||
self.ldap_full_name_property = ldap_config["full_name_property"]
|
||||
else:
|
||||
self.ldap_enabled = False
|
||||
self.ldap_server = None
|
||||
self.ldap_port = None
|
||||
self.ldap_tls = False
|
||||
self.ldap_search_base = None
|
||||
self.ldap_search_property = None
|
||||
self.ldap_email_property = None
|
||||
self.ldap_full_name_property = None
|
||||
ldap_config = config.get("ldap_config", {})
|
||||
|
||||
self.ldap_enabled = ldap_config.get("enabled", False)
|
||||
|
||||
if self.ldap_enabled:
|
||||
# verify dependencies are available
|
||||
try:
|
||||
import ldap3
|
||||
ldap3 # to stop unused lint
|
||||
except ImportError:
|
||||
raise ConfigError(MISSING_LDAP3)
|
||||
|
||||
self.ldap_mode = LDAPMode.SIMPLE
|
||||
|
||||
# verify config sanity
|
||||
self.require_keys(ldap_config, [
|
||||
"uri",
|
||||
"base",
|
||||
"attributes",
|
||||
])
|
||||
|
||||
self.ldap_uri = ldap_config["uri"]
|
||||
self.ldap_start_tls = ldap_config.get("start_tls", False)
|
||||
self.ldap_base = ldap_config["base"]
|
||||
self.ldap_attributes = ldap_config["attributes"]
|
||||
|
||||
if "bind_dn" in ldap_config:
|
||||
self.ldap_mode = LDAPMode.SEARCH
|
||||
self.require_keys(ldap_config, [
|
||||
"bind_dn",
|
||||
"bind_password",
|
||||
])
|
||||
|
||||
self.ldap_bind_dn = ldap_config["bind_dn"]
|
||||
self.ldap_bind_password = ldap_config["bind_password"]
|
||||
self.ldap_filter = ldap_config.get("filter", None)
|
||||
|
||||
# verify attribute lookup
|
||||
self.require_keys(ldap_config['attributes'], [
|
||||
"uid",
|
||||
"name",
|
||||
"mail",
|
||||
])
|
||||
|
||||
def require_keys(self, config, required):
|
||||
missing = [key for key in required if key not in config]
|
||||
if missing:
|
||||
raise ConfigError(
|
||||
"LDAP enabled but missing required config values: {}".format(
|
||||
", ".join(missing)
|
||||
)
|
||||
)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# ldap_config:
|
||||
# enabled: true
|
||||
# server: "ldap://localhost"
|
||||
# port: 389
|
||||
# tls: false
|
||||
# search_base: "ou=Users,dc=example,dc=com"
|
||||
# search_property: "cn"
|
||||
# email_property: "email"
|
||||
# full_name_property: "givenName"
|
||||
# uri: "ldap://ldap.example.com:389"
|
||||
# start_tls: true
|
||||
# base: "ou=users,dc=example,dc=com"
|
||||
# attributes:
|
||||
# uid: "cn"
|
||||
# mail: "email"
|
||||
# name: "givenName"
|
||||
# #bind_dn:
|
||||
# #bind_password:
|
||||
# #filter: "(objectClass=posixAccount)"
|
||||
"""
|
||||
|
||||
@@ -126,54 +126,58 @@ class LoggingConfig(Config):
|
||||
)
|
||||
|
||||
def setup_logging(self):
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
)
|
||||
if self.log_config is None:
|
||||
setup_logging(self.log_config, self.log_file, self.verbosity)
|
||||
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if self.verbosity:
|
||||
level = logging.DEBUG
|
||||
if self.verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
logger = logging.getLogger('')
|
||||
logger.setLevel(level)
|
||||
def setup_logging(log_config=None, log_file=None, verbosity=None):
|
||||
log_format = (
|
||||
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
|
||||
" - %(message)s"
|
||||
)
|
||||
if log_config is None:
|
||||
|
||||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if verbosity:
|
||||
level = logging.DEBUG
|
||||
if verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
formatter = logging.Formatter(log_format)
|
||||
if self.log_file:
|
||||
# TODO: Customisable file size / backup count
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
||||
)
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
logger = logging.getLogger('')
|
||||
logger.setLevel(level)
|
||||
|
||||
def sighup(signum, stack):
|
||||
logger.info("Closing log file due to SIGHUP")
|
||||
handler.doRollover()
|
||||
logger.info("Opened new log file due to SIGHUP")
|
||||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
||||
|
||||
# TODO(paul): obviously this is a terrible mechanism for
|
||||
# stealing SIGHUP, because it means no other part of synapse
|
||||
# can use it instead. If we want to catch SIGHUP anywhere
|
||||
# else as well, I'd suggest we find a nicer way to broadcast
|
||||
# it around.
|
||||
if getattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, sighup)
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
formatter = logging.Formatter(log_format)
|
||||
if log_file:
|
||||
# TODO: Customisable file size / backup count
|
||||
handler = logging.handlers.RotatingFileHandler(
|
||||
log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
|
||||
)
|
||||
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
def sighup(signum, stack):
|
||||
logger.info("Closing log file due to SIGHUP")
|
||||
handler.doRollover()
|
||||
logger.info("Opened new log file due to SIGHUP")
|
||||
|
||||
logger.addHandler(handler)
|
||||
# TODO(paul): obviously this is a terrible mechanism for
|
||||
# stealing SIGHUP, because it means no other part of synapse
|
||||
# can use it instead. If we want to catch SIGHUP anywhere
|
||||
# else as well, I'd suggest we find a nicer way to broadcast
|
||||
# it around.
|
||||
if getattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, sighup)
|
||||
else:
|
||||
with open(self.log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
observer = PythonLoggingObserver()
|
||||
observer.start()
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
|
||||
logger.addHandler(handler)
|
||||
else:
|
||||
with open(log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
|
||||
observer = PythonLoggingObserver()
|
||||
observer.start()
|
||||
|
||||
@@ -23,10 +23,14 @@ class PasswordConfig(Config):
|
||||
def read_config(self, config):
|
||||
password_config = config.get("password_config", {})
|
||||
self.password_enabled = password_config.get("enabled", True)
|
||||
self.password_pepper = password_config.get("pepper", "")
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
return """
|
||||
# Enable password for login.
|
||||
password_config:
|
||||
enabled: true
|
||||
# Uncomment and change to a secret random string for extra security.
|
||||
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
|
||||
#pepper: ""
|
||||
"""
|
||||
|
||||
@@ -27,7 +27,7 @@ class ServerConfig(Config):
|
||||
self.daemonize = config.get("daemonize")
|
||||
self.print_pidfile = config.get("print_pidfile")
|
||||
self.user_agent_suffix = config.get("user_agent_suffix")
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
|
||||
self.public_baseurl = config.get("public_baseurl")
|
||||
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
|
||||
|
||||
@@ -38,19 +38,7 @@ class ServerConfig(Config):
|
||||
|
||||
self.listeners = config.get("listeners", [])
|
||||
|
||||
thresholds = config.get("gc_thresholds", None)
|
||||
if thresholds is not None:
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
self.gc_thresholds = (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
else:
|
||||
self.gc_thresholds = None
|
||||
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
|
||||
|
||||
bind_port = config.get("bind_port")
|
||||
if bind_port:
|
||||
@@ -119,26 +107,6 @@ class ServerConfig(Config):
|
||||
]
|
||||
})
|
||||
|
||||
# Attempt to guess the content_addr for the v0 content repostitory
|
||||
content_addr = config.get("content_addr")
|
||||
if not content_addr:
|
||||
for listener in self.listeners:
|
||||
if listener["type"] == "http" and not listener.get("tls", False):
|
||||
unsecure_port = listener["port"]
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("Could not determine 'content_addr'")
|
||||
|
||||
host = self.server_name
|
||||
if ':' not in host:
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
else:
|
||||
host = host.split(':')[0]
|
||||
host = "%s:%d" % (host, unsecure_port)
|
||||
content_addr = "http://%s" % (host,)
|
||||
|
||||
self.content_addr = content_addr
|
||||
|
||||
def default_config(self, server_name, **kwargs):
|
||||
if ":" in server_name:
|
||||
bind_port = int(server_name.split(":")[1])
|
||||
@@ -181,7 +149,6 @@ class ServerConfig(Config):
|
||||
# room directory.
|
||||
# secondary_directory_servers:
|
||||
# - matrix.org
|
||||
# - vector.im
|
||||
|
||||
# List of ports that Synapse should listen on, their purpose and their
|
||||
# configuration.
|
||||
@@ -264,3 +231,20 @@ class ServerConfig(Config):
|
||||
type=int,
|
||||
help="Turn on the twisted telnet manhole"
|
||||
" service on the given port.")
|
||||
|
||||
|
||||
def read_gc_thresholds(thresholds):
|
||||
"""Reads the three integer thresholds for garbage collection. Ensures that
|
||||
the thresholds are integers if thresholds are supplied.
|
||||
"""
|
||||
if thresholds is None:
|
||||
return None
|
||||
try:
|
||||
assert len(thresholds) == 3
|
||||
return (
|
||||
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
|
||||
)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Value of `gc_threshold` must be a list of three integers if set"
|
||||
)
|
||||
|
||||
31
synapse/config/workers.py
Normal file
31
synapse/config/workers.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 matrix.org
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class WorkerConfig(Config):
|
||||
"""The workers are processes run separately to the main synapse process.
|
||||
They have their own pid_file and listener configuration. They use the
|
||||
replication_url to talk to the main synapse process."""
|
||||
|
||||
def read_config(self, config):
|
||||
self.worker_app = config.get("worker_app")
|
||||
self.worker_listeners = config.get("worker_listeners")
|
||||
self.worker_daemonize = config.get("worker_daemonize")
|
||||
self.worker_pid_file = config.get("worker_pid_file")
|
||||
self.worker_log_file = config.get("worker_log_file")
|
||||
self.worker_log_config = config.get("worker_log_config")
|
||||
self.worker_replication_url = config.get("worker_replication_url")
|
||||
@@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||
def __init__(self):
|
||||
self.remote_key = defer.Deferred()
|
||||
self.host = None
|
||||
self._peer = None
|
||||
|
||||
def connectionMade(self):
|
||||
self.host = self.transport.getHost()
|
||||
logger.debug("Connected to %s", self.host)
|
||||
self._peer = self.transport.getPeer()
|
||||
logger.debug("Connected to %s", self._peer)
|
||||
|
||||
self.sendCommand(b"GET", self.path)
|
||||
if self.host:
|
||||
self.sendHeader(b"Host", self.host)
|
||||
@@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||
self.timer.cancel()
|
||||
|
||||
def on_timeout(self):
|
||||
logger.debug("Timeout waiting for response from %s", self.host)
|
||||
logger.debug(
|
||||
"Timeout waiting for response from %s: %s",
|
||||
self.host, self._peer,
|
||||
)
|
||||
self.errback(IOError("Timeout waiting for response"))
|
||||
self.transport.abortConnection()
|
||||
|
||||
@@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
|
||||
def protocol(self):
|
||||
protocol = SynapseKeyClientProtocol()
|
||||
protocol.path = self.path
|
||||
protocol.host = self.host
|
||||
return protocol
|
||||
|
||||
@@ -44,7 +44,21 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
||||
VerifyKeyRequest = namedtuple("VerifyRequest", (
|
||||
"server_name", "key_ids", "json_object", "deferred"
|
||||
))
|
||||
"""
|
||||
A request for a verify key to verify a JSON object.
|
||||
|
||||
Attributes:
|
||||
server_name(str): The name of the server to verify against.
|
||||
key_ids(set(str)): The set of key_ids to that could be used to verify the
|
||||
JSON object
|
||||
json_object(dict): The JSON object to verify.
|
||||
deferred(twisted.internet.defer.Deferred):
|
||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||
a verify key has been fetched
|
||||
"""
|
||||
|
||||
|
||||
class Keyring(object):
|
||||
@@ -74,39 +88,32 @@ class Keyring(object):
|
||||
list of deferreds indicating success or failure to verify each
|
||||
json object's signature for the given server_name.
|
||||
"""
|
||||
group_id_to_json = {}
|
||||
group_id_to_group = {}
|
||||
group_ids = []
|
||||
|
||||
next_group_id = 0
|
||||
deferreds = {}
|
||||
verify_requests = []
|
||||
|
||||
for server_name, json_object in server_and_json:
|
||||
logger.debug("Verifying for %s", server_name)
|
||||
group_id = next_group_id
|
||||
next_group_id += 1
|
||||
group_ids.append(group_id)
|
||||
|
||||
key_ids = signature_ids(json_object, server_name)
|
||||
if not key_ids:
|
||||
deferreds[group_id] = defer.fail(SynapseError(
|
||||
deferred = defer.fail(SynapseError(
|
||||
400,
|
||||
"Not signed with a supported algorithm",
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
else:
|
||||
deferreds[group_id] = defer.Deferred()
|
||||
deferred = defer.Deferred()
|
||||
|
||||
group = KeyGroup(server_name, group_id, key_ids)
|
||||
verify_request = VerifyKeyRequest(
|
||||
server_name, key_ids, json_object, deferred
|
||||
)
|
||||
|
||||
group_id_to_group[group_id] = group
|
||||
group_id_to_json[group_id] = json_object
|
||||
verify_requests.append(verify_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_key_deferred(group, deferred):
|
||||
server_name = group.server_name
|
||||
def handle_key_deferred(verify_request):
|
||||
server_name = verify_request.server_name
|
||||
try:
|
||||
_, _, key_id, verify_key = yield deferred
|
||||
_, key_id, verify_key = yield verify_request.deferred
|
||||
except IOError as e:
|
||||
logger.warn(
|
||||
"Got IOError when downloading keys for %s: %s %s",
|
||||
@@ -128,7 +135,7 @@ class Keyring(object):
|
||||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
json_object = group_id_to_json[group.group_id]
|
||||
json_object = verify_request.json_object
|
||||
|
||||
try:
|
||||
verify_signed_json(json_object, server_name, verify_key)
|
||||
@@ -157,36 +164,34 @@ class Keyring(object):
|
||||
|
||||
# Actually start fetching keys.
|
||||
wait_on_deferred.addBoth(
|
||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||
lambda _: self.get_server_verify_keys(verify_requests)
|
||||
)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||
# any lookups waiting will proceed.
|
||||
server_to_gids = {}
|
||||
server_to_request_ids = {}
|
||||
|
||||
def remove_deferreds(res, server_name, group_id):
|
||||
server_to_gids[server_name].discard(group_id)
|
||||
if not server_to_gids[server_name]:
|
||||
def remove_deferreds(res, server_name, verify_request):
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids[server_name].discard(request_id)
|
||||
if not server_to_request_ids[server_name]:
|
||||
d = server_to_deferred.pop(server_name, None)
|
||||
if d:
|
||||
d.callback(None)
|
||||
return res
|
||||
|
||||
for g_id, deferred in deferreds.items():
|
||||
server_name = group_id_to_group[g_id].server_name
|
||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
request_id = id(verify_request)
|
||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, verify_request)
|
||||
|
||||
# Pass those keys to handle_key_deferred so that the json object
|
||||
# signatures can be verified
|
||||
return [
|
||||
preserve_context_over_fn(
|
||||
handle_key_deferred,
|
||||
group_id_to_group[g_id],
|
||||
deferreds[g_id],
|
||||
)
|
||||
for g_id in group_ids
|
||||
preserve_context_over_fn(handle_key_deferred, verify_request)
|
||||
for verify_request in verify_requests
|
||||
]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -220,7 +225,7 @@ class Keyring(object):
|
||||
|
||||
d.addBoth(rm, server_name)
|
||||
|
||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||
def get_server_verify_keys(self, verify_requests):
|
||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||
each group.
|
||||
"""
|
||||
@@ -237,62 +242,64 @@ class Keyring(object):
|
||||
merged_results = {}
|
||||
|
||||
missing_keys = {}
|
||||
for group in group_id_to_group.values():
|
||||
missing_keys.setdefault(group.server_name, set()).update(
|
||||
group.key_ids
|
||||
for verify_request in verify_requests:
|
||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||
verify_request.key_ids
|
||||
)
|
||||
|
||||
for fn in key_fetch_fns:
|
||||
results = yield fn(missing_keys.items())
|
||||
merged_results.update(results)
|
||||
|
||||
# We now need to figure out which groups we have keys for
|
||||
# and which we don't
|
||||
missing_groups = {}
|
||||
for group in group_id_to_group.values():
|
||||
for key_id in group.key_ids:
|
||||
if key_id in merged_results[group.server_name]:
|
||||
# We now need to figure out which verify requests we have keys
|
||||
# for and which we don't
|
||||
missing_keys = {}
|
||||
requests_missing_keys = []
|
||||
for verify_request in verify_requests:
|
||||
server_name = verify_request.server_name
|
||||
result_keys = merged_results[server_name]
|
||||
|
||||
if verify_request.deferred.called:
|
||||
# We've already called this deferred, which probably
|
||||
# means that we've already found a key for it.
|
||||
continue
|
||||
|
||||
for key_id in verify_request.key_ids:
|
||||
if key_id in result_keys:
|
||||
with PreserveLoggingContext():
|
||||
group_id_to_deferred[group.group_id].callback((
|
||||
group.group_id,
|
||||
group.server_name,
|
||||
verify_request.deferred.callback((
|
||||
server_name,
|
||||
key_id,
|
||||
merged_results[group.server_name][key_id],
|
||||
result_keys[key_id],
|
||||
))
|
||||
break
|
||||
else:
|
||||
missing_groups.setdefault(
|
||||
group.server_name, []
|
||||
).append(group)
|
||||
# The else block is only reached if the loop above
|
||||
# doesn't break.
|
||||
missing_keys.setdefault(server_name, set()).update(
|
||||
verify_request.key_ids
|
||||
)
|
||||
requests_missing_keys.append(verify_request)
|
||||
|
||||
if not missing_groups:
|
||||
if not missing_keys:
|
||||
break
|
||||
|
||||
missing_keys = {
|
||||
server_name: set(
|
||||
key_id for group in groups for key_id in group.key_ids
|
||||
)
|
||||
for server_name, groups in missing_groups.items()
|
||||
}
|
||||
|
||||
for group in missing_groups.values():
|
||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||
for verify_request in requests_missing_keys.values():
|
||||
verify_request.deferred.errback(SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (
|
||||
group.server_name, group.key_ids,
|
||||
verify_request.server_name, verify_request.key_ids,
|
||||
),
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
|
||||
def on_err(err):
|
||||
for deferred in group_id_to_deferred.values():
|
||||
if not deferred.called:
|
||||
deferred.errback(err)
|
||||
for verify_request in verify_requests:
|
||||
if not verify_request.deferred.called:
|
||||
verify_request.deferred.errback(err)
|
||||
|
||||
do_iterations().addErrback(on_err)
|
||||
|
||||
return group_id_to_deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
res = yield defer.gatherResults(
|
||||
@@ -447,7 +454,7 @@ class Keyring(object):
|
||||
)
|
||||
|
||||
processed_response = yield self.process_v2_response(
|
||||
perspective_name, response
|
||||
perspective_name, response, only_from_server=False
|
||||
)
|
||||
|
||||
for server_name, response_keys in processed_response.items():
|
||||
@@ -527,7 +534,7 @@ class Keyring(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_v2_response(self, from_server, response_json,
|
||||
requested_ids=[]):
|
||||
requested_ids=[], only_from_server=True):
|
||||
time_now_ms = self.clock.time_msec()
|
||||
response_keys = {}
|
||||
verify_keys = {}
|
||||
@@ -551,6 +558,13 @@ class Keyring(object):
|
||||
|
||||
results = {}
|
||||
server_name = response_json["server_name"]
|
||||
if only_from_server:
|
||||
if server_name != from_server:
|
||||
raise ValueError(
|
||||
"Expected a response for server %r not %r" % (
|
||||
from_server, server_name
|
||||
)
|
||||
)
|
||||
for key_id in response_json["signatures"].get(server_name, {}):
|
||||
if key_id not in response_json["verify_keys"]:
|
||||
raise ValueError(
|
||||
|
||||
@@ -31,6 +31,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FederationBase(object):
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||
include_none=False):
|
||||
|
||||
@@ -52,6 +52,8 @@ sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
|
||||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationClient, self).__init__(hs)
|
||||
|
||||
def start_get_pdu_cache(self):
|
||||
self._get_pdu_cache = ExpiringCache(
|
||||
@@ -234,9 +236,9 @@ class FederationClient(FederationBase):
|
||||
# TODO: Rate limit the number of times we try and get the same event.
|
||||
|
||||
if self._get_pdu_cache:
|
||||
e = self._get_pdu_cache.get(event_id)
|
||||
if e:
|
||||
defer.returnValue(e)
|
||||
ev = self._get_pdu_cache.get(event_id)
|
||||
if ev:
|
||||
defer.returnValue(ev)
|
||||
|
||||
pdu = None
|
||||
for destination in destinations:
|
||||
@@ -267,7 +269,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
break
|
||||
|
||||
except SynapseError:
|
||||
except SynapseError as e:
|
||||
logger.info(
|
||||
"Failed to get PDU %s from %s because %s",
|
||||
event_id, destination, e,
|
||||
@@ -312,6 +314,42 @@ class FederationClient(FederationBase):
|
||||
Deferred: Results in a list of PDUs.
|
||||
"""
|
||||
|
||||
try:
|
||||
# First we try and ask for just the IDs, as thats far quicker if
|
||||
# we have most of the state and auth_chain already.
|
||||
# However, this may 404 if the other side has an old synapse.
|
||||
result = yield self.transport_layer.get_room_state_ids(
|
||||
destination, room_id, event_id=event_id,
|
||||
)
|
||||
|
||||
state_event_ids = result["pdu_ids"]
|
||||
auth_event_ids = result.get("auth_chain_ids", [])
|
||||
|
||||
fetched_events, failed_to_fetch = yield self.get_events(
|
||||
[destination], room_id, set(state_event_ids + auth_event_ids)
|
||||
)
|
||||
|
||||
if failed_to_fetch:
|
||||
logger.warn("Failed to get %r", failed_to_fetch)
|
||||
|
||||
event_map = {
|
||||
ev.event_id: ev for ev in fetched_events
|
||||
}
|
||||
|
||||
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
|
||||
auth_chain = [
|
||||
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
|
||||
]
|
||||
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
defer.returnValue((pdus, auth_chain))
|
||||
except HttpResponseException as e:
|
||||
if e.code == 400 or e.code == 404:
|
||||
logger.info("Failed to use get_room_state_ids API, falling back")
|
||||
else:
|
||||
raise e
|
||||
|
||||
result = yield self.transport_layer.get_room_state(
|
||||
destination, room_id, event_id=event_id,
|
||||
)
|
||||
@@ -325,18 +363,93 @@ class FederationClient(FederationBase):
|
||||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
seen_events = yield self.store.get_events([
|
||||
ev.event_id for ev in itertools.chain(pdus, auth_chain)
|
||||
])
|
||||
|
||||
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, pdus, outlier=True
|
||||
destination,
|
||||
[p for p in pdus if p.event_id not in seen_events],
|
||||
outlier=True
|
||||
)
|
||||
signed_pdus.extend(
|
||||
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
|
||||
)
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
destination, auth_chain, outlier=True
|
||||
destination,
|
||||
[p for p in auth_chain if p.event_id not in seen_events],
|
||||
outlier=True
|
||||
)
|
||||
signed_auth.extend(
|
||||
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
|
||||
)
|
||||
|
||||
signed_auth.sort(key=lambda e: e.depth)
|
||||
|
||||
defer.returnValue((signed_pdus, signed_auth))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events(self, destinations, room_id, event_ids, return_local=True):
|
||||
"""Fetch events from some remote destinations, checking if we already
|
||||
have them.
|
||||
|
||||
Args:
|
||||
destinations (list)
|
||||
room_id (str)
|
||||
event_ids (list)
|
||||
return_local (bool): Whether to include events we already have in
|
||||
the DB in the returned list of events
|
||||
|
||||
Returns:
|
||||
Deferred: A deferred resolving to a 2-tuple where the first is a list of
|
||||
events and the second is a list of event ids that we failed to fetch.
|
||||
"""
|
||||
if return_local:
|
||||
seen_events = yield self.store.get_events(event_ids)
|
||||
signed_events = seen_events.values()
|
||||
else:
|
||||
seen_events = yield self.store.have_events(event_ids)
|
||||
signed_events = []
|
||||
|
||||
failed_to_fetch = set()
|
||||
|
||||
missing_events = set(event_ids)
|
||||
for k in seen_events:
|
||||
missing_events.discard(k)
|
||||
|
||||
if not missing_events:
|
||||
defer.returnValue((signed_events, failed_to_fetch))
|
||||
|
||||
def random_server_list():
|
||||
srvs = list(destinations)
|
||||
random.shuffle(srvs)
|
||||
return srvs
|
||||
|
||||
batch_size = 20
|
||||
missing_events = list(missing_events)
|
||||
for i in xrange(0, len(missing_events), batch_size):
|
||||
batch = set(missing_events[i:i + batch_size])
|
||||
|
||||
deferreds = [
|
||||
self.get_pdu(
|
||||
destinations=random_server_list(),
|
||||
event_id=e_id,
|
||||
)
|
||||
for e_id in batch
|
||||
]
|
||||
|
||||
res = yield defer.DeferredList(deferreds, consumeErrors=True)
|
||||
for success, result in res:
|
||||
if success:
|
||||
signed_events.append(result)
|
||||
batch.discard(result.event_id)
|
||||
|
||||
# We removed all events we successfully fetched from `batch`
|
||||
failed_to_fetch.update(batch)
|
||||
|
||||
defer.returnValue((signed_events, failed_to_fetch))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_event_auth(self, destination, room_id, event_id):
|
||||
@@ -412,14 +525,19 @@ class FederationClient(FederationBase):
|
||||
(destination, self.event_from_pdu_json(pdu_dict))
|
||||
)
|
||||
break
|
||||
except CodeMessageException:
|
||||
raise
|
||||
except CodeMessageException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise
|
||||
else:
|
||||
logger.warn(
|
||||
"Failed to make_%s via %s: %s",
|
||||
membership, destination, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Failed to make_%s via %s: %s",
|
||||
membership, destination, e.message
|
||||
)
|
||||
raise
|
||||
|
||||
raise RuntimeError("Failed to send to any server.")
|
||||
|
||||
@@ -491,8 +609,14 @@ class FederationClient(FederationBase):
|
||||
"auth_chain": signed_auth,
|
||||
"origin": destination,
|
||||
})
|
||||
except CodeMessageException:
|
||||
raise
|
||||
except CodeMessageException as e:
|
||||
if not 500 <= e.code < 600:
|
||||
raise
|
||||
else:
|
||||
logger.exception(
|
||||
"Failed to send_join via %s: %s",
|
||||
destination, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to send_join via %s: %s",
|
||||
|
||||
@@ -19,11 +19,13 @@ from twisted.internet import defer
|
||||
from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.events import FrozenEvent
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.api.errors import FederationError, SynapseError
|
||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
|
||||
@@ -44,6 +46,18 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
|
||||
|
||||
|
||||
class FederationServer(FederationBase):
|
||||
def __init__(self, hs):
|
||||
super(FederationServer, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self._room_pdu_linearizer = Linearizer()
|
||||
self._server_linearizer = Linearizer()
|
||||
|
||||
# We cache responses to state queries, as they take a while and often
|
||||
# come in waves.
|
||||
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
|
||||
|
||||
def set_handler(self, handler):
|
||||
"""Sets the handler that the replication layer will use to communicate
|
||||
receipt of new PDUs from other home servers. The required methods are
|
||||
@@ -83,11 +97,14 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_backfill_request(self, origin, room_id, versions, limit):
|
||||
pdus = yield self.handler.on_backfill_request(
|
||||
origin, room_id, versions, limit
|
||||
)
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
pdus = yield self.handler.on_backfill_request(
|
||||
origin, room_id, versions, limit
|
||||
)
|
||||
|
||||
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
|
||||
res = self._transaction_from_pdus(pdus).get_dict()
|
||||
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
@@ -178,15 +195,59 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_context_state_request(self, origin, room_id, event_id):
|
||||
if event_id:
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
origin, room_id, event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
if not event_id:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
for event in auth_chain:
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
result = self._state_resp_cache.get((room_id, event_id))
|
||||
if not result:
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
resp = yield self._state_resp_cache.set(
|
||||
(room_id, event_id),
|
||||
self._on_context_state_request_compute(room_id, event_id)
|
||||
)
|
||||
else:
|
||||
resp = yield result
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_state_ids_request(self, origin, room_id, event_id):
|
||||
if not event_id:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
room_id, event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"pdu_ids": [pdu.event_id for pdu in pdus],
|
||||
"auth_chain_ids": [pdu.event_id for pdu in auth_chain],
|
||||
}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_context_state_request_compute(self, room_id, event_id):
|
||||
pdus = yield self.handler.get_state_for_pdu(
|
||||
room_id, event_id,
|
||||
)
|
||||
auth_chain = yield self.store.get_auth_chain(
|
||||
[pdu.event_id for pdu in pdus]
|
||||
)
|
||||
|
||||
for event in auth_chain:
|
||||
# We sign these again because there was a bug where we
|
||||
# incorrectly signed things the first time round
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
@@ -194,13 +255,11 @@ class FederationServer(FederationBase):
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Specify an event")
|
||||
|
||||
defer.returnValue((200, {
|
||||
defer.returnValue({
|
||||
"pdus": [pdu.get_pdu_json() for pdu in pdus],
|
||||
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
|
||||
}))
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
@@ -274,14 +333,16 @@ class FederationServer(FederationBase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_event_auth(self, origin, room_id, event_id):
|
||||
time_now = self._clock.time_msec()
|
||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||
defer.returnValue((200, {
|
||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
}))
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
time_now = self._clock.time_msec()
|
||||
auth_pdus = yield self.handler.on_event_auth(event_id)
|
||||
res = {
|
||||
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
|
||||
}
|
||||
defer.returnValue((200, res))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_query_auth_request(self, origin, content, event_id):
|
||||
def on_query_auth_request(self, origin, content, room_id, event_id):
|
||||
"""
|
||||
Content is a dict with keys::
|
||||
auth_chain (list): A list of events that give the auth chain.
|
||||
@@ -300,58 +361,41 @@ class FederationServer(FederationBase):
|
||||
Returns:
|
||||
Deferred: Results in `dict` with the same format as `content`
|
||||
"""
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
auth_chain = [
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
origin, auth_chain, outlier=True
|
||||
)
|
||||
signed_auth = yield self._check_sigs_and_hash_and_fetch(
|
||||
origin, auth_chain, outlier=True
|
||||
)
|
||||
|
||||
ret = yield self.handler.on_query_auth(
|
||||
origin,
|
||||
event_id,
|
||||
signed_auth,
|
||||
content.get("rejects", []),
|
||||
content.get("missing", []),
|
||||
)
|
||||
ret = yield self.handler.on_query_auth(
|
||||
origin,
|
||||
event_id,
|
||||
signed_auth,
|
||||
content.get("rejects", []),
|
||||
content.get("missing", []),
|
||||
)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
send_content = {
|
||||
"auth_chain": [
|
||||
e.get_pdu_json(time_now)
|
||||
for e in ret["auth_chain"]
|
||||
],
|
||||
"rejects": ret.get("rejects", []),
|
||||
"missing": ret.get("missing", []),
|
||||
}
|
||||
time_now = self._clock.time_msec()
|
||||
send_content = {
|
||||
"auth_chain": [
|
||||
e.get_pdu_json(time_now)
|
||||
for e in ret["auth_chain"]
|
||||
],
|
||||
"rejects": ret.get("rejects", []),
|
||||
"missing": ret.get("missing", []),
|
||||
}
|
||||
|
||||
defer.returnValue(
|
||||
(200, send_content)
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_query_client_keys(self, origin, content):
|
||||
query = []
|
||||
for user_id, device_ids in content.get("device_keys", {}).items():
|
||||
if not device_ids:
|
||||
query.append((user_id, None))
|
||||
else:
|
||||
for device_id in device_ids:
|
||||
query.append((user_id, device_id))
|
||||
|
||||
results = yield self.store.get_e2e_device_keys(query)
|
||||
|
||||
json_result = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, json_bytes in device_keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||
json_bytes
|
||||
)
|
||||
|
||||
defer.returnValue({"device_keys": json_result})
|
||||
return self.on_query_request("client_keys", content)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
@@ -377,11 +421,24 @@ class FederationServer(FederationBase):
|
||||
@log_function
|
||||
def on_get_missing_events(self, origin, room_id, earliest_events,
|
||||
latest_events, limit, min_depth):
|
||||
missing_events = yield self.handler.on_get_missing_events(
|
||||
origin, room_id, earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
logger.info(
|
||||
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
|
||||
" limit: %d, min_depth: %d",
|
||||
earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
missing_events = yield self.handler.on_get_missing_events(
|
||||
origin, room_id, earliest_events, latest_events, limit, min_depth
|
||||
)
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
if len(missing_events) < 5:
|
||||
logger.info(
|
||||
"Returning %d events: %r", len(missing_events), missing_events
|
||||
)
|
||||
else:
|
||||
logger.info("Returning %d events", len(missing_events))
|
||||
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
defer.returnValue({
|
||||
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
|
||||
@@ -481,42 +538,59 @@ class FederationServer(FederationBase):
|
||||
pdu.internal_metadata.outlier = True
|
||||
elif min_depth and pdu.depth > min_depth:
|
||||
if get_missing and prevs - seen:
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
# If we're missing stuff, ensure we only fetch stuff one
|
||||
# at a time.
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
# We recalculate seen, since it may have changed.
|
||||
have_seen = yield self.store.have_events(prevs)
|
||||
seen = set(have_seen.keys())
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest |= seen
|
||||
if prevs - seen:
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
missing_events = yield self.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
latest_events=[pdu],
|
||||
limit=10,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest |= seen
|
||||
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
missing_events.sort(key=lambda x: x.depth)
|
||||
logger.info(
|
||||
"Missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
|
||||
for e in missing_events:
|
||||
yield self._handle_new_pdu(
|
||||
origin,
|
||||
e,
|
||||
get_missing=False
|
||||
)
|
||||
missing_events = yield self.get_missing_events(
|
||||
origin,
|
||||
pdu.room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
latest_events=[pdu],
|
||||
limit=10,
|
||||
min_depth=min_depth,
|
||||
)
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
# We want to sort these by depth so we process them and
|
||||
# tell clients about them in order.
|
||||
missing_events.sort(key=lambda x: x.depth)
|
||||
|
||||
for e in missing_events:
|
||||
yield self._handle_new_pdu(
|
||||
origin,
|
||||
e,
|
||||
get_missing=False
|
||||
)
|
||||
|
||||
have_seen = yield self.store.have_events(
|
||||
[ev for ev, _ in pdu.prev_events]
|
||||
)
|
||||
|
||||
prevs = {e_id for e_id, _ in pdu.prev_events}
|
||||
seen = set(have_seen.keys())
|
||||
if prevs - seen:
|
||||
logger.info(
|
||||
"Still missing %d events for room %r: %r...",
|
||||
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
|
||||
)
|
||||
fetch_state = True
|
||||
|
||||
if fetch_state:
|
||||
@@ -531,7 +605,7 @@ class FederationServer(FederationBase):
|
||||
origin, pdu.room_id, pdu.event_id,
|
||||
)
|
||||
except:
|
||||
logger.warn("Failed to get state for event: %s", pdu.event_id)
|
||||
logger.exception("Failed to get state for event: %s", pdu.event_id)
|
||||
|
||||
yield self.handler.on_receive_pdu(
|
||||
origin,
|
||||
|
||||
@@ -72,5 +72,7 @@ class ReplicationLayer(FederationClient, FederationServer):
|
||||
|
||||
self.hs = hs
|
||||
|
||||
super(ReplicationLayer, self).__init__(hs)
|
||||
|
||||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
@@ -54,6 +54,28 @@ class TransportLayerClient(object):
|
||||
destination, path=path, args={"event_id": event_id},
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_room_state_ids(self, destination, room_id, event_id):
|
||||
""" Requests all state for a given room from the given server at the
|
||||
given event. Returns the state's event_id's
|
||||
|
||||
Args:
|
||||
destination (str): The host name of the remote home server we want
|
||||
to get the state from.
|
||||
context (str): The name of the context we want the state of
|
||||
event_id (str): The event we want the context at.
|
||||
|
||||
Returns:
|
||||
Deferred: Results in a dict received from the remote homeserver.
|
||||
"""
|
||||
logger.debug("get_room_state_ids dest=%s, room=%s",
|
||||
destination, room_id)
|
||||
|
||||
path = PREFIX + "/state_ids/%s/" % room_id
|
||||
return self.client.get_json(
|
||||
destination, path=path, args={"event_id": event_id},
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_event(self, destination, event_id, timeout=None):
|
||||
""" Requests the pdu with give id and origin from the given server.
|
||||
|
||||
@@ -18,13 +18,14 @@ from twisted.internet import defer
|
||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import parse_json_object_from_request, parse_string
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import simplejson as json
|
||||
import re
|
||||
import synapse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -37,7 +38,7 @@ class TransportLayerServer(JsonResource):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
super(TransportLayerServer, self).__init__(hs)
|
||||
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
|
||||
|
||||
self.authenticator = Authenticator(hs)
|
||||
self.ratelimiter = FederationRateLimiter(
|
||||
@@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationError(SynapseError):
|
||||
"""There was a problem authenticating the request"""
|
||||
pass
|
||||
|
||||
|
||||
class NoAuthenticationError(AuthenticationError):
|
||||
"""The request had no authentication information"""
|
||||
pass
|
||||
|
||||
|
||||
class Authenticator(object):
|
||||
def __init__(self, hs):
|
||||
self.keyring = hs.get_keyring()
|
||||
@@ -67,7 +78,7 @@ class Authenticator(object):
|
||||
|
||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||
@defer.inlineCallbacks
|
||||
def authenticate_request(self, request):
|
||||
def authenticate_request(self, request, content):
|
||||
json_request = {
|
||||
"method": request.method,
|
||||
"uri": request.uri,
|
||||
@@ -75,17 +86,10 @@ class Authenticator(object):
|
||||
"signatures": {},
|
||||
}
|
||||
|
||||
content = None
|
||||
origin = None
|
||||
if content is not None:
|
||||
json_request["content"] = content
|
||||
|
||||
if request.method in ["PUT", "POST"]:
|
||||
# TODO: Handle other method types? other content types?
|
||||
try:
|
||||
content_bytes = request.content.read()
|
||||
content = json.loads(content_bytes)
|
||||
json_request["content"] = content
|
||||
except:
|
||||
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
|
||||
origin = None
|
||||
|
||||
def parse_auth_header(header_str):
|
||||
try:
|
||||
@@ -103,14 +107,14 @@ class Authenticator(object):
|
||||
sig = strip_quotes(param_dict["sig"])
|
||||
return (origin, key, sig)
|
||||
except:
|
||||
raise SynapseError(
|
||||
raise AuthenticationError(
|
||||
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
|
||||
if not auth_headers:
|
||||
raise SynapseError(
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
@@ -121,7 +125,7 @@ class Authenticator(object):
|
||||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||
|
||||
if not json_request["signatures"]:
|
||||
raise SynapseError(
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
@@ -130,10 +134,12 @@ class Authenticator(object):
|
||||
logger.info("Request from %s", origin)
|
||||
request.authenticated_entity = origin
|
||||
|
||||
defer.returnValue((origin, content))
|
||||
defer.returnValue(origin)
|
||||
|
||||
|
||||
class BaseFederationServlet(object):
|
||||
REQUIRE_AUTH = True
|
||||
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name,
|
||||
room_list_handler):
|
||||
self.handler = handler
|
||||
@@ -141,29 +147,46 @@ class BaseFederationServlet(object):
|
||||
self.ratelimiter = ratelimiter
|
||||
self.room_list_handler = room_list_handler
|
||||
|
||||
def _wrap(self, code):
|
||||
def _wrap(self, func):
|
||||
authenticator = self.authenticator
|
||||
ratelimiter = self.ratelimiter
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@functools.wraps(code)
|
||||
def new_code(request, *args, **kwargs):
|
||||
@functools.wraps(func)
|
||||
def new_func(request, *args, **kwargs):
|
||||
content = None
|
||||
if request.method in ["PUT", "POST"]:
|
||||
# TODO: Handle other method types? other content types?
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
(origin, content) = yield authenticator.authenticate_request(request)
|
||||
with ratelimiter.ratelimit(origin) as d:
|
||||
yield d
|
||||
response = yield code(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
origin = yield authenticator.authenticate_request(request, content)
|
||||
except NoAuthenticationError:
|
||||
origin = None
|
||||
if self.REQUIRE_AUTH:
|
||||
logger.exception("authenticate_request failed")
|
||||
raise
|
||||
except:
|
||||
logger.exception("authenticate_request failed")
|
||||
raise
|
||||
|
||||
if origin:
|
||||
with ratelimiter.ratelimit(origin) as d:
|
||||
yield d
|
||||
response = yield func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
response = yield func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
# Extra logic that functools.wraps() doesn't finish
|
||||
new_code.__self__ = code.__self__
|
||||
new_func.__self__ = func.__self__
|
||||
|
||||
return new_code
|
||||
return new_func
|
||||
|
||||
def register(self, server):
|
||||
pattern = re.compile("^" + PREFIX + self.PATH + "$")
|
||||
@@ -271,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet):
|
||||
)
|
||||
|
||||
|
||||
class FederationStateIdsServlet(BaseFederationServlet):
|
||||
PATH = "/state_ids/(?P<room_id>[^/]*)/"
|
||||
|
||||
def on_GET(self, origin, content, query, room_id):
|
||||
return self.handler.on_state_ids_request(
|
||||
origin,
|
||||
room_id,
|
||||
query.get("event_id", [None])[0],
|
||||
)
|
||||
|
||||
|
||||
class FederationBackfillServlet(BaseFederationServlet):
|
||||
PATH = "/backfill/(?P<context>[^/]*)/"
|
||||
|
||||
@@ -367,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
|
||||
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||
PATH = "/user/keys/query"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query):
|
||||
response = yield self.handler.on_query_client_keys(origin, content)
|
||||
defer.returnValue((200, response))
|
||||
return self.handler.on_query_client_keys(origin, content)
|
||||
|
||||
|
||||
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||
@@ -388,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, context, event_id):
|
||||
new_content = yield self.handler.on_query_auth_request(
|
||||
origin, content, event_id
|
||||
origin, content, context, event_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
@@ -420,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
|
||||
class On3pidBindServlet(BaseFederationServlet):
|
||||
PATH = "/3pid/onbind"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
def on_POST(self, origin, content, query):
|
||||
if "invites" in content:
|
||||
last_exception = None
|
||||
for invite in content["invites"]:
|
||||
@@ -444,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
|
||||
raise last_exception
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
class OpenIdUserInfo(BaseFederationServlet):
|
||||
"""
|
||||
@@ -469,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
|
||||
|
||||
PATH = "/openid/userinfo"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
token = parse_string(request, "access_token")
|
||||
def on_GET(self, origin, content, query):
|
||||
token = query.get("access_token", [None])[0]
|
||||
if token is None:
|
||||
defer.returnValue((401, {
|
||||
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
|
||||
@@ -488,11 +518,6 @@ class OpenIdUserInfo(BaseFederationServlet):
|
||||
|
||||
defer.returnValue((200, {"sub": user_id}))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
class PublicRoomList(BaseFederationServlet):
|
||||
"""
|
||||
@@ -528,14 +553,23 @@ class PublicRoomList(BaseFederationServlet):
|
||||
PATH = "/publicRooms"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, origin, content, query):
|
||||
data = yield self.room_list_handler.get_local_public_room_list()
|
||||
defer.returnValue((200, data))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
class FederationVersionServlet(BaseFederationServlet):
|
||||
PATH = "/version"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
def on_GET(self, origin, content, query):
|
||||
return defer.succeed((200, {
|
||||
"server": {
|
||||
"name": "Synapse",
|
||||
"version": get_version_string(synapse)
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
SERVLET_CLASSES = (
|
||||
@@ -543,6 +577,7 @@ SERVLET_CLASSES = (
|
||||
FederationPullServlet,
|
||||
FederationEventServlet,
|
||||
FederationStateServlet,
|
||||
FederationStateIdsServlet,
|
||||
FederationBackfillServlet,
|
||||
FederationQueryServlet,
|
||||
FederationMakeJoinServlet,
|
||||
@@ -560,6 +595,7 @@ SERVLET_CLASSES = (
|
||||
On3pidBindServlet,
|
||||
OpenIdUserInfo,
|
||||
PublicRoomList,
|
||||
FederationVersionServlet,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,10 +31,21 @@ from .search import SearchHandler
|
||||
|
||||
class Handlers(object):
|
||||
|
||||
""" A collection of all the event handlers.
|
||||
""" Deprecated. A collection of handlers.
|
||||
|
||||
There's no need to lazily create these; we'll just make them all eagerly
|
||||
at construction time.
|
||||
At some point most of the classes whose name ended "Handler" were
|
||||
accessed through this class.
|
||||
|
||||
However this makes it painful to unit test the handlers and to run cut
|
||||
down versions of synapse that only use specific handlers because using a
|
||||
single handler required creating all of the handlers. So some of the
|
||||
handlers have been lifted out of the Handlers object and are now accessed
|
||||
directly through the homeserver object itself.
|
||||
|
||||
Any new handlers should follow the new pattern of being accessed through
|
||||
the homeserver object and should not be added to the Handlers object.
|
||||
|
||||
The remaining handlers should be moved out of the handlers object.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
|
||||
@@ -13,14 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
import synapse.types
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.types import UserID, Requester
|
||||
|
||||
|
||||
import logging
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -31,11 +31,15 @@ class BaseHandler(object):
|
||||
Common base class for the event handlers.
|
||||
|
||||
Attributes:
|
||||
store (synapse.storage.events.StateStore):
|
||||
store (synapse.storage.DataStore):
|
||||
state_handler (synapse.state.StateHandler):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
@@ -120,7 +124,8 @@ class BaseHandler(object):
|
||||
# and having homeservers have their own users leave keeps more
|
||||
# of that decision-making and control local to the guest-having
|
||||
# homeserver.
|
||||
requester = Requester(target_user, "", True)
|
||||
requester = synapse.types.create_requester(
|
||||
target_user, is_guest=True)
|
||||
handler = self.hs.get_handlers().room_member_handler
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
|
||||
@@ -20,6 +20,7 @@ from synapse.api.constants import LoginType
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.config.ldap import LDAPMode
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
@@ -28,6 +29,12 @@ import bcrypt
|
||||
import pymacaroons
|
||||
import simplejson
|
||||
|
||||
try:
|
||||
import ldap3
|
||||
except ImportError:
|
||||
ldap3 = None
|
||||
pass
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
|
||||
@@ -38,6 +45,10 @@ class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
super(AuthHandler, self).__init__(hs)
|
||||
self.checkers = {
|
||||
LoginType.PASSWORD: self._check_password_auth,
|
||||
@@ -50,19 +61,23 @@ class AuthHandler(BaseHandler):
|
||||
self.INVALID_TOKEN_HTTP_STATUS = 401
|
||||
|
||||
self.ldap_enabled = hs.config.ldap_enabled
|
||||
self.ldap_server = hs.config.ldap_server
|
||||
self.ldap_port = hs.config.ldap_port
|
||||
self.ldap_tls = hs.config.ldap_tls
|
||||
self.ldap_search_base = hs.config.ldap_search_base
|
||||
self.ldap_search_property = hs.config.ldap_search_property
|
||||
self.ldap_email_property = hs.config.ldap_email_property
|
||||
self.ldap_full_name_property = hs.config.ldap_full_name_property
|
||||
|
||||
if self.ldap_enabled is True:
|
||||
import ldap
|
||||
logger.info("Import ldap version: %s", ldap.__version__)
|
||||
if self.ldap_enabled:
|
||||
if not ldap3:
|
||||
raise RuntimeError(
|
||||
'Missing ldap3 library. This is required for LDAP Authentication.'
|
||||
)
|
||||
self.ldap_mode = hs.config.ldap_mode
|
||||
self.ldap_uri = hs.config.ldap_uri
|
||||
self.ldap_start_tls = hs.config.ldap_start_tls
|
||||
self.ldap_base = hs.config.ldap_base
|
||||
self.ldap_filter = hs.config.ldap_filter
|
||||
self.ldap_attributes = hs.config.ldap_attributes
|
||||
if self.ldap_mode == LDAPMode.SEARCH:
|
||||
self.ldap_bind_dn = hs.config.ldap_bind_dn
|
||||
self.ldap_bind_password = hs.config.ldap_bind_password
|
||||
|
||||
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip):
|
||||
@@ -220,7 +235,6 @@ class AuthHandler(BaseHandler):
|
||||
sess = self._get_session_info(session_id)
|
||||
return sess.setdefault('serverdict', {}).get(key, default)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password_auth(self, authdict, _):
|
||||
if "user" not in authdict or "password" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
@@ -230,11 +244,7 @@ class AuthHandler(BaseHandler):
|
||||
if not user_id.startswith('@'):
|
||||
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||
|
||||
if not (yield self._check_password(user_id, password)):
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
defer.returnValue(user_id)
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_recaptcha(self, authdict, clientip):
|
||||
@@ -270,8 +280,17 @@ class AuthHandler(BaseHandler):
|
||||
data = pde.response
|
||||
resp_body = simplejson.loads(data)
|
||||
|
||||
if 'success' in resp_body and resp_body['success']:
|
||||
defer.returnValue(True)
|
||||
if 'success' in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
# intend the CAPTCHA to be presented by whatever client the
|
||||
# user is using, we just care that they have completed a CAPTCHA.
|
||||
logger.info(
|
||||
"%s reCAPTCHA from hostname %s",
|
||||
"Successful" if resp_body['success'] else "Failed",
|
||||
resp_body.get('hostname')
|
||||
)
|
||||
if resp_body['success']:
|
||||
defer.returnValue(True)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -338,67 +357,84 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def login_with_password(self, user_id, password):
|
||||
def validate_password_login(self, user_id, password):
|
||||
"""
|
||||
Authenticates the user with their username and password.
|
||||
|
||||
Used only by the v1 login API.
|
||||
|
||||
Args:
|
||||
user_id (str): User ID
|
||||
user_id (str): complete @user:id
|
||||
password (str): Password
|
||||
Returns:
|
||||
A tuple of:
|
||||
The user's ID.
|
||||
The access token for the user's session.
|
||||
The refresh token for the user's session.
|
||||
defer.Deferred: (str) canonical user id
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
StoreError if there was a problem accessing the database
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
|
||||
if not (yield self._check_password(user_id, password)):
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
logger.info("Logging in user %s", user_id)
|
||||
access_token = yield self.issue_access_token(user_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id)
|
||||
defer.returnValue((user_id, access_token, refresh_token))
|
||||
return self._check_password(user_id, password)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_login_tuple_for_user_id(self, user_id):
|
||||
def get_login_tuple_for_user_id(self, user_id, device_id=None,
|
||||
initial_display_name=None):
|
||||
"""
|
||||
Gets login tuple for the user with the given user ID.
|
||||
|
||||
Creates a new access/refresh token for the user.
|
||||
|
||||
The user is assumed to have been authenticated by some other
|
||||
machanism (e.g. CAS)
|
||||
machanism (e.g. CAS), and the user_id converted to the canonical case.
|
||||
|
||||
The device will be recorded in the table if it is not there already.
|
||||
|
||||
Args:
|
||||
user_id (str): User ID
|
||||
user_id (str): canonical User ID
|
||||
device_id (str|None): the device ID to associate with the tokens.
|
||||
None to leave the tokens unassociated with a device (deprecated:
|
||||
we should always have a device ID)
|
||||
initial_display_name (str): display name to associate with the
|
||||
device if it needs re-registering
|
||||
Returns:
|
||||
A tuple of:
|
||||
The user's ID.
|
||||
The access token for the user's session.
|
||||
The refresh token for the user's session.
|
||||
Raises:
|
||||
StoreError if there was a problem storing the token.
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
logger.info("Logging in user %s on device %s", user_id, device_id)
|
||||
access_token = yield self.issue_access_token(user_id, device_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id, device_id)
|
||||
|
||||
logger.info("Logging in user %s", user_id)
|
||||
access_token = yield self.issue_access_token(user_id)
|
||||
refresh_token = yield self.issue_refresh_token(user_id)
|
||||
defer.returnValue((user_id, access_token, refresh_token))
|
||||
# the device *should* have been registered before we got here; however,
|
||||
# it's possible we raced against a DELETE operation. The thing we
|
||||
# really don't want is active access_tokens without a record of the
|
||||
# device, so we double-check it here.
|
||||
if device_id is not None:
|
||||
yield self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
defer.returnValue((access_token, refresh_token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def does_user_exist(self, user_id):
|
||||
def check_user_exists(self, user_id):
|
||||
"""
|
||||
Checks to see if a user with the given id exists. Will check case
|
||||
insensitively, but return None if there are multiple inexact matches.
|
||||
|
||||
Args:
|
||||
(str) user_id: complete @user:id
|
||||
|
||||
Returns:
|
||||
defer.Deferred: (str) canonical_user_id, or None if zero or
|
||||
multiple matches
|
||||
"""
|
||||
try:
|
||||
yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(True)
|
||||
res = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(res[0])
|
||||
except LoginError:
|
||||
defer.returnValue(False)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _find_user_id_and_pwd_hash(self, user_id):
|
||||
@@ -428,84 +464,232 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password(self, user_id, password):
|
||||
"""
|
||||
"""Authenticate a user against the LDAP and local databases.
|
||||
|
||||
user_id is checked case insensitively against the local database, but
|
||||
will throw if there are multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
Returns:
|
||||
True if the user_id successfully authenticated
|
||||
(str) the canonical_user_id
|
||||
Raises:
|
||||
LoginError if the password was incorrect
|
||||
"""
|
||||
valid_ldap = yield self._check_ldap_password(user_id, password)
|
||||
if valid_ldap:
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
valid_local_password = yield self._check_local_password(user_id, password)
|
||||
if valid_local_password:
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
result = yield self._check_local_password(user_id, password)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_local_password(self, user_id, password):
|
||||
try:
|
||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
defer.returnValue(self.validate_hash(password, password_hash))
|
||||
except LoginError:
|
||||
defer.returnValue(False)
|
||||
"""Authenticate a user against the local password database.
|
||||
|
||||
user_id is checked case insensitively, but will throw if there are
|
||||
multiple inexact matches.
|
||||
|
||||
Args:
|
||||
user_id (str): complete @user:id
|
||||
Returns:
|
||||
(str) the canonical_user_id
|
||||
Raises:
|
||||
LoginError if the password was incorrect
|
||||
"""
|
||||
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||
result = self.validate_hash(password, password_hash)
|
||||
if not result:
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_ldap_password(self, user_id, password):
|
||||
if not self.ldap_enabled:
|
||||
logger.debug("LDAP not configured")
|
||||
""" Attempt to authenticate a user against an LDAP Server
|
||||
and register an account if none exists.
|
||||
|
||||
Returns:
|
||||
True if authentication against LDAP was successful
|
||||
"""
|
||||
|
||||
if not ldap3 or not self.ldap_enabled:
|
||||
defer.returnValue(False)
|
||||
|
||||
import ldap
|
||||
if self.ldap_mode not in LDAPMode.LIST:
|
||||
raise RuntimeError(
|
||||
'Invalid ldap mode specified: {mode}'.format(
|
||||
mode=self.ldap_mode
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Authenticating %s with LDAP" % user_id)
|
||||
try:
|
||||
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
|
||||
logger.debug("Connecting LDAP server at %s" % ldap_url)
|
||||
l = ldap.initialize(ldap_url)
|
||||
if self.ldap_tls:
|
||||
logger.debug("Initiating TLS")
|
||||
self._connection.start_tls_s()
|
||||
server = ldap3.Server(self.ldap_uri)
|
||||
logger.debug(
|
||||
"Attempting ldap connection with %s",
|
||||
self.ldap_uri
|
||||
)
|
||||
|
||||
local_name = UserID.from_string(user_id).localpart
|
||||
|
||||
dn = "%s=%s, %s" % (
|
||||
self.ldap_search_property,
|
||||
local_name,
|
||||
self.ldap_search_base)
|
||||
logger.debug("DN for LDAP authentication: %s" % dn)
|
||||
|
||||
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
|
||||
|
||||
if not (yield self.does_user_exist(user_id)):
|
||||
handler = self.hs.get_handlers().registration_handler
|
||||
user_id, access_token = (
|
||||
yield handler.register(localpart=local_name)
|
||||
localpart = UserID.from_string(user_id).localpart
|
||||
if self.ldap_mode == LDAPMode.SIMPLE:
|
||||
# bind with the the local users ldap credentials
|
||||
bind_dn = "{prop}={value},{base}".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart,
|
||||
base=self.ldap_base
|
||||
)
|
||||
conn = ldap3.Connection(server, bind_dn, password)
|
||||
logger.debug(
|
||||
"Established ldap connection in simple mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded ldap connection in simple mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
conn.bind()
|
||||
|
||||
elif self.ldap_mode == LDAPMode.SEARCH:
|
||||
# connect with preconfigured credentials and search for local user
|
||||
conn = ldap3.Connection(
|
||||
server,
|
||||
self.ldap_bind_dn,
|
||||
self.ldap_bind_password
|
||||
)
|
||||
logger.debug(
|
||||
"Established ldap connection in search mode: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
if self.ldap_start_tls:
|
||||
conn.start_tls()
|
||||
logger.debug(
|
||||
"Upgraded ldap connection in search mode through StartTLS: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
conn.bind()
|
||||
|
||||
# find matching dn
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
if self.ldap_filter:
|
||||
query = "(&{query}{filter})".format(
|
||||
query=query,
|
||||
filter=self.ldap_filter
|
||||
)
|
||||
logger.debug("ldap search filter: %s", query)
|
||||
result = conn.search(self.ldap_base, query)
|
||||
|
||||
if result and len(conn.response) == 1:
|
||||
# found exactly one result
|
||||
user_dn = conn.response[0]['dn']
|
||||
logger.debug('ldap search found dn: %s', user_dn)
|
||||
|
||||
# unbind and reconnect, rebind with found dn
|
||||
conn.unbind()
|
||||
conn = ldap3.Connection(
|
||||
server,
|
||||
user_dn,
|
||||
password,
|
||||
auto_bind=True
|
||||
)
|
||||
else:
|
||||
# found 0 or > 1 results, abort!
|
||||
logger.warn(
|
||||
"ldap search returned unexpected (%d!=1) amount of results",
|
||||
len(conn.response)
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
logger.info(
|
||||
"User authenticated against ldap server: %s",
|
||||
conn
|
||||
)
|
||||
|
||||
# check for existing account, if none exists, create one
|
||||
if not (yield self.check_user_exists(user_id)):
|
||||
# query user metadata for account creation
|
||||
query = "({prop}={value})".format(
|
||||
prop=self.ldap_attributes['uid'],
|
||||
value=localpart
|
||||
)
|
||||
|
||||
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
|
||||
query = "(&{filter}{user_filter})".format(
|
||||
filter=query,
|
||||
user_filter=self.ldap_filter
|
||||
)
|
||||
logger.debug("ldap registration filter: %s", query)
|
||||
|
||||
result = conn.search(
|
||||
search_base=self.ldap_base,
|
||||
search_filter=query,
|
||||
attributes=[
|
||||
self.ldap_attributes['name'],
|
||||
self.ldap_attributes['mail']
|
||||
]
|
||||
)
|
||||
|
||||
if len(conn.response) == 1:
|
||||
attrs = conn.response[0]['attributes']
|
||||
mail = attrs[self.ldap_attributes['mail']][0]
|
||||
name = attrs[self.ldap_attributes['name']][0]
|
||||
|
||||
# create account
|
||||
registration_handler = self.hs.get_handlers().registration_handler
|
||||
user_id, access_token = (
|
||||
yield registration_handler.register(localpart=localpart)
|
||||
)
|
||||
|
||||
# TODO: bind email, set displayname with data from ldap directory
|
||||
|
||||
logger.info(
|
||||
"ldap registration successful: %d: %s (%s, %)",
|
||||
user_id,
|
||||
localpart,
|
||||
name,
|
||||
mail
|
||||
)
|
||||
else:
|
||||
logger.warn(
|
||||
"ldap registration failed: unexpected (%d!=1) amount of results",
|
||||
len(result)
|
||||
)
|
||||
defer.returnValue(False)
|
||||
|
||||
defer.returnValue(True)
|
||||
except ldap.LDAPError, e:
|
||||
logger.warn("LDAP error: %s", e)
|
||||
except ldap3.core.exceptions.LDAPException as e:
|
||||
logger.warn("Error during ldap authentication: %s", e)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def issue_access_token(self, user_id):
|
||||
def issue_access_token(self, user_id, device_id=None):
|
||||
access_token = self.generate_access_token(user_id)
|
||||
yield self.store.add_access_token_to_user(user_id, access_token)
|
||||
yield self.store.add_access_token_to_user(user_id, access_token,
|
||||
device_id)
|
||||
defer.returnValue(access_token)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def issue_refresh_token(self, user_id):
|
||||
def issue_refresh_token(self, user_id, device_id=None):
|
||||
refresh_token = self.generate_refresh_token(user_id)
|
||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token,
|
||||
device_id)
|
||||
defer.returnValue(refresh_token)
|
||||
|
||||
def generate_access_token(self, user_id, extra_caveats=None):
|
||||
def generate_access_token(self, user_id, extra_caveats=None,
|
||||
duration_in_ms=(60 * 60 * 1000)):
|
||||
extra_caveats = extra_caveats or []
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + (60 * 60 * 1000)
|
||||
expiry = now + duration_in_ms
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
for caveat in extra_caveats:
|
||||
macaroon.add_first_party_caveat(caveat)
|
||||
@@ -613,7 +797,8 @@ class AuthHandler(BaseHandler):
|
||||
Returns:
|
||||
Hashed password (str).
|
||||
"""
|
||||
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds))
|
||||
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
"""Validates that self.hash(password) == stored_hash.
|
||||
@@ -626,6 +811,7 @@ class AuthHandler(BaseHandler):
|
||||
Whether self.hash(password) == stored_hash (bool).
|
||||
"""
|
||||
if stored_hash:
|
||||
return bcrypt.hashpw(password, stored_hash) == stored_hash
|
||||
return bcrypt.hashpw(password + self.hs.config.password_pepper,
|
||||
stored_hash.encode('utf-8')) == stored_hash
|
||||
else:
|
||||
return False
|
||||
|
||||
181
synapse/handlers/device.py
Normal file
181
synapse/handlers/device.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.util import stringutils
|
||||
from twisted.internet import defer
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(DeviceHandler, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_device_registered(self, user_id, device_id,
|
||||
initial_device_display_name=None):
|
||||
"""
|
||||
If the given device has not been registered, register it with the
|
||||
supplied display name.
|
||||
|
||||
If no device_id is supplied, we make one up.
|
||||
|
||||
Args:
|
||||
user_id (str): @user:id
|
||||
device_id (str | None): device id supplied by client
|
||||
initial_device_display_name (str | None): device display name from
|
||||
client
|
||||
Returns:
|
||||
str: device id (generated if none was supplied)
|
||||
"""
|
||||
if device_id is not None:
|
||||
yield self.store.store_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=initial_device_display_name,
|
||||
ignore_if_known=True,
|
||||
)
|
||||
defer.returnValue(device_id)
|
||||
|
||||
# if the device id is not specified, we'll autogen one, but loop a few
|
||||
# times in case of a clash.
|
||||
attempts = 0
|
||||
while attempts < 5:
|
||||
try:
|
||||
device_id = stringutils.random_string_with_symbols(16)
|
||||
yield self.store.store_device(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=initial_device_display_name,
|
||||
ignore_if_known=False,
|
||||
)
|
||||
defer.returnValue(device_id)
|
||||
except errors.StoreError:
|
||||
attempts += 1
|
||||
|
||||
raise errors.StoreError(500, "Couldn't generate a device ID.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""
|
||||
Retrieve the given user's devices
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Returns:
|
||||
defer.Deferred: list[dict[str, X]]: info on each device
|
||||
"""
|
||||
|
||||
device_map = yield self.store.get_devices_by_user(user_id)
|
||||
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
devices=((user_id, device_id) for device_id in device_map.keys())
|
||||
)
|
||||
|
||||
devices = device_map.values()
|
||||
for device in devices:
|
||||
_update_device_from_client_ips(device, ips)
|
||||
|
||||
defer.returnValue(devices)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_device(self, user_id, device_id):
|
||||
""" Retrieve the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: dict[str, X]: info on the device
|
||||
Raises:
|
||||
errors.NotFoundError: if the device was not found
|
||||
"""
|
||||
try:
|
||||
device = yield self.store.get_device(user_id, device_id)
|
||||
except errors.StoreError:
|
||||
raise errors.NotFoundError
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
devices=((user_id, device_id),)
|
||||
)
|
||||
_update_device_from_client_ips(device, ips)
|
||||
defer.returnValue(device)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_device(self, user_id, device_id):
|
||||
""" Delete the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.delete_device(user_id, device_id)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
# no match
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
yield self.store.user_delete_access_tokens(
|
||||
user_id, device_id=device_id,
|
||||
delete_refresh_tokens=True,
|
||||
)
|
||||
|
||||
yield self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_device(self, user_id, device_id, content):
|
||||
""" Update the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
content (dict): body of update request
|
||||
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
|
||||
try:
|
||||
yield self.store.update_device(
|
||||
user_id,
|
||||
device_id,
|
||||
new_display_name=content.get("display_name")
|
||||
)
|
||||
except errors.StoreError, e:
|
||||
if e.code == 404:
|
||||
raise errors.NotFoundError()
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def _update_device_from_client_ips(device, client_ips):
|
||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||
device.update({
|
||||
"last_seen_ts": ip.get("last_seen"),
|
||||
"last_seen_ip": ip.get("ip"),
|
||||
})
|
||||
139
synapse/handlers/e2e_keys.py
Normal file
139
synapse/handlers/e2e_keys.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api import errors
|
||||
import synapse.types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class E2eKeysHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.server_name = hs.hostname
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
# query request requires an object POST, but we abuse the
|
||||
# "query handler" interface.
|
||||
self.federation.register_query_handler(
|
||||
"client_keys", self.on_federation_query_client_keys
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_devices(self, query_body):
|
||||
""" Handle a device key query from a client
|
||||
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": ["<device_id>"]
|
||||
}
|
||||
}
|
||||
->
|
||||
{
|
||||
"device_keys": {
|
||||
"<user_id>": {
|
||||
"<device_id>": {
|
||||
...
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
queries_by_domain = collections.defaultdict(dict)
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
user = synapse.types.UserID.from_string(user_id)
|
||||
queries_by_domain[user.domain][user_id] = device_ids
|
||||
|
||||
# do the queries
|
||||
# TODO: do these in parallel
|
||||
results = {}
|
||||
for destination, destination_query in queries_by_domain.items():
|
||||
if destination == self.server_name:
|
||||
res = yield self.query_local_devices(destination_query)
|
||||
else:
|
||||
res = yield self.federation.query_client_keys(
|
||||
destination, {"device_keys": destination_query}
|
||||
)
|
||||
res = res["device_keys"]
|
||||
for user_id, keys in res.items():
|
||||
if user_id in destination_query:
|
||||
results[user_id] = keys
|
||||
|
||||
defer.returnValue((200, {"device_keys": results}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_local_devices(self, query):
|
||||
"""Get E2E device keys for local users
|
||||
|
||||
Args:
|
||||
query (dict[string, list[string]|None): map from user_id to a list
|
||||
of devices to query (None for all devices)
|
||||
|
||||
Returns:
|
||||
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
|
||||
map from user_id -> device_id -> device details
|
||||
"""
|
||||
local_query = []
|
||||
|
||||
result_dict = {}
|
||||
for user_id, device_ids in query.items():
|
||||
if not self.is_mine_id(user_id):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
raise errors.SynapseError(400, "Not a user here")
|
||||
|
||||
if not device_ids:
|
||||
local_query.append((user_id, None))
|
||||
else:
|
||||
for device_id in device_ids:
|
||||
local_query.append((user_id, device_id))
|
||||
|
||||
# make sure that each queried user appears in the result dict
|
||||
result_dict[user_id] = {}
|
||||
|
||||
results = yield self.store.get_e2e_device_keys(local_query)
|
||||
|
||||
# Build the result structure, un-jsonify the results, and add the
|
||||
# "unsigned" section
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, device_info in device_keys.items():
|
||||
r = json.loads(device_info["key_json"])
|
||||
r["unsigned"] = {}
|
||||
display_name = device_info["device_display_name"]
|
||||
if display_name is not None:
|
||||
r["unsigned"]["device_display_name"] = display_name
|
||||
result_dict[user_id][device_id] = r
|
||||
|
||||
defer.returnValue(result_dict)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_client_keys(self, query_body):
|
||||
""" Handle a device key query from a federated server
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
res = yield self.query_local_devices(device_keys_query)
|
||||
defer.returnValue({"device_keys": res})
|
||||
@@ -124,7 +124,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
try:
|
||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||
auth_chain, state, event
|
||||
origin, auth_chain, state, event
|
||||
)
|
||||
except AuthError as e:
|
||||
raise FederationError(
|
||||
@@ -335,30 +335,59 @@ class FederationHandler(BaseHandler):
|
||||
state_events.update({s.event_id: s for s in state})
|
||||
events_to_state[e_id] = state
|
||||
|
||||
required_auth = set(
|
||||
a_id
|
||||
for event in events + state_events.values() + auth_events.values()
|
||||
for a_id, _ in event.auth_events
|
||||
)
|
||||
auth_events.update({
|
||||
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
|
||||
})
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
failed_to_fetch = set()
|
||||
|
||||
# Try and fetch any missing auth events from both DB and remote servers.
|
||||
# We repeatedly do this until we stop finding new auth events.
|
||||
while missing_auth - failed_to_fetch:
|
||||
logger.info("Missing auth for backfill: %r", missing_auth)
|
||||
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
|
||||
auth_events.update(ret_events)
|
||||
|
||||
required_auth.update(
|
||||
a_id for event in ret_events.values() for a_id, _ in event.auth_events
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
if missing_auth - failed_to_fetch:
|
||||
logger.info(
|
||||
"Fetching missing auth for backfill: %r",
|
||||
missing_auth - failed_to_fetch
|
||||
)
|
||||
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth - failed_to_fetch
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
required_auth.update(
|
||||
a_id for event in results for a_id, _ in event.auth_events
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
failed_to_fetch = missing_auth - set(auth_events)
|
||||
|
||||
seen_events = yield self.store.have_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
|
||||
all_events = events + state_events.values() + auth_events.values()
|
||||
required_auth = set(
|
||||
a_id for event in all_events for a_id, _ in event.auth_events
|
||||
)
|
||||
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
results = yield defer.gatherResults(
|
||||
[
|
||||
self.replication_layer.get_pdu(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth
|
||||
],
|
||||
consumeErrors=True
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results})
|
||||
|
||||
ev_infos = []
|
||||
for a in auth_events.values():
|
||||
if a.event_id in seen_events:
|
||||
@@ -370,6 +399,7 @@ class FederationHandler(BaseHandler):
|
||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in a.auth_events
|
||||
if a_id in auth_events
|
||||
}
|
||||
})
|
||||
|
||||
@@ -381,6 +411,7 @@ class FederationHandler(BaseHandler):
|
||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||
auth_events[a_id]
|
||||
for a_id, _ in event_map[e_id].auth_events
|
||||
if a_id in auth_events
|
||||
}
|
||||
})
|
||||
|
||||
@@ -399,7 +430,7 @@ class FederationHandler(BaseHandler):
|
||||
# previous to work out the state.
|
||||
# TODO: We can probably do something more clever here.
|
||||
yield self._handle_new_event(
|
||||
dest, event
|
||||
dest, event, backfilled=True,
|
||||
)
|
||||
|
||||
defer.returnValue(events)
|
||||
@@ -635,7 +666,7 @@ class FederationHandler(BaseHandler):
|
||||
pass
|
||||
|
||||
event_stream_id, max_stream_id = yield self._persist_auth_tree(
|
||||
auth_chain, state, event
|
||||
origin, auth_chain, state, event
|
||||
)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
@@ -686,7 +717,9 @@ class FederationHandler(BaseHandler):
|
||||
logger.warn("Failed to create join %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_join_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@@ -916,7 +949,9 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_leave_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
except AuthError as e:
|
||||
logger.warn("Failed to create new leave %r because %s", event, e)
|
||||
raise e
|
||||
@@ -985,14 +1020,9 @@ class FederationHandler(BaseHandler):
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
|
||||
def get_state_for_pdu(self, room_id, event_id):
|
||||
yield run_on_reactor()
|
||||
|
||||
if do_auth:
|
||||
in_room = yield self.auth.check_host_in_room(room_id, origin)
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
state_groups = yield self.store.get_state_groups(
|
||||
room_id, [event_id]
|
||||
)
|
||||
@@ -1016,13 +1046,16 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
res = results.values()
|
||||
for event in res:
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
# We sign these again because there was a bug where we
|
||||
# incorrectly signed things the first time round
|
||||
if self.hs.is_mine_id(event.event_id):
|
||||
event.signatures.update(
|
||||
compute_event_signature(
|
||||
event,
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
else:
|
||||
@@ -1109,11 +1142,12 @@ class FederationHandler(BaseHandler):
|
||||
backfilled=backfilled,
|
||||
)
|
||||
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
|
||||
event_stream_id, max_stream_id
|
||||
)
|
||||
if not backfilled:
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
|
||||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
defer.returnValue((context, event_stream_id, max_stream_id))
|
||||
|
||||
@@ -1145,11 +1179,19 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _persist_auth_tree(self, auth_events, state, event):
|
||||
def _persist_auth_tree(self, origin, auth_events, state, event):
|
||||
"""Checks the auth chain is valid (and passes auth checks) for the
|
||||
state and event. Then persists the auth chain and state atomically.
|
||||
Persists the event seperately.
|
||||
|
||||
Will attempt to fetch missing auth events.
|
||||
|
||||
Args:
|
||||
origin (str): Where the events came from
|
||||
auth_events (list)
|
||||
state (list)
|
||||
event (Event)
|
||||
|
||||
Returns:
|
||||
2-tuple of (event_stream_id, max_stream_id) from the persist_event
|
||||
call for `event`
|
||||
@@ -1162,7 +1204,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
event_map = {
|
||||
e.event_id: e
|
||||
for e in auth_events
|
||||
for e in itertools.chain(auth_events, state, [event])
|
||||
}
|
||||
|
||||
create_event = None
|
||||
@@ -1171,10 +1213,29 @@ class FederationHandler(BaseHandler):
|
||||
create_event = e
|
||||
break
|
||||
|
||||
missing_auth_events = set()
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
for e_id, _ in e.auth_events:
|
||||
if e_id not in event_map:
|
||||
missing_auth_events.add(e_id)
|
||||
|
||||
for e_id in missing_auth_events:
|
||||
m_ev = yield self.replication_layer.get_pdu(
|
||||
[origin],
|
||||
e_id,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
if m_ev and m_ev.event_id == e_id:
|
||||
event_map[e_id] = m_ev
|
||||
else:
|
||||
logger.info("Failed to find auth event %r", e_id)
|
||||
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
auth_for_e = {
|
||||
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
|
||||
for e_id, _ in e.auth_events
|
||||
if e_id in event_map
|
||||
}
|
||||
if create_event:
|
||||
auth_for_e[(EventTypes.Create, "")] = create_event
|
||||
@@ -1408,7 +1469,7 @@ class FederationHandler(BaseHandler):
|
||||
local_view = dict(auth_events)
|
||||
remote_view = dict(auth_events)
|
||||
remote_view.update({
|
||||
(d.type, d.state_key): d for d in different_events
|
||||
(d.type, d.state_key): d for d in different_events if d
|
||||
})
|
||||
|
||||
new_state, prev_state = self.state_handler.resolve_events(
|
||||
|
||||
@@ -21,7 +21,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
import json
|
||||
import logging
|
||||
@@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
|
||||
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
)
|
||||
|
||||
def _should_trust_id_server(self, id_server):
|
||||
if id_server not in self.trusted_id_servers:
|
||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||
logger.warn(
|
||||
"Trusting untrustworthy ID server %r even though it isn't"
|
||||
" in the trusted id list for testing because"
|
||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||
" is set in the config",
|
||||
id_server,
|
||||
)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def threepid_from_creds(self, creds):
|
||||
yield run_on_reactor()
|
||||
@@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler):
|
||||
else:
|
||||
raise SynapseError(400, "No client_secret in creds")
|
||||
|
||||
if id_server not in self.trusted_id_servers:
|
||||
if self.trust_any_id_server_just_for_testing_do_not_use:
|
||||
logger.warn(
|
||||
"Trusting untrustworthy ID server %r even though it isn't"
|
||||
" in the trusted id list for testing because"
|
||||
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
|
||||
" is set in the config",
|
||||
id_server,
|
||||
)
|
||||
else:
|
||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server)
|
||||
defer.returnValue(None)
|
||||
if not self._should_trust_id_server(id_server):
|
||||
logger.warn(
|
||||
'%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server
|
||||
)
|
||||
defer.returnValue(None)
|
||||
|
||||
data = {}
|
||||
try:
|
||||
@@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
|
||||
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
||||
yield run_on_reactor()
|
||||
|
||||
if not self._should_trust_id_server(id_server):
|
||||
raise SynapseError(
|
||||
400, "Untrusted ID server '%s'" % id_server,
|
||||
Codes.SERVER_NOT_TRUSTED
|
||||
)
|
||||
|
||||
params = {
|
||||
'email': email,
|
||||
'client_secret': client_secret,
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.types import (
|
||||
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async import concurrently_execute, run_on_reactor
|
||||
from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.visibility import filter_events_for_client
|
||||
@@ -50,9 +50,23 @@ class MessageHandler(BaseHandler):
|
||||
self.validator = EventValidator()
|
||||
self.snapshot_cache = SnapshotCache()
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id):
|
||||
event = yield self.store.get_event(event_id)
|
||||
|
||||
if event.room_id != room_id:
|
||||
raise SynapseError(400, "Event is for wrong room.")
|
||||
|
||||
depth = event.depth
|
||||
|
||||
with (yield self.pagination_lock.write(room_id)):
|
||||
yield self.store.delete_old_state(room_id, depth)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||
as_client_event=True):
|
||||
as_client_event=True, event_filter=None):
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
@@ -61,11 +75,11 @@ class MessageHandler(BaseHandler):
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config rules to apply, if any.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
event_filter (Filter): Filter to apply to results or None
|
||||
Returns:
|
||||
dict: Pagination API results
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
data_source = self.hs.get_event_sources().sources["room"]
|
||||
|
||||
if pagin_config.from_token:
|
||||
room_token = pagin_config.from_token.room_key
|
||||
@@ -85,42 +99,48 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
|
||||
if source_config.direction == 'b':
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
# requires that we have a topo token.
|
||||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
if membership == Membership.LEAVE:
|
||||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, max_topo
|
||||
with (yield self.pagination_lock.read(room_id)):
|
||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
|
||||
events, next_key = yield data_source.get_pagination_rows(
|
||||
requester.user, source_config, room_id
|
||||
)
|
||||
if source_config.direction == 'b':
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
# requires that we have a topo token.
|
||||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = yield self.store.get_max_topological_token(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
next_token = pagin_config.from_token.copy_and_replace(
|
||||
"room_key", next_key
|
||||
)
|
||||
if membership == Membership.LEAVE:
|
||||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, max_topo
|
||||
)
|
||||
|
||||
events, next_key = yield self.store.paginate_room_events(
|
||||
room_id=room_id,
|
||||
from_key=source_config.from_key,
|
||||
to_key=source_config.to_key,
|
||||
direction=source_config.direction,
|
||||
limit=source_config.limit,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
|
||||
next_token = pagin_config.from_token.copy_and_replace(
|
||||
"room_key", next_key
|
||||
)
|
||||
|
||||
if not events:
|
||||
defer.returnValue({
|
||||
@@ -129,6 +149,9 @@ class MessageHandler(BaseHandler):
|
||||
"end": next_token.to_string(),
|
||||
})
|
||||
|
||||
if event_filter:
|
||||
events = event_filter.filter(events)
|
||||
|
||||
events = yield filter_events_for_client(
|
||||
self.store,
|
||||
user_id,
|
||||
|
||||
@@ -13,15 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.types import UserID, Requester
|
||||
|
||||
from synapse.types import UserID
|
||||
from ._base import BaseHandler
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -36,13 +36,6 @@ class ProfileHandler(BaseHandler):
|
||||
"profile", self.on_profile_query
|
||||
)
|
||||
|
||||
distributor = hs.get_distributor()
|
||||
|
||||
distributor.observe("registered_user", self.registered_user)
|
||||
|
||||
def registered_user(self, user):
|
||||
return self.store.create_profile(user.localpart)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user):
|
||||
if self.hs.is_mine(target_user):
|
||||
@@ -172,7 +165,9 @@ class ProfileHandler(BaseHandler):
|
||||
try:
|
||||
# Assume the user isn't a guest because we don't let guests set
|
||||
# profile or avatar data.
|
||||
requester = Requester(user, "", False)
|
||||
# XXX why are we recreating `requester` here for each room?
|
||||
# what was wrong with the `requester` we were passed?
|
||||
requester = synapse.types.create_requester(user)
|
||||
yield handler.update_membership(
|
||||
requester,
|
||||
user,
|
||||
|
||||
@@ -14,19 +14,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains functions for registering clients."""
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID, Requester
|
||||
import synapse.types
|
||||
from synapse.api.errors import (
|
||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
from synapse.util.distributor import registered_user
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,8 +37,6 @@ class RegistrationHandler(BaseHandler):
|
||||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("registered_user")
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
|
||||
self._next_generated_user_id = None
|
||||
@@ -55,6 +53,13 @@ class RegistrationHandler(BaseHandler):
|
||||
Codes.INVALID_USERNAME
|
||||
)
|
||||
|
||||
if localpart[0] == '_':
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID may not begin with _",
|
||||
Codes.INVALID_USERNAME
|
||||
)
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
@@ -93,7 +98,8 @@ class RegistrationHandler(BaseHandler):
|
||||
password=None,
|
||||
generate_token=True,
|
||||
guest_access_token=None,
|
||||
make_guest=False
|
||||
make_guest=False,
|
||||
admin=False,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
@@ -101,8 +107,13 @@ class RegistrationHandler(BaseHandler):
|
||||
localpart : The local part of the user ID to register. If None,
|
||||
one will be generated.
|
||||
password (str) : The password to assign to this user so they can
|
||||
login again. This can be None which means they cannot login again
|
||||
via a password (e.g. the user is an application service user).
|
||||
login again. This can be None which means they cannot login again
|
||||
via a password (e.g. the user is an application service user).
|
||||
generate_token (bool): Whether a new access token should be
|
||||
generated. Having this be True should be considered deprecated,
|
||||
since it offers no means of associating a device_id with the
|
||||
access_token. Instead you should call auth_handler.issue_access_token
|
||||
after registration.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
@@ -140,9 +151,12 @@ class RegistrationHandler(BaseHandler):
|
||||
password_hash=password_hash,
|
||||
was_guest=was_guest,
|
||||
make_guest=make_guest,
|
||||
create_profile_with_localpart=(
|
||||
# If the user was a guest then they already have a profile
|
||||
None if was_guest else user.localpart
|
||||
),
|
||||
admin=admin,
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
# autogen a sequential user ID
|
||||
attempts = 0
|
||||
@@ -160,7 +174,8 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=password_hash,
|
||||
make_guest=make_guest
|
||||
make_guest=make_guest,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
except SynapseError:
|
||||
# if user id is taken, just generate another
|
||||
@@ -168,7 +183,6 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id = None
|
||||
token = None
|
||||
attempts += 1
|
||||
yield registered_user(self.distributor, user)
|
||||
|
||||
# We used to generate default identicons here, but nowadays
|
||||
# we want clients to generate their own as part of their branding
|
||||
@@ -195,15 +209,13 @@ class RegistrationHandler(BaseHandler):
|
||||
user_id, allowed_appservice=service
|
||||
)
|
||||
|
||||
token = self.auth_handler().generate_access_token(user_id)
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash="",
|
||||
appservice_id=service_id,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
yield registered_user(self.distributor, user)
|
||||
defer.returnValue((user_id, token))
|
||||
defer.returnValue(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_recaptcha(self, ip, private_key, challenge, response):
|
||||
@@ -248,9 +260,9 @@ class RegistrationHandler(BaseHandler):
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
password_hash=None,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
yield registered_user(self.distributor, user)
|
||||
except Exception as e:
|
||||
yield self.store.add_access_token_to_user(user_id, token)
|
||||
# Ignore Registration errors
|
||||
@@ -359,7 +371,8 @@ class RegistrationHandler(BaseHandler):
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_or_create_user(self, localpart, displayname, duration_seconds):
|
||||
def get_or_create_user(self, localpart, displayname, duration_in_ms,
|
||||
password_hash=None):
|
||||
"""Creates a new user if the user does not exist,
|
||||
else revokes all previous access tokens and generates a new one.
|
||||
|
||||
@@ -388,17 +401,16 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds)
|
||||
token = self.auth_handler().generate_access_token(
|
||||
user_id, None, duration_in_ms)
|
||||
|
||||
if need_register:
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
password_hash=None
|
||||
password_hash=password_hash,
|
||||
create_profile_with_localpart=user.localpart,
|
||||
)
|
||||
|
||||
yield registered_user(self.distributor, user)
|
||||
else:
|
||||
yield self.store.user_delete_access_tokens(user_id=user_id)
|
||||
yield self.store.add_access_token_to_user(user_id=user_id, token=token)
|
||||
@@ -406,8 +418,9 @@ class RegistrationHandler(BaseHandler):
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
requester = synapse.types.create_requester(user)
|
||||
yield profile_handler.set_displayname(
|
||||
user, Requester(user, token, False), displayname
|
||||
user, requester, displayname
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
@@ -20,7 +20,7 @@ from ._base import BaseHandler
|
||||
|
||||
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
|
||||
from synapse.api.constants import (
|
||||
EventTypes, JoinRules, RoomCreationPreset,
|
||||
EventTypes, JoinRules, RoomCreationPreset, Membership,
|
||||
)
|
||||
from synapse.api.errors import AuthError, StoreError, SynapseError
|
||||
from synapse.util import stringutils
|
||||
@@ -345,8 +345,8 @@ class RoomCreationHandler(BaseHandler):
|
||||
class RoomListHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(RoomListHandler, self).__init__(hs)
|
||||
self.response_cache = ResponseCache()
|
||||
self.remote_list_request_cache = ResponseCache()
|
||||
self.response_cache = ResponseCache(hs)
|
||||
self.remote_list_request_cache = ResponseCache(hs)
|
||||
self.remote_list_cache = {}
|
||||
self.fetch_looping_call = hs.get_clock().looping_call(
|
||||
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
|
||||
@@ -367,14 +367,10 @@ class RoomListHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_room(room_id):
|
||||
# We pull each bit of state out indvidually to avoid pulling the
|
||||
# full state into memory. Due to how the caching works this should
|
||||
# be fairly quick, even if not originally in the cache.
|
||||
def get_state(etype, state_key):
|
||||
return self.state_handler.get_current_state(room_id, etype, state_key)
|
||||
current_state = yield self.state_handler.get_current_state(room_id)
|
||||
|
||||
# Double check that this is actually a public room.
|
||||
join_rules_event = yield get_state(EventTypes.JoinRules, "")
|
||||
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
join_rule = join_rules_event.content.get("join_rule", None)
|
||||
if join_rule and join_rule != JoinRules.PUBLIC:
|
||||
@@ -382,47 +378,51 @@ class RoomListHandler(BaseHandler):
|
||||
|
||||
result = {"room_id": room_id}
|
||||
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
if len(joined_users) == 0:
|
||||
num_joined_users = len([
|
||||
1 for _, event in current_state.items()
|
||||
if event.type == EventTypes.Member
|
||||
and event.membership == Membership.JOIN
|
||||
])
|
||||
if num_joined_users == 0:
|
||||
return
|
||||
|
||||
result["num_joined_members"] = len(joined_users)
|
||||
result["num_joined_members"] = num_joined_users
|
||||
|
||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
|
||||
name_event = yield get_state(EventTypes.Name, "")
|
||||
name_event = yield current_state.get((EventTypes.Name, ""))
|
||||
if name_event:
|
||||
name = name_event.content.get("name", None)
|
||||
if name:
|
||||
result["name"] = name
|
||||
|
||||
topic_event = yield get_state(EventTypes.Topic, "")
|
||||
topic_event = current_state.get((EventTypes.Topic, ""))
|
||||
if topic_event:
|
||||
topic = topic_event.content.get("topic", None)
|
||||
if topic:
|
||||
result["topic"] = topic
|
||||
|
||||
canonical_event = yield get_state(EventTypes.CanonicalAlias, "")
|
||||
canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
|
||||
if canonical_event:
|
||||
canonical_alias = canonical_event.content.get("alias", None)
|
||||
if canonical_alias:
|
||||
result["canonical_alias"] = canonical_alias
|
||||
|
||||
visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "")
|
||||
visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
|
||||
visibility = None
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get("history_visibility", None)
|
||||
result["world_readable"] = visibility == "world_readable"
|
||||
|
||||
guest_event = yield get_state(EventTypes.GuestAccess, "")
|
||||
guest_event = current_state.get((EventTypes.GuestAccess, ""))
|
||||
guest = None
|
||||
if guest_event:
|
||||
guest = guest_event.content.get("guest_access", None)
|
||||
result["guest_can_join"] = guest == "can_join"
|
||||
|
||||
avatar_event = yield get_state("m.room.avatar", "")
|
||||
avatar_event = current_state.get(("m.room.avatar", ""))
|
||||
if avatar_event:
|
||||
avatar_url = avatar_event.content.get("url", None)
|
||||
if avatar_url:
|
||||
|
||||
@@ -14,24 +14,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
from twisted.internet import defer
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
from synapse.types import UserID, RoomID, Requester
|
||||
import synapse.types
|
||||
from synapse.api.constants import (
|
||||
EventTypes, Membership,
|
||||
)
|
||||
from synapse.api.errors import AuthError, SynapseError, Codes
|
||||
from synapse.types import UserID, RoomID
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.distributor import user_left_room, user_joined_room
|
||||
|
||||
from signedjson.sign import verify_signed_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
import logging
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
)
|
||||
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
|
||||
else:
|
||||
requester = Requester(target_user, None, False)
|
||||
requester = synapse.types.create_requester(target_user)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||
|
||||
@@ -138,7 +138,7 @@ class SyncHandler(object):
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.clock = hs.get_clock()
|
||||
self.response_cache = ResponseCache()
|
||||
self.response_cache = ResponseCache(hs)
|
||||
|
||||
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
|
||||
full_state=False):
|
||||
|
||||
@@ -221,6 +221,9 @@ class TypingHandler(object):
|
||||
|
||||
def get_all_typing_updates(self, last_id, current_id):
|
||||
# TODO: Work out a way to do this without scanning the entire state.
|
||||
if last_id == current_id:
|
||||
return []
|
||||
|
||||
rows = []
|
||||
for room_id, serial in self._room_serials.items():
|
||||
if last_id < serial and serial <= current_id:
|
||||
|
||||
@@ -24,12 +24,13 @@ from synapse.http.endpoint import SpiderEndpoint
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer, reactor, ssl, protocol
|
||||
from twisted.internet import defer, reactor, ssl, protocol, task
|
||||
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
|
||||
from twisted.web.client import (
|
||||
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
||||
readBody, FileBodyProducer, PartialDownloadError,
|
||||
readBody, PartialDownloadError,
|
||||
)
|
||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web._newclient import ResponseDone
|
||||
@@ -468,3 +469,26 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
||||
|
||||
def creatorForNetloc(self, hostname, port):
|
||||
return self
|
||||
|
||||
|
||||
class FileBodyProducer(TwistedFileBodyProducer):
|
||||
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
|
||||
|
||||
We override the pauseProducing and resumeProducing methods in twisted's
|
||||
FileBodyProducer so that they do not raise exceptions if the task has
|
||||
already completed.
|
||||
"""
|
||||
|
||||
def pauseProducing(self):
|
||||
try:
|
||||
super(FileBodyProducer, self).pauseProducing()
|
||||
except task.TaskDone:
|
||||
# task has already completed
|
||||
pass
|
||||
|
||||
def resumeProducing(self):
|
||||
try:
|
||||
super(FileBodyProducer, self).resumeProducing()
|
||||
except task.NotPaused:
|
||||
# task was not paused (probably because it had already completed)
|
||||
pass
|
||||
|
||||
@@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
|
||||
def register_paths(self, method, path_patterns, callback):
|
||||
for path_pattern in path_patterns:
|
||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
self._PathEntry(path_pattern, callback)
|
||||
)
|
||||
|
||||
@@ -27,7 +27,8 @@ import gc
|
||||
from twisted.internet import reactor
|
||||
|
||||
from .metric import (
|
||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
|
||||
MemoryUsageMetric,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,6 +67,21 @@ class Metrics(object):
|
||||
return self._register(CacheMetric, *args, **kwargs)
|
||||
|
||||
|
||||
def register_memory_metrics(hs):
|
||||
try:
|
||||
import psutil
|
||||
process = psutil.Process()
|
||||
process.memory_info().rss
|
||||
except (ImportError, AttributeError):
|
||||
logger.warn(
|
||||
"psutil is not installed or incorrect version."
|
||||
" Disabling memory metrics."
|
||||
)
|
||||
return
|
||||
metric = MemoryUsageMetric(hs, psutil)
|
||||
all_metrics.append(metric)
|
||||
|
||||
|
||||
def get_metrics_for(pkg_name):
|
||||
""" Returns a Metrics instance for conveniently creating metrics
|
||||
namespaced with the given name prefix. """
|
||||
|
||||
@@ -153,3 +153,43 @@ class CacheMetric(object):
|
||||
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
|
||||
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
|
||||
]
|
||||
|
||||
|
||||
class MemoryUsageMetric(object):
|
||||
"""Keeps track of the current memory usage, using psutil.
|
||||
|
||||
The class will keep the current min/max/sum/counts of rss over the last
|
||||
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
|
||||
"""
|
||||
|
||||
UPDATE_HZ = 2 # number of times to get memory per second
|
||||
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
|
||||
|
||||
def __init__(self, hs, psutil):
|
||||
clock = hs.get_clock()
|
||||
self.memory_snapshots = []
|
||||
|
||||
self.process = psutil.Process()
|
||||
|
||||
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
|
||||
|
||||
def _update_curr_values(self):
|
||||
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
|
||||
self.memory_snapshots.append(self.process.memory_info().rss)
|
||||
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
|
||||
|
||||
def render(self):
|
||||
if not self.memory_snapshots:
|
||||
return []
|
||||
|
||||
max_rss = max(self.memory_snapshots)
|
||||
min_rss = min(self.memory_snapshots)
|
||||
sum_rss = sum(self.memory_snapshots)
|
||||
len_rss = len(self.memory_snapshots)
|
||||
|
||||
return [
|
||||
"process_psutil_rss:max %d" % max_rss,
|
||||
"process_psutil_rss:min %d" % min_rss,
|
||||
"process_psutil_rss:total %d" % sum_rss,
|
||||
"process_psutil_rss:count %d" % len_rss,
|
||||
]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
import logging
|
||||
|
||||
@@ -92,7 +93,11 @@ class EmailPusher(object):
|
||||
|
||||
def on_stop(self):
|
||||
if self.timed_call:
|
||||
self.timed_call.cancel()
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
except (AlreadyCalled, AlreadyCancelled):
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
|
||||
@@ -140,9 +145,8 @@ class EmailPusher(object):
|
||||
being run.
|
||||
"""
|
||||
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
|
||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
||||
self.user_id, start, self.max_stream_ordering
|
||||
)
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_email
|
||||
unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
|
||||
|
||||
soonest_due_at = None
|
||||
|
||||
@@ -190,7 +194,10 @@ class EmailPusher(object):
|
||||
soonest_due_at = should_notify_at
|
||||
|
||||
if self.timed_call is not None:
|
||||
self.timed_call.cancel()
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
except (AlreadyCalled, AlreadyCancelled):
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
if soonest_due_at is not None:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
from synapse.push import PusherConfigException
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
import logging
|
||||
import push_rule_evaluator
|
||||
@@ -38,6 +39,7 @@ class HttpPusher(object):
|
||||
self.hs = hs
|
||||
self.store = self.hs.get_datastore()
|
||||
self.clock = self.hs.get_clock()
|
||||
self.state_handler = self.hs.get_state_handler()
|
||||
self.user_id = pusherdict['user_name']
|
||||
self.app_id = pusherdict['app_id']
|
||||
self.app_display_name = pusherdict['app_display_name']
|
||||
@@ -108,7 +110,11 @@ class HttpPusher(object):
|
||||
|
||||
def on_stop(self):
|
||||
if self.timed_call:
|
||||
self.timed_call.cancel()
|
||||
try:
|
||||
self.timed_call.cancel()
|
||||
except (AlreadyCalled, AlreadyCancelled):
|
||||
pass
|
||||
self.timed_call = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process(self):
|
||||
@@ -140,7 +146,8 @@ class HttpPusher(object):
|
||||
run once per pusher.
|
||||
"""
|
||||
|
||||
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
|
||||
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
|
||||
unprocessed = yield fn(
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
|
||||
@@ -237,7 +244,9 @@ class HttpPusher(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _build_notification_dict(self, event, tweaks, badge):
|
||||
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
|
||||
ctx = yield push_tools.get_context_for_event(
|
||||
self.state_handler, event, self.user_id
|
||||
)
|
||||
|
||||
d = {
|
||||
'notification': {
|
||||
@@ -269,8 +278,8 @@ class HttpPusher(object):
|
||||
if 'content' in event:
|
||||
d['notification']['content'] = event.content
|
||||
|
||||
if len(ctx['aliases']):
|
||||
d['notification']['room_alias'] = ctx['aliases'][0]
|
||||
# We no longer send aliases separately, instead, we send the human
|
||||
# readable name of the room, which may be an alias.
|
||||
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
|
||||
d['notification']['sender_display_name'] = ctx['sender_display_name']
|
||||
if 'name' in ctx and len(ctx['name']) > 0:
|
||||
|
||||
@@ -273,16 +273,16 @@ class Mailer(object):
|
||||
|
||||
sender_state_event = room_state[("m.room.member", event.sender)]
|
||||
sender_name = name_from_member_event(sender_state_event)
|
||||
sender_avatar_url = None
|
||||
if "avatar_url" in sender_state_event.content:
|
||||
sender_avatar_url = sender_state_event.content["avatar_url"]
|
||||
sender_avatar_url = sender_state_event.content.get("avatar_url")
|
||||
|
||||
# 'hash' for deterministically picking default images: use
|
||||
# sender_hash % the number of default images to choose from
|
||||
sender_hash = string_ordinal_total(event.sender)
|
||||
|
||||
msgtype = event.content.get("msgtype")
|
||||
|
||||
ret = {
|
||||
"msgtype": event.content["msgtype"],
|
||||
"msgtype": msgtype,
|
||||
"is_historical": event.event_id != notif['event_id'],
|
||||
"id": event.event_id,
|
||||
"ts": event.origin_server_ts,
|
||||
@@ -291,9 +291,9 @@ class Mailer(object):
|
||||
"sender_hash": sender_hash,
|
||||
}
|
||||
|
||||
if event.content["msgtype"] == "m.text":
|
||||
if msgtype == "m.text":
|
||||
self.add_text_message_vars(ret, event)
|
||||
elif event.content["msgtype"] == "m.image":
|
||||
elif msgtype == "m.image":
|
||||
self.add_image_message_vars(ret, event)
|
||||
|
||||
if "body" in event.content:
|
||||
@@ -302,16 +302,17 @@ class Mailer(object):
|
||||
return ret
|
||||
|
||||
def add_text_message_vars(self, messagevars, event):
|
||||
if "format" in event.content:
|
||||
msgformat = event.content["format"]
|
||||
else:
|
||||
msgformat = None
|
||||
msgformat = event.content.get("format")
|
||||
|
||||
messagevars["format"] = msgformat
|
||||
|
||||
if msgformat == "org.matrix.custom.html":
|
||||
messagevars["body_text_html"] = safe_markup(event.content["formatted_body"])
|
||||
else:
|
||||
messagevars["body_text_html"] = safe_text(event.content["body"])
|
||||
formatted_body = event.content.get("formatted_body")
|
||||
body = event.content.get("body")
|
||||
|
||||
if msgformat == "org.matrix.custom.html" and formatted_body:
|
||||
messagevars["body_text_html"] = safe_markup(formatted_body)
|
||||
elif body:
|
||||
messagevars["body_text_html"] = safe_text(body)
|
||||
|
||||
return messagevars
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.util.presentable_names import (
|
||||
calculate_room_name, name_from_member_event
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_context_for_event(store, ev):
|
||||
name_aliases = yield store.get_room_name_and_aliases(
|
||||
ev.room_id
|
||||
)
|
||||
def get_context_for_event(state_handler, ev, user_id):
|
||||
ctx = {}
|
||||
|
||||
ctx = {'aliases': name_aliases[1]}
|
||||
if name_aliases[0] is not None:
|
||||
ctx['name'] = name_aliases[0]
|
||||
room_state = yield state_handler.get_current_state(ev.room_id)
|
||||
|
||||
their_member_events_for_room = yield store.get_current_state(
|
||||
room_id=ev.room_id,
|
||||
event_type='m.room.member',
|
||||
state_key=ev.user_id
|
||||
# we no longer bother setting room_alias, and make room_name the
|
||||
# human-readable name instead, be that m.room.name, an alias or
|
||||
# a list of people in the room
|
||||
name = calculate_room_name(
|
||||
room_state, user_id, fallback_to_single_member=False
|
||||
)
|
||||
for mev in their_member_events_for_room:
|
||||
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
|
||||
dn = mev.content['displayname']
|
||||
if dn is not None:
|
||||
ctx['sender_display_name'] = dn
|
||||
if name:
|
||||
ctx['name'] = name
|
||||
|
||||
sender_state_event = room_state[("m.room.member", ev.sender)]
|
||||
ctx['sender_display_name'] = name_from_member_event(sender_state_event)
|
||||
|
||||
defer.returnValue(ctx)
|
||||
|
||||
@@ -48,6 +48,12 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"Jinja2>=2.8": ["Jinja2>=2.8"],
|
||||
"bleach>=1.4.2": ["bleach>=1.4.2"],
|
||||
},
|
||||
"ldap": {
|
||||
"ldap3>=1.0": ["ldap3>=1.0"],
|
||||
},
|
||||
"psutil": {
|
||||
"psutil>=2.0.0": ["psutil>=2.0.0"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
23
synapse/replication/slave/storage/directory.py
Normal file
23
synapse/replication/slave/storage/directory.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage.directory import DirectoryStore
|
||||
|
||||
|
||||
class DirectoryStore(BaseSlavedStore):
|
||||
get_aliases_for_room = DirectoryStore.__dict__[
|
||||
"get_aliases_for_room"
|
||||
].orig
|
||||
@@ -18,7 +18,6 @@ from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.room import RoomStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
@@ -64,7 +63,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
|
||||
# Cached functions can't be accessed through a class instance so we need
|
||||
# to reach inside the __dict__ to extract them.
|
||||
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
|
||||
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
|
||||
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
|
||||
get_latest_event_ids_in_room = EventFederationStore.__dict__[
|
||||
@@ -95,8 +93,11 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||
)
|
||||
|
||||
get_unread_push_actions_for_user_in_range = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range.__func__
|
||||
get_unread_push_actions_for_user_in_range_for_http = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
|
||||
)
|
||||
get_unread_push_actions_for_user_in_range_for_email = (
|
||||
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
|
||||
)
|
||||
get_push_action_users_in_range = (
|
||||
DataStore.get_push_action_users_in_range.__func__
|
||||
@@ -144,6 +145,15 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
|
||||
|
||||
get_backfill_events = DataStore.get_backfill_events.__func__
|
||||
_get_backfill_events = DataStore._get_backfill_events.__func__
|
||||
get_missing_events = DataStore.get_missing_events.__func__
|
||||
_get_missing_events = DataStore._get_missing_events.__func__
|
||||
|
||||
get_auth_chain = DataStore.get_auth_chain.__func__
|
||||
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
|
||||
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
@@ -202,7 +212,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
self.get_rooms_for_user.invalidate_all()
|
||||
self.get_users_in_room.invalidate((event.room_id,))
|
||||
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
|
||||
self.get_room_name_and_aliases.invalidate((event.room_id,))
|
||||
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
|
||||
@@ -246,9 +255,3 @@ class SlavedEventStore(BaseSlavedStore):
|
||||
self._get_current_state_for_key.invalidate((
|
||||
event.room_id, event.type, event.state_key
|
||||
))
|
||||
|
||||
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||
self.get_room_name_and_aliases.invalidate(
|
||||
(event.room_id,)
|
||||
)
|
||||
pass
|
||||
|
||||
33
synapse/replication/slave/storage/keys.py
Normal file
33
synapse/replication/slave/storage/keys.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.keys import KeyStore
|
||||
|
||||
|
||||
class SlavedKeyStore(BaseSlavedStore):
|
||||
_get_server_verify_key = KeyStore.__dict__[
|
||||
"_get_server_verify_key"
|
||||
]
|
||||
|
||||
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
|
||||
store_server_verify_key = DataStore.store_server_verify_key.__func__
|
||||
|
||||
get_server_certificate = DataStore.get_server_certificate.__func__
|
||||
store_server_certificate = DataStore.store_server_certificate.__func__
|
||||
|
||||
get_server_keys_json = DataStore.get_server_keys_json.__func__
|
||||
store_server_keys_json = DataStore.store_server_keys_json.__func__
|
||||
21
synapse/replication/slave/storage/room.py
Normal file
21
synapse/replication/slave/storage/room.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
|
||||
|
||||
class RoomStore(BaseSlavedStore):
|
||||
get_public_room_ids = DataStore.get_public_room_ids.__func__
|
||||
30
synapse/replication/slave/storage/transactions.py
Normal file
30
synapse/replication/slave/storage/transactions.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from ._base import BaseSlavedStore
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.transactions import TransactionStore
|
||||
|
||||
|
||||
class TransactionStore(BaseSlavedStore):
|
||||
get_destination_retry_timings = TransactionStore.__dict__[
|
||||
"get_destination_retry_timings"
|
||||
].orig
|
||||
_get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__
|
||||
|
||||
# For now, don't record the destination rety timings
|
||||
def set_destination_retry_timings(*args, **kwargs):
|
||||
return defer.succeed(None)
|
||||
@@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import (
|
||||
account_data,
|
||||
report_event,
|
||||
openid,
|
||||
devices,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
@@ -90,3 +91,4 @@ class ClientRestResource(JsonResource):
|
||||
account_data.register_servlets(hs, client_resource)
|
||||
report_event.register_servlets(hs, client_resource)
|
||||
openid.register_servlets(hs, client_resource)
|
||||
devices.register_servlets(hs, client_resource)
|
||||
|
||||
@@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PurgeMediaCacheRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/purge_media_cache")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.media_repository = hs.get_media_repository()
|
||||
super(PurgeMediaCacheRestServlet, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
before_ts = request.args.get("before_ts", None)
|
||||
if not before_ts:
|
||||
raise SynapseError(400, "Missing 'before_ts' arg")
|
||||
|
||||
logger.info("before_ts: %r", before_ts[0])
|
||||
|
||||
try:
|
||||
before_ts = int(before_ts[0])
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid 'before_ts' arg")
|
||||
|
||||
ret = yield self.media_repository.delete_old_remote_media(before_ts)
|
||||
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PurgeHistoryRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
yield self.handlers.message_handler.purge_history(room_id, event_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class DeactivateAccountRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, target_user_id):
|
||||
UserID.from_string(target_user_id)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
# FIXME: Theoretically there is a race here wherein user resets password
|
||||
# using threepid.
|
||||
yield self.store.user_delete_access_tokens(target_user_id)
|
||||
yield self.store.user_delete_threepids(target_user_id)
|
||||
yield self.store.user_set_password_hash(target_user_id, None)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
WhoisRestServlet(hs).register(http_server)
|
||||
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
PurgeHistoryRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.hs = hs
|
||||
self.handlers = hs.get_handlers()
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
@@ -45,30 +45,27 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
||||
raise SynapseError(400, "Guest users must specify room_id param")
|
||||
if "room_id" in request.args:
|
||||
room_id = request.args["room_id"][0]
|
||||
try:
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
as_client_event = "raw" not in request.args
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
as_client_event=as_client_event,
|
||||
affect_presence=(not is_guest),
|
||||
room_id=room_id,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
except:
|
||||
logger.exception("Event stream failed")
|
||||
raise
|
||||
as_client_event = "raw" not in request.args
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
as_client_event=as_client_event,
|
||||
affect_presence=(not is_guest),
|
||||
room_id=room_id,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
|
||||
defer.returnValue((200, chunk))
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
self.servername = hs.config.server_name
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.device_handler = self.hs.get_device_handler()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
@@ -145,15 +146,23 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
).to_string()
|
||||
|
||||
auth_handler = self.auth_handler
|
||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||
user_id = yield auth_handler.validate_password_login(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"])
|
||||
|
||||
password=login_submission["password"],
|
||||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
@@ -165,14 +174,19 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
@@ -196,13 +210,15 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if registered_user_id:
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"user_id": registered_user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
@@ -245,18 +261,27 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if registered_user_id:
|
||||
device_id = yield self._register_device(
|
||||
registered_user_id, login_submission
|
||||
)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"user_id": registered_user_id,
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
else:
|
||||
# TODO: we should probably check that the register isn't going
|
||||
# to fonx/change our user_id before registering the device
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
user_id, access_token = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
@@ -295,6 +320,26 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
|
||||
return (user, attributes)
|
||||
|
||||
def _register_device(self, user_id, login_submission):
|
||||
"""Register a device for a user.
|
||||
|
||||
This is called after the user's credentials have been validated, but
|
||||
before the access token has been issued.
|
||||
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) login_submission: dictionary supplied to /login call, from
|
||||
which we pull device_id and initial_device_name
|
||||
Returns:
|
||||
defer.Deferred: (str) device_id
|
||||
"""
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get(
|
||||
"initial_device_display_name")
|
||||
return self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
|
||||
class SAML2RestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login/saml2", releases=())
|
||||
@@ -414,13 +459,13 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if not user_exists:
|
||||
user_id, _ = (
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id, _ = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
|
||||
login_token = auth_handler.generate_short_term_login_token(user_id)
|
||||
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
|
||||
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
||||
login_token)
|
||||
request.redirect(redirect_url)
|
||||
|
||||
@@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RegisterRestServlet, self).__init__(hs)
|
||||
# sessions are stored as:
|
||||
# self.sessions = {
|
||||
@@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
# TODO: persistent storage
|
||||
self.sessions = {}
|
||||
self.enable_registration = hs.config.enable_registration
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
def on_GET(self, request):
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
@@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
user_localpart = register_json["user"].encode("utf-8")
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
(user_id, token) = yield handler.appservice_register(
|
||||
user_id = yield handler.appservice_register(
|
||||
user_localpart, as_token
|
||||
)
|
||||
token = yield self.auth_handler.issue_access_token(user_id)
|
||||
self._remove_session(session)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
@@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
user = register_json["user"].encode("utf-8")
|
||||
password = register_json["password"].encode("utf-8")
|
||||
admin = register_json.get("admin", None)
|
||||
|
||||
# Its important to check as we use null bytes as HMAC field separators
|
||||
if "\x00" in user:
|
||||
raise SynapseError(400, "Invalid user")
|
||||
if "\x00" in password:
|
||||
raise SynapseError(400, "Invalid password")
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
@@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
msg=user,
|
||||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
password = register_json["password"].encode("utf-8")
|
||||
)
|
||||
want_mac.update(user)
|
||||
want_mac.update("\x00")
|
||||
want_mac.update(password)
|
||||
want_mac.update("\x00")
|
||||
want_mac.update("admin" if admin else "notadmin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.register(
|
||||
localpart=user,
|
||||
password=password,
|
||||
admin=bool(admin),
|
||||
)
|
||||
self._remove_session(session)
|
||||
defer.returnValue({
|
||||
@@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
||||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||
if duration_seconds > self.direct_user_creation_max_duration:
|
||||
duration_seconds = self.direct_user_creation_max_duration
|
||||
password_hash = user_json["password_hash"].encode("utf-8") \
|
||||
if user_json.get("password_hash") else None
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.get_or_create_user(
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_seconds=duration_seconds
|
||||
duration_in_ms=(duration_seconds * 1000),
|
||||
password_hash=password_hash
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
|
||||
@@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
|
||||
from synapse.api.errors import SynapseError, Codes, AuthError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.types import UserID, RoomID, RoomAlias
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -72,8 +74,6 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
||||
|
||||
def get_room_config(self, request):
|
||||
user_supplied_config = parse_json_object_from_request(request)
|
||||
# default visibility
|
||||
user_supplied_config.setdefault("visibility", "public")
|
||||
return user_supplied_config
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
@@ -279,6 +279,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
try:
|
||||
yield self.auth.get_user_by_req(request)
|
||||
except AuthError:
|
||||
# This endpoint isn't authed, but its useful to know who's hitting
|
||||
# it if they *do* supply an access token
|
||||
pass
|
||||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
data = yield handler.get_aggregated_public_room_list()
|
||||
|
||||
@@ -322,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||
request, default_limit=10,
|
||||
)
|
||||
as_client_event = "raw" not in request.args
|
||||
filter_bytes = request.args.get("filter", None)
|
||||
if filter_bytes:
|
||||
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
|
||||
event_filter = Filter(json.loads(filter_json))
|
||||
else:
|
||||
event_filter = None
|
||||
handler = self.handlers.message_handler
|
||||
msgs = yield handler.get_messages(
|
||||
room_id=room_id,
|
||||
requester=requester,
|
||||
pagin_config=pagination_config,
|
||||
as_client_event=as_client_event
|
||||
as_client_event=as_client_event,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
|
||||
defer.returnValue((200, msgs))
|
||||
|
||||
@@ -25,7 +25,9 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_v2_patterns(path_regex, releases=(0,)):
|
||||
def client_v2_patterns(path_regex, releases=(0,),
|
||||
v2_alpha=True,
|
||||
unstable=True):
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
prefix.
|
||||
|
||||
@@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)):
|
||||
Returns:
|
||||
SRE_Pattern
|
||||
"""
|
||||
patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
|
||||
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
|
||||
patterns.append(re.compile("^" + unstable_prefix + path_regex))
|
||||
patterns = []
|
||||
if v2_alpha:
|
||||
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
|
||||
if unstable:
|
||||
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
|
||||
patterns.append(re.compile("^" + unstable_prefix + path_regex))
|
||||
for release in releases:
|
||||
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
|
||||
patterns.append(re.compile("^" + new_prefix + path_regex))
|
||||
|
||||
@@ -28,8 +28,40 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PasswordRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
||||
if existingUid is None:
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PasswordRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/password")
|
||||
PATTERNS = client_v2_patterns("/account/password$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PasswordRestServlet, self).__init__()
|
||||
@@ -89,8 +121,83 @@ class PasswordRestServlet(RestServlet):
|
||||
return 200, {}
|
||||
|
||||
|
||||
class DeactivateAccountRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/deactivate$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
super(DeactivateAccountRestServlet, self).__init__()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
user_id = None
|
||||
requester = None
|
||||
|
||||
if LoginType.PASSWORD in result:
|
||||
# if using password, they should also be logged in
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
if user_id != result[LoginType.PASSWORD]:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
# FIXME: Theoretically there is a race here wherein user resets password
|
||||
# using threepid.
|
||||
yield self.store.user_delete_access_tokens(user_id)
|
||||
yield self.store.user_delete_threepids(user_id)
|
||||
yield self.store.user_set_password_hash(user_id, None)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class ThreepidRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
super(ThreepidRequestTokenRestServlet, self).__init__()
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class ThreepidRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid")
|
||||
PATTERNS = client_v2_patterns("/account/3pid$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThreepidRestServlet, self).__init__()
|
||||
@@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet):
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
PasswordRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
ThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
ThreepidRestServlet(hs).register(http_server)
|
||||
|
||||
100
synapse/rest/client/v2_alpha/devices.py
Normal file
100
synapse/rest/client/v2_alpha/devices.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http import servlet
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DevicesRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(DevicesRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
devices = yield self.device_handler.get_devices_by_user(
|
||||
requester.user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {"devices": devices}))
|
||||
|
||||
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(DeviceRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
device = yield self.device_handler.get_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
)
|
||||
defer.returnValue((200, device))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, device_id):
|
||||
# XXX: it's not completely obvious we want to expose this endpoint.
|
||||
# It allows the client to delete access tokens, which feels like a
|
||||
# thing which merits extra auth. But if we want to do the interactive-
|
||||
# auth dance, we should really make it possible to delete more than one
|
||||
# device at a time.
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
yield self.device_handler.delete_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
body
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
||||
@@ -13,24 +13,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import simplejson as json
|
||||
from canonicaljson import encode_canonical_json
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.server
|
||||
import synapse.types
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import UserID
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
import simplejson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyUploadServlet(RestServlet):
|
||||
"""
|
||||
POST /keys/upload/<device_id> HTTP/1.1
|
||||
POST /keys/upload HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
@@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
|
||||
},
|
||||
}
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
||||
releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(KeyUploadServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
# TODO: Check that the device_id matches that in the authentication
|
||||
# or derive the device_id from the authentication instead.
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if device_id is not None:
|
||||
# passing the device_id here is deprecated; however, we allow it
|
||||
# for now for compatibility with older clients.
|
||||
if (requester.device_id is not None and
|
||||
device_id != requester.device_id):
|
||||
logger.warning("Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id, device_id)
|
||||
else:
|
||||
device_id = requester.device_id
|
||||
|
||||
if device_id is None:
|
||||
raise synapse.api.errors.SynapseError(
|
||||
400,
|
||||
"To upload keys, you must pass device_id when authenticating"
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
@@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet):
|
||||
user_id, device_id, time_now, key_list
|
||||
)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
# the device should have been registered already, but it may have been
|
||||
# deleted due to a race with a DELETE request. Or we may be using an
|
||||
# old access_token without an associated device_id. Either way, we
|
||||
# need to double-check the device is registered to avoid ending up with
|
||||
# keys without a corresponding device.
|
||||
self.device_handler.check_device_registered(user_id, device_id)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
@@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet):
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
super(KeyQueryServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.is_mine = hs.is_mine
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = yield self.handle_request(body)
|
||||
result = yield self.e2e_keys_handler.query_devices(body)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet):
|
||||
auth_user_id = requester.user.to_string()
|
||||
user_id = user_id if user_id else auth_user_id
|
||||
device_ids = [device_id] if device_id else []
|
||||
result = yield self.handle_request(
|
||||
result = yield self.e2e_keys_handler.query_devices(
|
||||
{"device_keys": {user_id: device_ids}}
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_request(self, body):
|
||||
local_query = []
|
||||
remote_queries = {}
|
||||
for user_id, device_ids in body.get("device_keys", {}).items():
|
||||
user = UserID.from_string(user_id)
|
||||
if self.is_mine(user):
|
||||
if not device_ids:
|
||||
local_query.append((user_id, None))
|
||||
else:
|
||||
for device_id in device_ids:
|
||||
local_query.append((user_id, device_id))
|
||||
else:
|
||||
remote_queries.setdefault(user.domain, {})[user_id] = list(
|
||||
device_ids
|
||||
)
|
||||
results = yield self.store.get_e2e_device_keys(local_query)
|
||||
|
||||
json_result = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, json_bytes in device_keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||
json_bytes
|
||||
)
|
||||
|
||||
for destination, device_keys in remote_queries.items():
|
||||
remote_result = yield self.federation.query_client_keys(
|
||||
destination, {"device_keys": device_keys}
|
||||
)
|
||||
for user_id, keys in remote_result["device_keys"].items():
|
||||
if user_id in device_keys:
|
||||
json_result[user_id] = keys
|
||||
defer.returnValue((200, {"device_keys": json_result}))
|
||||
|
||||
|
||||
class OneTimeKeyServlet(RestServlet):
|
||||
"""
|
||||
|
||||
@@ -41,17 +41,59 @@ else:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register")
|
||||
class RegisterRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RegisterRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if len(absent) > 0:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RegisterRestServlet, self).__init__()
|
||||
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
@@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet):
|
||||
"Do not understand membership kind: %s" % (kind,)
|
||||
)
|
||||
|
||||
if '/register/email/requestToken' in request.path:
|
||||
ret = yield self.onEmailTokenRequest(request)
|
||||
defer.returnValue(ret)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
# we do basic sanity checks here because the auth layer will store these
|
||||
@@ -104,11 +142,12 @@ class RegisterRestServlet(RestServlet):
|
||||
# Set the desired user according to the AS API (which uses the
|
||||
# 'user' key not 'username'). Since this is a new addition, we'll
|
||||
# fallback to 'username' if they gave one.
|
||||
if isinstance(body.get("user"), basestring):
|
||||
desired_username = body["user"]
|
||||
result = yield self._do_appservice_registration(
|
||||
desired_username, request.args["access_token"][0]
|
||||
)
|
||||
desired_username = body.get("user", desired_username)
|
||||
|
||||
if isinstance(desired_username, basestring):
|
||||
result = yield self._do_appservice_registration(
|
||||
desired_username, request.args["access_token"][0], body
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
||||
@@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet):
|
||||
# FIXME: Should we really be determining if this is shared secret
|
||||
# auth based purely on the 'mac' key?
|
||||
result = yield self._do_shared_secret_registration(
|
||||
desired_username, desired_password, body["mac"]
|
||||
desired_username, desired_password, body
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
@@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet):
|
||||
[LoginType.EMAIL_IDENTITY]
|
||||
]
|
||||
|
||||
authed, result, params, session_id = yield self.auth_handler.check_auth(
|
||||
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
defer.returnValue((401, auth_result))
|
||||
return
|
||||
|
||||
if registered_user_id is not None:
|
||||
@@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet):
|
||||
"Already registered user ID %r for this session",
|
||||
registered_user_id
|
||||
)
|
||||
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
|
||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
||||
registered_user_id
|
||||
# don't re-register the email address
|
||||
add_email = False
|
||||
else:
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "Missing password.",
|
||||
Codes.MISSING_PARAM)
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
guest_access_token = params.get("guest_access_token", None)
|
||||
|
||||
(registered_user_id, _) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=new_password,
|
||||
guest_access_token=guest_access_token,
|
||||
generate_token=False,
|
||||
)
|
||||
defer.returnValue((200, {
|
||||
"user_id": registered_user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"refresh_token": refresh_token,
|
||||
}))
|
||||
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
session_id, "registered_user_id", registered_user_id
|
||||
)
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
guest_access_token = params.get("guest_access_token", None)
|
||||
add_email = True
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=new_password,
|
||||
guest_access_token=guest_access_token,
|
||||
return_dict = yield self._create_registration_details(
|
||||
registered_user_id, params
|
||||
)
|
||||
|
||||
# remember that we've now registered that user account, and with what
|
||||
# user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
session_id, "registered_user_id", user_id
|
||||
)
|
||||
if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result:
|
||||
threepid = auth_result[LoginType.EMAIL_IDENTITY]
|
||||
yield self._register_email_threepid(
|
||||
registered_user_id, threepid, return_dict["access_token"],
|
||||
params.get("bind_email")
|
||||
)
|
||||
|
||||
if result and LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
|
||||
for reqd in ['medium', 'address', 'validated_at']:
|
||||
if reqd not in threepid:
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
else:
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
# And we add an email pusher for them by default, but only
|
||||
# if email notifications are enabled (so people don't start
|
||||
# getting mail spam where they weren't before if email
|
||||
# notifs are set up on a home server)
|
||||
if (
|
||||
self.hs.config.email_enable_notifs and
|
||||
self.hs.config.email_notif_for_new_users
|
||||
):
|
||||
# Pull the ID of the access token back out of the db
|
||||
# It would really make more sense for this to be passed
|
||||
# up when the access token is saved, but that's quite an
|
||||
# invasive change I'd rather do separately.
|
||||
user_tuple = yield self.store.get_user_by_access_token(
|
||||
token
|
||||
)
|
||||
|
||||
yield self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=user_tuple["token_id"],
|
||||
kind="email",
|
||||
app_id="m.email",
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
lang=None, # We don't know a user's language here
|
||||
data={},
|
||||
)
|
||||
|
||||
if 'bind_email' in params and params['bind_email']:
|
||||
logger.info("bind_email specified: binding")
|
||||
|
||||
emailThreepid = result[LoginType.EMAIL_IDENTITY]
|
||||
threepid_creds = emailThreepid['threepid_creds']
|
||||
logger.debug("Binding emails %s to %s" % (
|
||||
emailThreepid, user_id
|
||||
))
|
||||
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
|
||||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
result = yield self._create_registration_details(user_id, token)
|
||||
defer.returnValue((200, result))
|
||||
defer.returnValue((200, return_dict))
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_appservice_registration(self, username, as_token):
|
||||
(user_id, token) = yield self.registration_handler.appservice_register(
|
||||
def _do_appservice_registration(self, username, as_token, body):
|
||||
user_id = yield self.registration_handler.appservice_register(
|
||||
username, as_token
|
||||
)
|
||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
||||
defer.returnValue((yield self._create_registration_details(user_id, body)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_shared_secret_registration(self, username, password, mac):
|
||||
def _do_shared_secret_registration(self, username, password, body):
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
@@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet):
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
got_mac = str(mac)
|
||||
got_mac = str(body["mac"])
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
@@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet):
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=username, password=password
|
||||
(user_id, _) = yield self.registration_handler.register(
|
||||
localpart=username, password=password, generate_token=False,
|
||||
)
|
||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
||||
|
||||
result = yield self._create_registration_details(user_id, body)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_registration_details(self, user_id, token):
|
||||
refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
|
||||
def _register_email_threepid(self, user_id, threepid, token, bind_email):
|
||||
"""Add an email address as a 3pid identifier
|
||||
|
||||
Also adds an email pusher for the email address, if configured in the
|
||||
HS config
|
||||
|
||||
Also optionally binds emails to the given user_id on the identity server
|
||||
|
||||
Args:
|
||||
user_id (str): id of user
|
||||
threepid (object): m.login.email.identity auth response
|
||||
token (str): access_token for the user
|
||||
bind_email (bool): true if the client requested the email to be
|
||||
bound at the identity server
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
reqd = ('medium', 'address', 'validated_at')
|
||||
if any(x not in threepid for x in reqd):
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
defer.returnValue()
|
||||
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
# And we add an email pusher for them by default, but only
|
||||
# if email notifications are enabled (so people don't start
|
||||
# getting mail spam where they weren't before if email
|
||||
# notifs are set up on a home server)
|
||||
if (self.hs.config.email_enable_notifs and
|
||||
self.hs.config.email_notif_for_new_users):
|
||||
# Pull the ID of the access token back out of the db
|
||||
# It would really make more sense for this to be passed
|
||||
# up when the access token is saved, but that's quite an
|
||||
# invasive change I'd rather do separately.
|
||||
user_tuple = yield self.store.get_user_by_access_token(
|
||||
token
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
|
||||
yield self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="email",
|
||||
app_id="m.email",
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
lang=None, # We don't know a user's language here
|
||||
data={},
|
||||
)
|
||||
|
||||
if bind_email:
|
||||
logger.info("bind_email specified: binding")
|
||||
logger.debug("Binding emails %s to %s" % (
|
||||
threepid, user_id
|
||||
))
|
||||
yield self.identity_handler.bind_threepid(
|
||||
threepid['threepid_creds'], user_id
|
||||
)
|
||||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_registration_details(self, user_id, params):
|
||||
"""Complete registration of newly-registered user
|
||||
|
||||
Allocates device_id if one was not given; also creates access_token
|
||||
and refresh_token.
|
||||
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) params: registration parameters, from which we pull
|
||||
device_id and initial_device_name
|
||||
Returns:
|
||||
defer.Deferred: (object) dictionary for response from /register
|
||||
"""
|
||||
device_id = yield self._register_device(user_id, params)
|
||||
|
||||
access_token, refresh_token = (
|
||||
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
initial_display_name=params.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"refresh_token": refresh_token,
|
||||
"device_id": device_id,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def onEmailTokenRequest(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
def _register_device(self, user_id, params):
|
||||
"""Register a device for a user.
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
This is called after the user's credentials have been validated, but
|
||||
before the access token has been issued.
|
||||
|
||||
if len(absent) > 0:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) params: registration parameters, from which we pull
|
||||
device_id and initial_device_name
|
||||
Returns:
|
||||
defer.Deferred: (str) device_id
|
||||
"""
|
||||
# register the user's device
|
||||
device_id = params.get("device_id")
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
device_id = self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
return device_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self):
|
||||
@@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet):
|
||||
generate_token=False,
|
||||
make_guest=True
|
||||
)
|
||||
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
||||
access_token = self.auth_handler.generate_access_token(
|
||||
user_id, ["guest = true"]
|
||||
)
|
||||
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
|
||||
# so long as we don't return a refresh_token here.
|
||||
defer.returnValue((200, {
|
||||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
@@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet):
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet):
|
||||
try:
|
||||
old_refresh_token = body["refresh_token"]
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token)
|
||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||
refresh_result = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token
|
||||
)
|
||||
(user_id, new_refresh_token, device_id) = refresh_result
|
||||
new_access_token = yield auth_handler.issue_access_token(
|
||||
user_id, device_id
|
||||
)
|
||||
defer.returnValue((200, {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
|
||||
@@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
|
||||
|
||||
def on_GET(self, request):
|
||||
return (200, {
|
||||
"versions": ["r0.0.1"]
|
||||
"versions": [
|
||||
"r0.0.1",
|
||||
"r0.1.0",
|
||||
"r0.2.0",
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
|
||||
@@ -15,14 +15,12 @@
|
||||
|
||||
from synapse.http.server import respond_with_json_bytes, finish_request
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, SynapseError, CodeMessageException, Codes, cs_error
|
||||
Codes, cs_error
|
||||
)
|
||||
|
||||
from twisted.protocols.basic import FileSender
|
||||
from twisted.web import server, resource
|
||||
from twisted.internet import defer
|
||||
|
||||
import base64
|
||||
import simplejson as json
|
||||
@@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
|
||||
"""
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, directory, auth, external_addr):
|
||||
def __init__(self, hs, directory):
|
||||
resource.Resource.__init__(self)
|
||||
self.hs = hs
|
||||
self.directory = directory
|
||||
self.auth = auth
|
||||
self.external_addr = external_addr.rstrip('/')
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
|
||||
if not os.path.isdir(self.directory):
|
||||
os.mkdir(self.directory)
|
||||
logger.info("ContentRepoResource : Created %s directory.",
|
||||
self.directory)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def map_request_to_name(self, request):
|
||||
# auth the user
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# namespace all file uploads on the user
|
||||
prefix = base64.urlsafe_b64encode(
|
||||
requester.user.to_string()
|
||||
).replace('=', '')
|
||||
|
||||
# use a random string for the main portion
|
||||
main_part = random_string(24)
|
||||
|
||||
# suffix with a file extension if we can make one. This is nice to
|
||||
# provide a hint to clients on the file information. We will also reuse
|
||||
# this info to spit back the content type to the client.
|
||||
suffix = ""
|
||||
if request.requestHeaders.hasHeader("Content-Type"):
|
||||
content_type = request.requestHeaders.getRawHeaders(
|
||||
"Content-Type")[0]
|
||||
suffix = "." + base64.urlsafe_b64encode(content_type)
|
||||
if (content_type.split("/")[0].lower() in
|
||||
["image", "video", "audio"]):
|
||||
file_ext = content_type.split("/")[-1]
|
||||
# be a little paranoid and only allow a-z
|
||||
file_ext = re.sub("[^a-z]", "", file_ext)
|
||||
suffix += "." + file_ext
|
||||
|
||||
file_name = prefix + main_part + suffix
|
||||
file_path = os.path.join(self.directory, file_name)
|
||||
logger.info("User %s is uploading a file to path %s",
|
||||
request.user.user_id.to_string(),
|
||||
file_path)
|
||||
|
||||
# keep trying to make a non-clashing file, with a sensible max attempts
|
||||
attempts = 0
|
||||
while os.path.exists(file_path):
|
||||
main_part = random_string(24)
|
||||
file_name = prefix + main_part + suffix
|
||||
file_path = os.path.join(self.directory, file_name)
|
||||
attempts += 1
|
||||
if attempts > 25: # really? Really?
|
||||
raise SynapseError(500, "Unable to create file.")
|
||||
|
||||
defer.returnValue(file_path)
|
||||
|
||||
def render_GET(self, request):
|
||||
# no auth here on purpose, to allow anyone to view, even across home
|
||||
@@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
|
||||
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render(request)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
respond_with_json_bytes(request, 200, {}, send_cors=True)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _async_render(self, request):
|
||||
try:
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
raise SynapseError(
|
||||
msg="Upload request body is too large",
|
||||
code=413,
|
||||
)
|
||||
|
||||
fname = yield self.map_request_to_name(request)
|
||||
|
||||
# TODO I have a suspicious feeling this is just going to block
|
||||
with open(fname, "wb") as f:
|
||||
f.write(request.content.read())
|
||||
|
||||
# FIXME (erikj): These should use constants.
|
||||
file_name = os.path.basename(fname)
|
||||
# FIXME: we can't assume what the repo's public mounted path is
|
||||
# ...plus self-signed SSL won't work to remote clients anyway
|
||||
# ...and we can't assume that it's SSL anyway, as we might want to
|
||||
# serve it via the non-SSL listener...
|
||||
url = "%s/_matrix/content/%s" % (
|
||||
self.external_addr, file_name
|
||||
)
|
||||
|
||||
respond_with_json_bytes(request, 200,
|
||||
json.dumps({"content_token": url}),
|
||||
send_cors=True)
|
||||
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
respond_with_json_bytes(request, e.code,
|
||||
json.dumps(cs_exception(e)))
|
||||
except Exception as e:
|
||||
logger.error("Failed to store file: %s" % e)
|
||||
respond_with_json_bytes(
|
||||
request,
|
||||
500,
|
||||
json.dumps({"error": "Internal server error"}),
|
||||
send_cors=True)
|
||||
|
||||
@@ -65,3 +65,9 @@ class MediaFilePaths(object):
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
)
|
||||
|
||||
@@ -26,14 +26,17 @@ from .thumbnailer import Thumbnailer
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
|
||||
import os
|
||||
import errno
|
||||
import shutil
|
||||
|
||||
import cgi
|
||||
import logging
|
||||
@@ -42,8 +45,11 @@ import urlparse
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
|
||||
|
||||
|
||||
class MediaRepository(object):
|
||||
def __init__(self, hs, filepaths):
|
||||
def __init__(self, hs):
|
||||
self.auth = hs.get_auth()
|
||||
self.client = MatrixFederationHttpClient(hs)
|
||||
self.clock = hs.get_clock()
|
||||
@@ -51,11 +57,28 @@ class MediaRepository(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
self.filepaths = filepaths
|
||||
self.downloads = {}
|
||||
self.filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer()
|
||||
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
self.clock.looping_call(
|
||||
self._update_recently_accessed_remotes,
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_recently_accessed_remotes(self):
|
||||
media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
yield self.store.update_cached_last_access_time(
|
||||
media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _makedirs(filepath):
|
||||
dirname = os.path.dirname(filepath)
|
||||
@@ -92,22 +115,12 @@ class MediaRepository(object):
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media(self, server_name, media_id):
|
||||
key = (server_name, media_id)
|
||||
download = self.downloads.get(key)
|
||||
if download is None:
|
||||
download = self._get_remote_media_impl(server_name, media_id)
|
||||
download = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
self.downloads[key] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(media_info):
|
||||
del self.downloads[key]
|
||||
return media_info
|
||||
return download.observe()
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
media_info = yield self._get_remote_media_impl(server_name, media_id)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
@@ -118,6 +131,11 @@ class MediaRepository(object):
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id
|
||||
)
|
||||
else:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
yield self.store.update_cached_last_access_time(
|
||||
[(server_name, media_id)], self.clock.time_msec()
|
||||
)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -134,10 +152,15 @@ class MediaRepository(object):
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
)
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn("Failed to fetch remoted media %r", e)
|
||||
raise SynapseError(502, "Failed to fetch remoted media")
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
@@ -410,6 +433,41 @@ class MediaRepository(object):
|
||||
"height": m_height,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_old_remote_media(self, before_ts):
|
||||
old_media = yield self.store.get_remote_media_before(before_ts)
|
||||
|
||||
deleted = 0
|
||||
|
||||
for media in old_media:
|
||||
origin = media["media_origin"]
|
||||
media_id = media["media_id"]
|
||||
file_id = media["filesystem_id"]
|
||||
key = (origin, media_id)
|
||||
|
||||
logger.info("Deleting: %r", key)
|
||||
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||
try:
|
||||
os.remove(full_path)
|
||||
except OSError as e:
|
||||
logger.warn("Failed to remove file: %r", full_path)
|
||||
if e.errno == errno.ENOENT:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
|
||||
origin, file_id
|
||||
)
|
||||
shutil.rmtree(thumbnail_dir, ignore_errors=True)
|
||||
|
||||
yield self.store.delete_remote_media(origin, media_id)
|
||||
deleted += 1
|
||||
|
||||
defer.returnValue({"deleted": deleted})
|
||||
|
||||
|
||||
class MediaRepositoryResource(Resource):
|
||||
"""File uploading and downloading.
|
||||
@@ -458,9 +516,8 @@ class MediaRepositoryResource(Resource):
|
||||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self)
|
||||
filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
|
||||
media_repo = MediaRepository(hs, filepaths)
|
||||
media_repo = hs.get_media_repository()
|
||||
|
||||
self.putChild("upload", UploadResource(hs, media_repo))
|
||||
self.putChild("download", DownloadResource(hs, media_repo))
|
||||
|
||||
@@ -29,6 +29,8 @@ from synapse.http.server import (
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
@@ -252,7 +254,8 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
@@ -279,7 +282,7 @@ class PreviewUrlResource(Resource):
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * int(i.attrib['width']) * int(i.attrib['height'])
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
@@ -287,9 +290,9 @@ class PreviewUrlResource(Resource):
|
||||
og['og:image'] = images[0].attrib['src']
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url request
|
||||
# itself and benefit from the same caching etc. But for now we just rely on the
|
||||
# caching on the master request to speed things up.
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
self._rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
@@ -328,20 +331,24 @@ class PreviewUrlResource(Resource):
|
||||
# ...or if they are within a <script/> or <style/> tag.
|
||||
# This is a very very very coarse approximation to a plain text
|
||||
# render of the page.
|
||||
text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
|
||||
"ancestor::aside | ancestor::footer | "
|
||||
"ancestor::script | ancestor::style)]" +
|
||||
"[ancestor::body]")
|
||||
text = ''
|
||||
for text_node in text_nodes:
|
||||
if len(text) < 500:
|
||||
text += text_node + ' '
|
||||
else:
|
||||
break
|
||||
text = re.sub(r'[\t ]+', ' ', text)
|
||||
text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
|
||||
text = text.strip()[:500]
|
||||
og['og:description'] = text if text else None
|
||||
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
# We clone `tree` as we modify it.
|
||||
cloned_tree = deepcopy(tree.find("body"))
|
||||
|
||||
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
|
||||
for el in cloned_tree.iter(TAGS_TO_REMOVE):
|
||||
el.getparent().remove(el)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r'\s+', '\n', el.text).strip()
|
||||
for el in cloned_tree.iter()
|
||||
if el.text and isinstance(el.tag, basestring) # Removes comments
|
||||
)
|
||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
@@ -449,3 +456,56 @@ class PreviewUrlResource(Resource):
|
||||
content_type.startswith("application/xhtml")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
||||
# Try to get a summary of between 200 and 500 words, respecting
|
||||
# first paragraph and then word boundaries.
|
||||
# TODO: Respect sentences?
|
||||
|
||||
description = ''
|
||||
|
||||
# Keep adding paragraphs until we get to the MIN_SIZE.
|
||||
for text_node in text_nodes:
|
||||
if len(description) < min_size:
|
||||
text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
|
||||
description += text_node + '\n\n'
|
||||
else:
|
||||
break
|
||||
|
||||
description = description.strip()
|
||||
description = re.sub(r'[\t ]+', ' ', description)
|
||||
description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
|
||||
|
||||
# If the concatenation of paragraphs to get above MIN_SIZE
|
||||
# took us over MAX_SIZE, then we need to truncate mid paragraph
|
||||
if len(description) > max_size:
|
||||
new_desc = ""
|
||||
|
||||
# This splits the paragraph into words, but keeping the
|
||||
# (preceeding) whitespace intact so we can easily concat
|
||||
# words back together.
|
||||
for match in re.finditer("\s*\S+", description):
|
||||
word = match.group()
|
||||
|
||||
# Keep adding words while the total length is less than
|
||||
# MAX_SIZE.
|
||||
if len(word) + len(new_desc) < max_size:
|
||||
new_desc += word
|
||||
else:
|
||||
# At this point the next word *will* take us over
|
||||
# MAX_SIZE, but we also want to ensure that its not
|
||||
# a huge word. If it is add it anyway and we'll
|
||||
# truncate later.
|
||||
if len(new_desc) < min_size:
|
||||
new_desc += word
|
||||
break
|
||||
|
||||
# Double check that we're not over the limit
|
||||
if len(new_desc) > max_size:
|
||||
new_desc = new_desc[:max_size]
|
||||
|
||||
# We always add an ellipsis because at the very least
|
||||
# we chopped mid paragraph.
|
||||
description = new_desc.strip() + "…"
|
||||
return description if description else None
|
||||
|
||||
@@ -19,37 +19,38 @@
|
||||
# partial one for unit test mocking.
|
||||
|
||||
# Imports required for the default HomeServer() implementation
|
||||
from twisted.web.client import BrowserLikePolicyForHTTPS
|
||||
from twisted.enterprise import adbapi
|
||||
|
||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.notifier import Notifier
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.handlers.room import RoomListHandler
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.api.filtering import Filtering
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.web.client import BrowserLikePolicyForHTTPS
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.filtering import Filtering
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.handlers import Handlers
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.handlers.room import RoomListHandler
|
||||
from synapse.handlers.sync import SyncHandler
|
||||
from synapse.handlers.typing import TypingHandler
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.notifier import Notifier
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -91,6 +92,8 @@ class HomeServer(object):
|
||||
'typing_handler',
|
||||
'room_list_handler',
|
||||
'auth_handler',
|
||||
'device_handler',
|
||||
'e2e_keys_handler',
|
||||
'application_service_api',
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
@@ -113,6 +116,7 @@ class HomeServer(object):
|
||||
'filtering',
|
||||
'http_client_context_factory',
|
||||
'simple_http_client',
|
||||
'media_repository',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
@@ -195,6 +199,12 @@ class HomeServer(object):
|
||||
def build_auth_handler(self):
|
||||
return AuthHandler(self)
|
||||
|
||||
def build_device_handler(self):
|
||||
return DeviceHandler(self)
|
||||
|
||||
def build_e2e_keys_handler(self):
|
||||
return E2eKeysHandler(self)
|
||||
|
||||
def build_application_service_api(self):
|
||||
return ApplicationServiceApi(self)
|
||||
|
||||
@@ -233,6 +243,9 @@ class HomeServer(object):
|
||||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
def build_media_repository(self):
|
||||
return MediaRepository(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
||||
25
synapse/server.pyi
Normal file
25
synapse/server.pyi
Normal file
@@ -0,0 +1,25 @@
|
||||
import synapse.handlers
|
||||
import synapse.handlers.auth
|
||||
import synapse.handlers.device
|
||||
import synapse.handlers.e2e_keys
|
||||
import synapse.storage
|
||||
import synapse.state
|
||||
|
||||
class HomeServer(object):
|
||||
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
|
||||
pass
|
||||
|
||||
def get_datastore(self) -> synapse.storage.DataStore:
|
||||
pass
|
||||
|
||||
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
|
||||
pass
|
||||
|
||||
def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
|
||||
pass
|
||||
|
||||
def get_handlers(self) -> synapse.handlers.Handlers:
|
||||
pass
|
||||
|
||||
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||
pass
|
||||
@@ -379,7 +379,8 @@ class StateHandler(object):
|
||||
try:
|
||||
# FIXME: hs.get_auth() is bad style, but we need to do it to
|
||||
# get around circular deps.
|
||||
self.hs.get_auth().check(event, auth_events)
|
||||
# The signatures have already been checked at this point
|
||||
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
|
||||
prev_event = event
|
||||
except AuthError:
|
||||
return prev_event
|
||||
@@ -391,7 +392,8 @@ class StateHandler(object):
|
||||
try:
|
||||
# FIXME: hs.get_auth() is bad style, but we need to do it to
|
||||
# get around circular deps.
|
||||
self.hs.get_auth().check(event, auth_events)
|
||||
# The signatures have already been checked at this point
|
||||
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
|
||||
return event
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage.devices import DeviceStore
|
||||
from .appservice import (
|
||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||
)
|
||||
@@ -80,6 +82,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
EventPushActionsStore,
|
||||
OpenIdStore,
|
||||
ClientIpStore,
|
||||
DeviceStore,
|
||||
):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
@@ -92,7 +95,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
extra_tables=[("local_invites", "stream_id")]
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering", step=-1
|
||||
db_conn, "events", "stream_ordering", step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
||||
)
|
||||
self._receipts_id_gen = StreamIdGenerator(
|
||||
db_conn, "receipts_linearized", "stream_id"
|
||||
|
||||
@@ -597,10 +597,13 @@ class SQLBaseStore(object):
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the rows with,
|
||||
or None to not apply a WHERE clause.
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
table (str): the table name
|
||||
keyvalues (dict[str, Any] | None):
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
@@ -615,9 +618,11 @@ class SQLBaseStore(object):
|
||||
|
||||
Args:
|
||||
txn : Transaction object
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the rows with
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
table (str): the table name
|
||||
keyvalues (dict[str, T] | None):
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
"""
|
||||
if keyvalues:
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
@@ -807,6 +812,11 @@ class SQLBaseStore(object):
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "more than one row matched")
|
||||
|
||||
def _simple_delete(self, table, keyvalues, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_txn, table, keyvalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_txn(txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
|
||||
@@ -138,6 +138,9 @@ class AccountDataStore(SQLBaseStore):
|
||||
A deferred pair of lists of tuples of stream_id int, user_id string,
|
||||
room_id string, type string, and content string.
|
||||
"""
|
||||
if last_room_id == current_id and last_global_id == current_id:
|
||||
return defer.succeed(([], []))
|
||||
|
||||
def get_updated_account_data_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, account_data_type, content"
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from . import engines
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -87,10 +88,12 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start_doing_background_updates(self):
|
||||
while True:
|
||||
if self._background_update_timer is not None:
|
||||
return
|
||||
assert self._background_update_timer is None, \
|
||||
"background updates already running"
|
||||
|
||||
logger.info("Starting background schema updates")
|
||||
|
||||
while True:
|
||||
sleep = defer.Deferred()
|
||||
self._background_update_timer = self._clock.call_later(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None
|
||||
@@ -101,22 +104,23 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
self._background_update_timer = None
|
||||
|
||||
try:
|
||||
result = yield self.do_background_update(
|
||||
result = yield self.do_next_background_update(
|
||||
self.BACKGROUND_UPDATE_DURATION_MS
|
||||
)
|
||||
except:
|
||||
logger.exception("Error doing update")
|
||||
|
||||
if result is None:
|
||||
logger.info(
|
||||
"No more background updates to do."
|
||||
" Unscheduling background update task."
|
||||
)
|
||||
return
|
||||
else:
|
||||
if result is None:
|
||||
logger.info(
|
||||
"No more background updates to do."
|
||||
" Unscheduling background update task."
|
||||
)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_background_update(self, desired_duration_ms):
|
||||
"""Does some amount of work on a background update
|
||||
def do_next_background_update(self, desired_duration_ms):
|
||||
"""Does some amount of work on the next queued background update
|
||||
|
||||
Args:
|
||||
desired_duration_ms(float): How long we want to spend
|
||||
updating.
|
||||
@@ -135,11 +139,21 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
self._background_update_queue.append(update['update_name'])
|
||||
|
||||
if not self._background_update_queue:
|
||||
# no work left to do
|
||||
defer.returnValue(None)
|
||||
|
||||
# pop from the front, and add back to the back
|
||||
update_name = self._background_update_queue.pop(0)
|
||||
self._background_update_queue.append(update_name)
|
||||
|
||||
res = yield self._do_background_update(update_name, desired_duration_ms)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_background_update(self, update_name, desired_duration_ms):
|
||||
logger.info("Starting update batch on background update '%s'",
|
||||
update_name)
|
||||
|
||||
update_handler = self._background_update_handlers[update_name]
|
||||
|
||||
performance = self._background_update_performance.get(update_name)
|
||||
@@ -202,6 +216,64 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
"""
|
||||
self._background_update_handlers[update_name] = update_handler
|
||||
|
||||
def register_background_index_update(self, update_name, index_name,
|
||||
table, columns):
|
||||
"""Helper for store classes to do a background index addition
|
||||
|
||||
To use:
|
||||
|
||||
1. use a schema delta file to add a background update. Example:
|
||||
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||
('my_new_index', '{}');
|
||||
|
||||
2. In the Store constructor, call this method
|
||||
|
||||
Args:
|
||||
update_name (str): update_name to register for
|
||||
index_name (str): name of index to add
|
||||
table (str): table to add index to
|
||||
columns (list[str]): columns/expressions to include in index
|
||||
"""
|
||||
|
||||
# if this is postgres, we add the indexes concurrently. Otherwise
|
||||
# we fall back to doing it inline
|
||||
if isinstance(self.database_engine, engines.PostgresEngine):
|
||||
conc = True
|
||||
else:
|
||||
conc = False
|
||||
|
||||
sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \
|
||||
% {
|
||||
"conc": "CONCURRENTLY" if conc else "",
|
||||
"name": index_name,
|
||||
"table": table,
|
||||
"columns": ", ".join(columns),
|
||||
}
|
||||
|
||||
def create_index_concurrently(conn):
|
||||
conn.rollback()
|
||||
# postgres insists on autocommit for the index
|
||||
conn.set_session(autocommit=True)
|
||||
c = conn.cursor()
|
||||
c.execute(sql)
|
||||
conn.set_session(autocommit=False)
|
||||
|
||||
def create_index(conn):
|
||||
c = conn.cursor()
|
||||
c.execute(sql)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def updater(progress, batch_size):
|
||||
logger.info("Adding index %s to %s", index_name, table)
|
||||
if conc:
|
||||
yield self.runWithConnection(create_index_concurrently)
|
||||
else:
|
||||
yield self.runWithConnection(create_index)
|
||||
yield self._end_background_update(update_name)
|
||||
defer.returnValue(1)
|
||||
|
||||
self.register_background_update_handler(update_name, updater)
|
||||
|
||||
def start_background_update(self, update_name, progress):
|
||||
"""Starts a background update running.
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user