Compare commits

...

252 Commits

Author SHA1 Message Date
Erik Johnston
6affde172c Replace some ujson with simplejson to make it work 2018-03-15 23:38:43 +00:00
Richard van der Hoff
d627174da2 Fix log message in purge_history
(we don't just remove remote events)
2018-02-13 16:51:21 +00:00
Richard van der Hoff
bfdf7b9237 Merge pull request #2864 from matrix-org/rav/persist_event_caching
Use StateResolutionHandler to resolve state in persist_events
2018-02-13 14:45:57 +00:00
Richard van der Hoff
630caf8a70 style nit 2018-02-13 14:29:22 +00:00
Richard van der Hoff
8fd1a32456 Fix typos in purge api & doc
* It's supposed to be purge_local_events, not ..._history
* Fix the doc to have valid json
2018-02-13 13:09:39 +00:00
Richard van der Hoff
4d09366656 Merge pull request #2695 from okurz/feature/allow_recent_pysaml
Allow use of higher versions of saml2
2018-02-13 12:24:08 +00:00
Erik Johnston
1026690cd2 Merge pull request #2857 from matrix-org/erikj/upload_store
Tell storage providers about new file so they can upload
2018-02-12 13:52:58 +00:00
Richard van der Hoff
10b34dbb9a Merge pull request #2858 from matrix-org/rav/purge_updates
delete_local_events for purge_room_history
2018-02-09 14:11:00 +00:00
Richard van der Hoff
39a6b35496 purge: move room_depth update to end
... to avoid locking the table for too long
2018-02-09 13:07:41 +00:00
Richard van der Hoff
74fcbf741b delete_local_events for purge_history
Add a flag which makes the purger delete local events
2018-02-09 13:07:41 +00:00
Richard van der Hoff
e571aef06d purge: Move cache invalidation to more appropriate place
it was a bit of a non-sequitur there
2018-02-09 13:07:41 +00:00
Richard van der Hoff
61ffaa8137 bump purge logging to info
this thing takes ages and the only sign of any progress is the logs, so having
some logs is useful.
2018-02-09 13:07:41 +00:00
Richard van der Hoff
671540dccf rename delete_old_state -> purge_history
(beacause it deletes more than state)
2018-02-09 13:07:41 +00:00
Erik Johnston
5fa571a91b Tell storage providers about new file so they can upload 2018-02-07 13:35:08 +00:00
Erik Johnston
053255f36c Merge pull request #2856 from matrix-org/erikj/remove_ratelimit
Remove pointless ratelimit check
2018-02-07 11:12:14 +00:00
Erik Johnston
e3624fad5f Remove pointless ratelimit check
The intention was for the check to be called as early as possible in the
request, but actually was called just before the main ratelimit check,
so was fairly pointless.
2018-02-07 10:30:25 +00:00
Erik Johnston
617199d73d Merge pull request #2847 from matrix-org/erikj/separate_event_creation
Split event creation into a separate handler
2018-02-06 17:01:17 +00:00
Erik Johnston
3e1e69ccaf Update copyright 2018-02-06 16:40:38 +00:00
Erik Johnston
770b2252ca s/_create_new_client_event/create_new_client_event/ 2018-02-06 16:40:30 +00:00
Erik Johnston
3d33eef6fc Store state groups separately from events (#2784)
* Split state group persist into seperate storage func

* Add per database engine code for state group id gen

* Move store_state_group to StateReadStore

This allows other workers to use it, and so resolve state.

* Hook up store_state_group

* Fix tests

* Rename _store_mult_state_groups_txn

* Rename StateGroupReadStore

* Remove redundant _have_persisted_state_group_txn

* Update comments

* Comment compute_event_context

* Set start val for state_group_id_seq

... otherwise we try to recreate old state groups

* Update comments

* Don't store state for outliers

* Update comment

* Update docstring as state groups are ints
2018-02-06 14:31:24 +00:00
Richard van der Hoff
b31bf0bb51 Merge pull request #2849 from matrix-org/rav/clean_up_state_delta
Remove redundant return value from _calculate_state_delta
2018-02-05 17:42:20 +01:00
Richard van der Hoff
9a304ef2b0 Merge pull request #2848 from matrix-org/rav/refactor_search_insert
Factor out common code for search insert
2018-02-05 17:42:09 +01:00
Richard van der Hoff
ebfe64e3d6 Use StateResolutionHandler to resolve state in persist events
... and thus benefit (hopefully) from its cache.
2018-02-05 16:23:26 +00:00
Richard van der Hoff
225dc3b4cb Flatten _get_new_state_after_events
rejig the if statements to simplify the logic and reduce indentation
2018-02-05 16:23:25 +00:00
Richard van der Hoff
9fcbbe8e7d Check that events being persisted have state_group 2018-02-05 16:23:25 +00:00
Richard van der Hoff
447aed42d2 Add event_map param to resolve_state_groups 2018-02-05 16:23:25 +00:00
Richard van der Hoff
ee6fb4cf85 Remove redundant return value from _calculate_state_delta
we already have the state from _get_new_state_after_events, so returning it
from _calculate_state_delta is just confusing.
2018-02-05 16:23:20 +00:00
Richard van der Hoff
3c7b480ba3 Factor out common code for search insert
we can reuse the same code as is used for event insert, for doing the
background index population.
2018-02-05 16:12:14 +00:00
Erik Johnston
25c0a020f4 Updates tests 2018-02-05 16:01:48 +00:00
Erik Johnston
3fa362502c Update places where we create events 2018-02-05 16:01:48 +00:00
Erik Johnston
5ff3d23564 Split event creation into a separate handler 2018-02-05 16:01:48 +00:00
Richard van der Hoff
c46e75d3d8 Move store_event_search_txn to SearchStore
... as a precursor to making event storing and doing the bg update share some
code.
2018-02-05 15:43:22 +00:00
Richard van der Hoff
db91e72ade Merge pull request #2844 from matrix-org/rav/evicted_metrics
montoring metrics for number of cache evictions
2018-02-05 16:34:36 +01:00
Richard van der Hoff
bc496df192 report metrics on number of cache evictions 2018-02-05 15:34:01 +00:00
Erik Johnston
a1beca0e25 Fix broken unit test for media storage 2018-02-05 12:44:03 +00:00
Erik Johnston
b5049d2e5c Add .vscode to gitignore 2018-02-05 12:06:46 +00:00
Erik Johnston
1f881e0746 Merge pull request #2791 from matrix-org/erikj/media_storage_refactor
Ensure media is in local cache before thumbnailing
2018-02-05 11:28:52 +00:00
Richard van der Hoff
9c9356512e Merge pull request #2845 from matrix-org/rav/urlcache_error_handling
Handle url_previews with no content-type
2018-02-02 15:27:52 +01:00
Richard van der Hoff
18eae413af Merge pull request #2842 from matrix-org/rav/state_resolution_handler
Factor out resolve_state_groups to a separate handler
2018-02-02 15:27:35 +01:00
Richard van der Hoff
78d6ddba86 Merge pull request #2841 from matrix-org/rav/refactor_calc_state_delta
factor _get_new_state_after_events out of _calculate_state_delta
2018-02-02 15:27:15 +01:00
Richard van der Hoff
9dcd667ac2 Merge pull request #2836 from matrix-org/rav/resolve_state_events_docstring
Docstring fixes
2018-02-02 15:26:49 +01:00
Richard van der Hoff
33cac3dc29 Merge pull request #2818 from turt2live/travis/admin-list-media
Add an admin route to get all the media in a room
2018-02-02 11:41:03 +01:00
Travis Ralston
6e87b34f7b Merge branch 'develop' into travis/admin-list-media 2018-02-01 18:05:47 -07:00
Richard van der Hoff
d5352cbba8 Handle url_previews with no content-type
avoid failing with an exception if the remote server doesn't give us a
Content-Type header.

Also, clean up the exception handling a bit.
2018-02-02 00:53:46 +00:00
Richard van der Hoff
14737ba495 doc arg types for _seperate 2018-02-01 12:41:34 +00:00
Richard van der Hoff
e15d4ea248 More docstring fixes
Fix a couple of errors in docstrings
2018-02-01 12:41:34 +00:00
Richard van der Hoff
a18828c129 Fix docstring for StateHandler.resolve_state_groups
The return type was a complete lie, so fix it
2018-02-01 12:41:34 +00:00
Richard van der Hoff
6da4c4d3bd Factor out resolve_state_groups to a separate handler
We extract the storage-independent bits of the state group resolution out to a
separate functiom, and stick it in a new handler, in preparation for its use
from the storage layer.
2018-02-01 12:40:04 +00:00
Richard van der Hoff
0cbda53819 Rename resolve_state_groups -> resolve_state_groups_for_events
(to make way for a method that actually just does the state group resolution)
2018-02-01 12:40:00 +00:00
Richard van der Hoff
77c0629ebc Merge pull request #2837 from matrix-org/rav/fix_quarantine_media
Fix sql error in quarantine_media
2018-02-01 11:45:58 +01:00
Travis Ralston
e16e45b1b4 pep8
Signed-off-by: Travis Ralston <travpc@gmail.com>
2018-01-31 15:30:38 -07:00
Richard van der Hoff
e1e4ec9f9d factor _get_new_state_after_events out of _calculate_state_delta
This reduces the scope of a bunch of variables
2018-01-31 21:32:09 +00:00
Richard van der Hoff
78e7e05188 Merge pull request #2838 from matrix-org/rav/fix_logging_on_dns_fail
Remove spurious log argument
2018-01-31 22:18:46 +01:00
Richard van der Hoff
ad48dfe73d Merge pull request #2835 from matrix-org/rav/remove_event_type_param
Remove unused "event_type" param on state.get_current_state_ids
2018-01-31 22:17:57 +01:00
Richard van der Hoff
518a74586c Merge pull request #2834 from matrix-org/rav/better_persist_event_exception_handling
Improve exception handling in persist_event
2018-01-31 22:17:38 +01:00
Richard van der Hoff
d1fe4db882 Merge pull request #2833 from matrix-org/rav/pusher_hacks
Logging and metrics for the http pusher
2018-01-31 22:16:59 +01:00
Richard van der Hoff
421d68ca8c Merge pull request #2817 from matrix-org/rav/http_conn_pool
Use a connection pool for the SimpleHttpClient
2018-01-31 22:14:22 +01:00
Richard van der Hoff
326189c25a Script to move remote media to another media store 2018-01-31 18:45:10 +00:00
Travis Ralston
3af53c183a Add admin api documentation for list media endpoint
Signed-off-by: Travis Ralston <travpc@gmail.com>
2018-01-31 08:15:59 -07:00
Travis Ralston
63c4383927 Documentation and naming
Signed-off-by: Travis Ralston <travpc@gmail.com>
2018-01-31 08:07:52 -07:00
Richard van der Hoff
af19f5e9aa Remove spurious log argument
... which would cause scary-looking and unhelpful errors in the log on dns fail
2018-01-30 17:52:03 +00:00
Richard van der Hoff
773f0eed1e Fix sql error in quarantine_media 2018-01-30 15:02:51 +00:00
Richard van der Hoff
adfc0c9539 docstring for get_current_state_ids 2018-01-29 17:39:55 +00:00
Richard van der Hoff
d413a2ba98 Remove unused "event_type" param on state.get_current_state_ids
this param doesn't seem to be used, and is a bit pointless anyway because it
can easily be replicated by the caller. It is also horrible, because it changes
the return type of the method.
2018-01-29 17:06:57 +00:00
Richard van der Hoff
b387ee17b6 Improve exception handling in persist_event
1. use `deferred.errback()` instead of `deferred.errback(e)`, which means that
a Failure object will be constructed using the current exception state,
*including* its stack trace - so the stack trace is saved in the Failure,
leading to better exception reports.

2. Set `consumeErrors=True` on the ObservableDeferred, because we know that
there will always be at least one observer - which avoids a spurious "CRITICAL:
unhandled exception in Deferred" error in the logs
2018-01-29 17:05:33 +00:00
Richard van der Hoff
03dd745fe2 Better logging when pushes fail 2018-01-29 15:49:06 +00:00
Richard van der Hoff
e051abd20b add appid/device_display_name to to pusher logging 2018-01-29 15:04:16 +00:00
Richard van der Hoff
02ba118f81 Increase http conn pool size 2018-01-29 14:30:15 +00:00
Richard van der Hoff
4c65b98e4a Merge pull request #2831 from matrix-org/rav/fix-userdir-search-again
Fix SQL for user search
2018-01-27 22:58:24 +01:00
Richard van der Hoff
d1f3490e75 Add tests for user directory search 2018-01-27 17:21:57 +00:00
Richard van der Hoff
46022025ea Fix SQL for user search
fix some syntax errors for user search when search_all_users is enabled

fixes #2801, hopefully
2018-01-27 17:21:57 +00:00
Richard van der Hoff
2186d7c06e Merge pull request #2829 from matrix-org/rav/postgres_in_tests
Make it possible to run tests against postgres
2018-01-27 18:20:55 +01:00
Richard van der Hoff
88b9c5cbf0 Make it possible to run tests against postgres 2018-01-27 17:15:24 +00:00
Richard van der Hoff
d7eacc4f87 Create dbpool as normal in tests
... instead of creating our own special SQLiteMemoryDbPool, whose purpose was a
bit of a mystery.

For some reason this makes one of the tests run slightly slower, so bump the
sleep(). Sorry.
2018-01-27 17:15:15 +00:00
Richard van der Hoff
b178eca261 Run on_new_connection for unit tests
Configure the connectionpool used for unit tests to run the `on_new_connection`
function.
2018-01-27 17:06:04 +00:00
Richard van der Hoff
d8f90c4208 Merge pull request #2828 from matrix-org/rav/kill_memory_datastore
Remove unused/bitrotted MemoryDataStore
2018-01-27 17:57:02 +01:00
Richard van der Hoff
4b0f06e99c Merge pull request #2830 from matrix-org/rav/factor_out_get_conn
Factor out get_db_conn to HomeServer base class
2018-01-27 17:56:25 +01:00
Neil Johnson
e98f0f9112 Merge pull request #2827 from matrix-org/fix_server_500_on_public_rooms_call_when_no_rooms_exist
Fix server 500 on public rooms call when no rooms exist
2018-01-26 20:22:40 +00:00
Richard van der Hoff
25adde9a04 Factor out get_db_conn to HomeServer base class
This function is identical to all subclasses, so we may as well push it up to
the base class to reduce duplication (and make use of it in the tests)
2018-01-26 00:56:49 +00:00
Richard van der Hoff
6e9bf67f18 Remove unused/bitrotted MemoryDataStore
This isn't used, and looks thoroughly bitrotted.
2018-01-26 00:35:15 +00:00
Richard van der Hoff
2b91846497 Remove spurious unittest.DEBUG 2018-01-26 00:34:27 +00:00
Neil Johnson
73560237d6 add white space line 2018-01-26 00:15:10 +00:00
Neil Johnson
86c4f49a31 rather than try reconstruct the results object, better to guard against the xrange step argument being 0 2018-01-26 00:12:02 +00:00
Neil Johnson
f632083576 fix return type, should be a dict 2018-01-25 23:52:17 +00:00
Neil Johnson
6c6e197b0a fix PEP8 violation 2018-01-25 23:47:46 +00:00
Neil Johnson
d02e43b15f remove white space 2018-01-25 23:29:46 +00:00
Neil Johnson
349c739966 synapse 500s on a call to publicRooms in the case where the number of public rooms is zero, the specific cause is due to xrange trying to use a step value of zero, but if the total room number really is zero then it makes sense to just bail and save the extra processing 2018-01-25 23:28:44 +00:00
Matthew Hodgson
9a72b70630 fix thinko on 3pid whitelisting 2018-01-24 11:07:47 +01:00
Richard van der Hoff
25e2456ee7 Merge pull request #2816 from matrix-org/rav/metrics_comments
Add some comments about the reactor tick time metric
2018-01-23 19:39:26 +00:00
Matthew Hodgson
d32385336f add ?ts massaging for ASes (#2754)
blindly implement ?ts for AS. untested
2018-01-23 09:59:06 +01:00
Richard van der Hoff
b2da272b77 Merge pull request #2821 from matrix-org/rav/matthew_test_fixes
Matthew's fixes to the unit tests
2018-01-22 20:43:56 +00:00
Richard van der Hoff
4528dd2443 Fix logging and add user_id 2018-01-22 20:15:42 +00:00
Richard van der Hoff
93efd7eb04 logging and debug for http pusher 2018-01-22 18:14:10 +00:00
Matthew Hodgson
ab9f844aaf Add federation_domain_whitelist option (#2820)
Add federation_domain_whitelist

gives a way to restrict which domains your HS is allowed to federate with.
useful mainly for gracefully preventing a private but internet-connected HS from trying to federate to the wider public Matrix network
2018-01-22 19:11:18 +01:00
Richard van der Hoff
5c431f421c Matthew's fixes to the unit tests
Extracted from https://github.com/matrix-org/synapse/pull/2820
2018-01-22 16:46:16 +00:00
Matthew Hodgson
d84f65255e Merge pull request #2813 from matrix-org/matthew/registrations_require_3pid
add registrations_require_3pid and allow_local_3pids
2018-01-22 13:57:22 +00:00
Travis Ralston
a94d9b6b82 Appease the linter
These are ids anyways, not mxc uris.

Signed-off-by: Travis Ralston <travpc@gmail.com>
2018-01-20 22:49:46 -07:00
Travis Ralston
5552ed9a7f Add an admin route to get all the media in a room
This is intended to be used by administrators to monitor the media that is passing through their server, if they wish.

Signed-off-by: Travis Ralston <travpc@gmail.com>
2018-01-20 22:37:53 -07:00
Richard van der Hoff
2c8526cac7 Use a connection pool for the SimpleHttpClient
In particular I hope this will help the pusher, which makes many requests to
sygnal, and is currently negotiating SSL for each one.
2018-01-20 00:55:44 +00:00
Richard van der Hoff
87b7d72760 Add some comments about the reactor tick time metric 2018-01-19 23:51:04 +00:00
Matthew Hodgson
49fce04624 fix typo (thanks sytest) 2018-01-19 19:55:38 +00:00
Richard van der Hoff
b0d9e633ee Merge pull request #2814 from matrix-org/rav/fix_urlcache_thumbs
Use the right path for url_preview thumbnails
2018-01-19 18:57:15 +00:00
Richard van der Hoff
ad7ec63d08 Use the right path for url_preview thumbnails
This was introduced by #2627: we were overwriting the original media for url
previews with the thumbnails :/

(fixes https://github.com/vector-im/riot-web/issues/6012, hopefully)
2018-01-19 18:29:39 +00:00
Matthew Hodgson
62d7d66ae5 oops, check all login types 2018-01-19 18:23:56 +00:00
Matthew Hodgson
8fe253f19b fix PR nitpicking 2018-01-19 18:23:45 +00:00
Matthew Hodgson
293380bef7 trailing commas 2018-01-19 15:38:53 +00:00
Matthew Hodgson
447f4f0d5f rewrite based on PR feedback:
* [ ] split config options into allowed_local_3pids and registrations_require_3pid
 * [ ] simplify and comment logic for picking registration flows
 * [ ] fix docstring and move check_3pid_allowed into a new util module
 * [ ] use check_3pid_allowed everywhere

@erikjohnston PTAL
2018-01-19 15:33:55 +00:00
Matthew Hodgson
9d332e0f79 fix up v1, and improve errors 2018-01-19 00:53:58 +00:00
Matthew Hodgson
0af58f14ee fix pep8 2018-01-19 00:33:51 +00:00
Matthew Hodgson
81d037dbd8 mock registrations_require_3pid 2018-01-19 00:28:08 +00:00
Matthew Hodgson
28a6ccb49c add registrations_require_3pid
lets homeservers specify a whitelist for 3PIDs that users are allowed to associate with.
Typically useful for stopping people from registering with non-work emails
2018-01-19 00:19:58 +00:00
Erik Johnston
cd871a3057 Fix storage provider bug introduced when renamed to store_local 2018-01-18 18:37:59 +00:00
Erik Johnston
8ff6726c0d Merge pull request #2812 from matrix-org/erikj/media_storage_provider_config
Make storage providers configurable
2018-01-18 18:33:57 +00:00
Erik Johnston
d69768348f Fix passing wrong config to provider constructor 2018-01-18 17:14:05 +00:00
Erik Johnston
8e85220373 Remove duplicate directory test 2018-01-18 17:12:35 +00:00
Erik Johnston
3fe2bae857 Missing staticmethod 2018-01-18 17:11:45 +00:00
Erik Johnston
aae77da73f Fixup comments 2018-01-18 17:11:29 +00:00
Erik Johnston
ce4f66133e Add unit tests 2018-01-18 16:32:11 +00:00
Erik Johnston
b6dc7044a9 Merge pull request #2804 from matrix-org/erikj/file_consumer
Add decent impl of a FileConsumer
2018-01-18 16:31:33 +00:00
Erik Johnston
9a89dae8c5 Fix typo in thumbnail resource causing access times to be incorrect 2018-01-18 15:06:24 +00:00
Erik Johnston
0af5dc63a8 Make storage providers more configurable 2018-01-18 14:07:21 +00:00
Richard van der Hoff
5a4da21d58 Merge pull request #2810 from matrix-org/rav/metrics_fixes
Fix bugs in block metrics
2018-01-18 13:35:28 +00:00
Richard van der Hoff
d57765fc8a Fix bugs in block metrics
... which I introduced in #2785
2018-01-18 12:24:42 +00:00
Erik Johnston
2cf6a7bc20 Use better file consumer 2018-01-18 12:00:46 +00:00
Erik Johnston
4a53f3a3e8 Ensure media is in local cache before thumbnailing 2018-01-18 12:00:46 +00:00
Erik Johnston
be0dfcd4a2 Do logcontexts correctly 2018-01-18 11:57:57 +00:00
Erik Johnston
1432f7ccd5 Move test stuff to tests 2018-01-18 11:57:57 +00:00
Erik Johnston
2f18a2647b Make all fields private 2018-01-18 11:57:54 +00:00
Richard van der Hoff
d6af5512bb Merge pull request #2809 from matrix-org/rav/metrics_errors
better exception logging in callbackmetrics
2018-01-18 11:46:37 +00:00
Richard van der Hoff
ce236f8ac8 better exception logging in callbackmetrics
when we fail to render a metric, give a clue as to which metric it was
2018-01-18 11:30:49 +00:00
Erik Johnston
dc519602ac Ensure we registerProducer isn't called twice 2018-01-18 11:07:17 +00:00
Erik Johnston
17b54389fe Fix _notify_empty typo 2018-01-18 11:05:34 +00:00
Erik Johnston
28b338ed9b Move definition of paused_producer to __init__ 2018-01-18 11:04:41 +00:00
Erik Johnston
a177325b49 Fix comments 2018-01-18 11:02:43 +00:00
Richard van der Hoff
36da256cc6 Merge pull request #2805 from matrix-org/rav/log_state_res
Log room when doing state resolution
2018-01-17 18:05:04 +00:00
Richard van der Hoff
1224612a79 Log room when doing state resolution
Mostly because it helps figure out what is prompting the resolution
2018-01-17 17:11:59 +00:00
Erik Johnston
bc67e7d260 Add decent impl of a FileConsumer
Twisted core doesn't have a general purpose one, so we need to write one
ourselves.

Features:
- All writing happens in background thread
- Supports both push and pull producers
- Push producers get paused if the consumer falls behind
2018-01-17 16:43:03 +00:00
Erik Johnston
a87006f9c7 Merge pull request #2783 from matrix-org/erikj/media_last_accessed
Keep track of last access time for local media
2018-01-17 16:39:02 +00:00
Matthew Hodgson
06db5c4b76 Merge pull request #2803 from matrix-org/matthew/fix-userdir-sql
fix SQL when searching all users
2018-01-17 16:27:54 +00:00
Richard van der Hoff
8716eb4920 Merge pull request #2802 from matrix-org/rav/clean_up_resolve_events
Split resolve_events into two functions
2018-01-17 16:01:33 +00:00
Matthew Hodgson
2d9ab533f9 fix SQL when searching all users 2018-01-17 15:58:52 +00:00
Richard van der Hoff
390093d45e Split resolve_events into two functions
... so that the return type doesn't depend on the arg types
2018-01-17 15:44:31 +00:00
Erik Johnston
2fb3a28c98 Remove lost comment 2018-01-17 14:59:44 +00:00
Richard van der Hoff
a7e4ff9cca Merge pull request #2795 from matrix-org/rav/track_db_scheduling
Track DB scheduling delay per-request
2018-01-17 14:29:37 +00:00
Richard van der Hoff
f884cfffb9 Merge pull request #2797 from matrix-org/rav/user_id_checking
Sanity checking for user ids
2018-01-17 14:29:12 +00:00
Richard van der Hoff
a5213df1f7 Sanity checking for user ids
Check the user_id passed to a couple of APIs for validity, to avoid
"IndexError: list index out of range" exception which looks scary and results
in a 500 rather than a more useful error.

Fixes #1432, among other things
2018-01-17 14:28:54 +00:00
Richard van der Hoff
e8f7541d3f Merge remote-tracking branch 'origin/develop' into rav/track_db_scheduling 2018-01-17 14:01:57 +00:00
Richard van der Hoff
fb6563b4be Merge pull request #2793 from matrix-org/rav/db_txn_time_in_millis
Track db txn time in millisecs
2018-01-17 13:52:42 +00:00
Richard van der Hoff
1954e867b4 Merge pull request #2796 from matrix-org/rav/fix_closed_connection_errors
Fix 'NoneType' object has no attribute 'writeHeaders'
2018-01-17 13:52:24 +00:00
Richard van der Hoff
f23b4078c0 Merge pull request #2794 from matrix-org/rav/rework_run_interaction
rework runInteraction in terms of runConnection
2018-01-17 13:50:42 +00:00
Richard van der Hoff
11ab2f56f5 Merge branch 'develop' into rav/fix_closed_connection_errors 2018-01-17 11:25:11 +00:00
Richard van der Hoff
0486a7814a Merge branch 'develop' into rav/rework_run_interaction 2018-01-17 11:24:38 +00:00
Richard van der Hoff
90c14da992 Merge branch 'develop' into rav/db_txn_time_in_millis 2018-01-17 11:19:55 +00:00
Richard van der Hoff
1067b96364 Merge pull request #2792 from matrix-org/rav/optimise_logging_context
Optimise LoggingContext creation and copying
2018-01-17 11:19:28 +00:00
Richard van der Hoff
38506773eb Merge remote-tracking branch 'origin/develop' into rav/optimise_logging_context 2018-01-17 10:57:13 +00:00
Erik Johnston
300edc2348 Update last access time when thumbnails are viewed 2018-01-17 10:24:43 +00:00
Erik Johnston
05f98a2224 Keep track of last access time for local media 2018-01-17 10:24:43 +00:00
Erik Johnston
3cb2dabaad Merge pull request #2789 from matrix-org/erikj/fix_thumbnails
Fix thumbnailing remote files
2018-01-17 10:22:20 +00:00
Erik Johnston
d728c47142 Add docstring 2018-01-17 10:06:14 +00:00
Richard van der Hoff
936482d507 Fix 'NoneType' object has no attribute 'writeHeaders'
Avoid throwing a (harmless) exception when we try to write an error response to
an http request where the client has disconnected.

This comes up as a CRITICAL error in the logs which tends to mislead people
into thinking there's an actual problem
2018-01-16 17:58:16 +00:00
Richard van der Hoff
3d12d97415 Track DB scheduling delay per-request
For each request, track the amount of time spent waiting for a db
connection. This entails adding it to the LoggingContext and we may as well add
metrics for it while we are passing.
2018-01-16 17:23:32 +00:00
Richard van der Hoff
0f5d2cc37c Merge branch 'rav/db_txn_time_in_millis' into rav/track_db_scheduling 2018-01-16 17:09:15 +00:00
Richard van der Hoff
8615f19d20 rework runInteraction in terms of runConnection
... so that we can share the code
2018-01-16 17:08:29 +00:00
Matthew Hodgson
5e97ca7ee6 fix typo 2018-01-16 16:52:35 +00:00
Erik Johnston
d863f68cab Use local vars 2018-01-16 16:24:15 +00:00
Erik Johnston
6368e5c0ab Change _generate_thumbnails to take media_type 2018-01-16 16:17:38 +00:00
Erik Johnston
0a90d9ede4 Move setting of file_id up to caller 2018-01-16 16:03:05 +00:00
Richard van der Hoff
6324b65f08 Track db txn time in millisecs
... to reduce the amount of floating-point foo we do.
2018-01-16 15:53:18 +00:00
Richard van der Hoff
44a498418c Optimise LoggingContext creation and copying
It turns out that the only thing we use the __dict__ of LoggingContext for is
`request`, and given we create lots of LoggingContexts and then copy them every
time we do a db transaction or log line, using the __dict__ seems a bit
redundant. Let's try to optimise things by making the request attribute
explicit.
2018-01-16 15:49:42 +00:00
Erik Johnston
5dfc83704b Fix typo 2018-01-16 14:32:56 +00:00
Richard van der Hoff
febdca4b37 Merge pull request #2785 from matrix-org/rav/reorganise_metrics_again
Reorganise request and block metrics
2018-01-16 14:10:21 +00:00
Richard van der Hoff
f5f89fda21 Merge pull request #2790 from matrix-org/rav/preserve_event_logcontext_leak
Fix a logcontext leak in persist_events
2018-01-16 14:09:13 +00:00
Erik Johnston
307f88dfb6 Fix up log lines 2018-01-16 13:53:52 +00:00
Richard van der Hoff
807e848f0f Merge pull request #2787 from matrix-org/rav/worker_event_counts
Metrics for events processed in appservice and fed sender
2018-01-16 13:14:09 +00:00
Richard van der Hoff
4a31a61ef9 Merge pull request #2786 from matrix-org/rav/rdata_metrics
Metrics for number of RDATA commands received
2018-01-16 13:13:53 +00:00
Richard van der Hoff
ee7a1cabd8 document metrics changes 2018-01-16 13:04:01 +00:00
Erik Johnston
9795b9ebb1 Correctly use server_name/file_id when generating/fetching remote thumbnails 2018-01-16 12:02:06 +00:00
Erik Johnston
c5b589f2e8 Log when we respond with 404 2018-01-16 12:01:40 +00:00
Richard van der Hoff
64ddec1bc0 Fix a logcontext leak in persist_events
ObserveableDeferred expects its callbacks to be called without any
logcontexts, whereas it turns out we were calling them with the logcontext of
the request which initiated the persistence loop.

It seems wrong that we are attributing work done in the persistence loop to the
request that happened to initiate it, so let's solve this by dropping the
logcontext for it.

(I'm not sure this actually causes any real problems other than messages in the
debug log, but let's clean it up anyway)
2018-01-16 11:47:36 +00:00
Erik Johnston
a4c5e4a645 Fix thumbnailing remote files 2018-01-16 11:37:50 +00:00
Erik Johnston
1159abbdd2 Merge pull request #2767 from matrix-org/erikj/media_storage_refactor
Refactor MediaRepository to separate out storage
2018-01-16 10:23:50 +00:00
Richard van der Hoff
a027c2af8d Metrics for events processed in appservice and fed sender
More metrics I wished I'd had
2018-01-15 18:23:24 +00:00
Richard van der Hoff
5c3c32f16f Metrics for number of RDATA commands received
I found myself wishing we had this.
2018-01-15 17:45:55 +00:00
Richard van der Hoff
39f4e29d01 Reorganise request and block metrics
In order to circumvent the number of duplicate foo:count metrics increasing
without bounds, it's time for a rearrangement.

The following are all deprecated, and replaced with synapse_util_metrics_block_count:
  synapse_util_metrics_block_timer:count
  synapse_util_metrics_block_ru_utime:count
  synapse_util_metrics_block_ru_stime:count
  synapse_util_metrics_block_db_txn_count:count
  synapse_util_metrics_block_db_txn_duration:count

The following are all deprecated, and replaced with synapse_http_server_response_count:
   synapse_http_server_requests
   synapse_http_server_response_time:count
   synapse_http_server_response_ru_utime:count
   synapse_http_server_response_ru_stime:count
   synapse_http_server_response_db_txn_count:count
   synapse_http_server_response_db_txn_duration:count

The following are renamed (the old metrics are kept for now, but deprecated):

  synapse_util_metrics_block_timer:total ->
     synapse_util_metrics_block_time_seconds

  synapse_util_metrics_block_ru_utime:total ->
     synapse_util_metrics_block_ru_utime_seconds

  synapse_util_metrics_block_ru_stime:total ->
     synapse_util_metrics_block_ru_stime_seconds

  synapse_util_metrics_block_db_txn_count:total ->
     synapse_util_metrics_block_db_txn_count

  synapse_util_metrics_block_db_txn_duration:total ->
     synapse_util_metrics_block_db_txn_duration_seconds

  synapse_http_server_response_time:total ->
     synapse_http_server_response_time_seconds

  synapse_http_server_response_ru_utime:total ->
     synapse_http_server_response_ru_utime_seconds

  synapse_http_server_response_ru_stime:total ->
     synapse_http_server_response_ru_stime_seconds

   synapse_http_server_response_db_txn_count:total ->
      synapse_http_server_response_db_txn_count

   synapse_http_server_response_db_txn_duration:total
      synapse_http_server_response_db_txn_duration_seconds
2018-01-15 17:09:44 +00:00
Richard van der Hoff
992018d1c0 mechanism to render metrics with alternative names 2018-01-15 17:04:39 +00:00
Richard van der Hoff
80fa610f9c Add some comments to metrics classes 2018-01-15 16:52:52 +00:00
Richard van der Hoff
5e16c1dc8c Merge pull request #2778 from matrix-org/rav/counters_should_be_floats
Make Counter render floats
2018-01-15 14:09:53 +00:00
Richard van der Hoff
19d274085f Make Counter render floats
Prometheus handles all metrics as floats, and sometimes we store non-integer
values in them (notably, durations in seconds), so let's render them as floats
too.

(Note that the standard client libraries also treat Counters as floats.)
2018-01-12 23:49:44 +00:00
Richard van der Hoff
0fc2362d37 Merge pull request #2777 from matrix-org/rav/fix_remote_thumbnails
Reinstate media download on thumbnail request
2018-01-12 19:12:31 +00:00
Richard van der Hoff
21bf87a146 Reinstate media download on thumbnail request
We need to actually download the remote media when we get a request for a
thumbnail.
2018-01-12 15:38:06 +00:00
Erik Johnston
694f1c1b18 Fix up comments 2018-01-12 15:02:46 +00:00
Erik Johnston
e21370ba54 Correctly reraise exception 2018-01-12 14:44:02 +00:00
Erik Johnston
85a4d78213 Make Responder a context manager 2018-01-12 13:32:03 +00:00
Erik Johnston
dcc8eded41 Add missing class var 2018-01-12 13:16:27 +00:00
Erik Johnston
fefeb0ab0e Merge pull request #2774 from matrix-org/erikj/synctl
When using synctl with workers, don't start the main synapse automatically
2018-01-12 12:00:07 +00:00
Erik Johnston
81391fa162 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/media_storage_refactor 2018-01-12 11:28:49 +00:00
Erik Johnston
1e4edd1717 Remove unnecessary condition 2018-01-12 11:28:32 +00:00
Erik Johnston
c6c009603c Remove unused variables 2018-01-12 11:24:05 +00:00
Erik Johnston
4d88958cf6 Make class var local 2018-01-12 11:23:54 +00:00
Erik Johnston
227c491510 Comments 2018-01-12 11:22:41 +00:00
Erik Johnston
f4d93ae424 Actually make it work 2018-01-12 10:39:27 +00:00
Erik Johnston
f68e4cf690 Refactor 2018-01-12 10:11:12 +00:00
Richard van der Hoff
5f23b6d5ea Merge pull request #2766 from matrix-org/rav/room_event
Add /room/{id}/event/{id} to synapse
2018-01-11 14:47:18 +00:00
Erik Johnston
7cd34512d8 When using synctl with workers, don't start the main synapse automatically 2018-01-11 11:37:39 +00:00
Erik Johnston
07ab948c38 Merge branch 'master' of github.com:matrix-org/synapse into develop 2018-01-11 11:23:40 +00:00
Erik Johnston
825a07a974 Merge pull request #2773 from matrix-org/erikj/hash_bg
Do bcrypt hashing in a background thread
2018-01-10 18:11:41 +00:00
Erik Johnston
f8e1ab5fee Do bcrypt hashing in a background thread 2018-01-10 18:01:28 +00:00
Erik Johnston
b9e4a97922 Merge pull request #2772 from matrix-org/t3chguy/patch-1
Fix publicised groups GET API (singular) over federation
2018-01-10 15:15:48 +00:00
Michael Telatynski
5f07f5694c fix order of operations derp and also use .get to default to {}
Signed-off-by: Michael Telatynski <7t3chguy@gmail.com>
2018-01-10 15:11:35 +00:00
Michael Telatynski
8c9d5b4873 Fix publicised groups API (singular) over federation
which was missing its fed client API, since there is no other API
it might as well reuse the bulk one and unwrap it

Signed-off-by: Michael Telatynski <7t3chguy@gmail.com>
2018-01-10 15:04:51 +00:00
Richard van der Hoff
c175a5f0f2 Merge pull request #2770 from matrix-org/rav/fix_request_metrics
Update http request metrics before calling servlet
2018-01-10 14:53:40 +00:00
Richard van der Hoff
d90e8ea444 Update http request metrics before calling servlet
Make sure that we set the servlet name in the metrics object *before* calling
the servlet, in case the servlet throws an exception.
2018-01-09 18:27:35 +00:00
Erik Johnston
8f03aa9f61 Add StorageProvider concept 2018-01-09 16:16:12 +00:00
Erik Johnston
2442e9876c Make PreviewUrlResource use MediaStorage 2018-01-09 16:15:07 +00:00
Erik Johnston
9d30a7691c Make ThumbnailResource use MediaStorage 2018-01-09 16:15:07 +00:00
Erik Johnston
9e20840e02 Use MediaStorage for remote media 2018-01-09 16:15:07 +00:00
Erik Johnston
dd3092c3a3 Use MediaStorage for local files 2018-01-09 16:15:07 +00:00
Erik Johnston
ada470bccb Add MediaStorage class 2018-01-09 16:15:07 +00:00
Erik Johnston
1ee787912b Add some helper classes 2018-01-09 16:15:07 +00:00
Erik Johnston
47ca5eb882 Split out add_file_headers 2018-01-09 16:15:07 +00:00
Erik Johnston
ce3a726fc0 Merge pull request #2764 from matrix-org/erikj/remove_dead_thumbnail_code
Remove dead code related to default thumbnails
2018-01-09 16:01:02 +00:00
Erik Johnston
b6c9deffda Remove dead TODO 2018-01-09 15:53:23 +00:00
Richard van der Hoff
51c9d9ed65 Add /room/{id}/event/{id} to synapse
Turns out that there is a valid usecase for retrieving event by id (notably
having received a push), but event ids should be scoped to room, so /event/{id}
is wrong.
2018-01-09 14:39:12 +00:00
Erik Johnston
b30cd5b107 Remove dead code related to default thumbnails 2018-01-09 14:38:33 +00:00
Richard van der Hoff
a767f06e3f Merge pull request #2765 from matrix-org/rav/fix_room_uts
Fix flaky test_rooms UTs
2018-01-09 14:21:03 +00:00
Richard van der Hoff
cb66a2d387 Merge pull request #2763 from matrix-org/rav/fix_config_uts
Fix broken config UTs
2018-01-09 12:08:08 +00:00
Richard van der Hoff
aed4e4ecdd Merge pull request #2762 from matrix-org/rav/fix_logconfig_indenting
Make indentation of generated log config consistent
2018-01-09 12:07:56 +00:00
Richard van der Hoff
f8fa5ae4af enable twisted delayedcall debugging in UTs 2018-01-09 12:06:45 +00:00
Richard van der Hoff
374c4d4ced Remove dead code
pointless function is pointless
2018-01-09 12:06:45 +00:00
Richard van der Hoff
142fb0a7d4 Disable user_directory updates for UTs
Fix flakiness in the UTs caused by the user_directory being updated in the
background
2018-01-09 12:06:45 +00:00
Richard van der Hoff
0211464ba2 Fix broken config UTs
https://github.com/matrix-org/synapse/pull/2755 broke log-config generation,
which in turn broke the unit tests.
2018-01-09 11:28:33 +00:00
Richard van der Hoff
3a556f1ea0 Make indentation of generated log config consistent
(we had a mix of 2- and 4-space indents)
2018-01-09 11:27:19 +00:00
Erik Johnston
e9f7677170 Merge pull request #2761 from turt2live/patch-1
Fix templating error with unban permission message
2018-01-08 10:10:49 +00:00
Travis Ralston
eccfc8e928 Fix templating error with unban permission message
Fixes https://github.com/matrix-org/synapse/issues/2759

Signed-off-by: Travis Ralston <travpc@gmail.com>
2018-01-07 19:52:58 -07:00
Richard van der Hoff
e6b24663e4 Merge pull request #2755 from matrix-org/richvdh-patch-1
Remove 'verbosity'/'log_file' from generated cfg
2018-01-05 14:17:46 +00:00
Richard van der Hoff
840f72356e Remove 'verbosity'/'log_file' from generated cfg
... because these only really exist to confuse people nowadays.

Also bring log config more into line with the generated log config, by making `level_for_storage`
apply to the `synapse.storage.SQL` logger rather than `synapse.storage`.
2018-01-05 12:30:28 +00:00
Richard van der Hoff
6e375f4597 Merge pull request #2744 from matrix-org/rav/login_logging
Better logging when login can't find a 3pid
2018-01-05 09:54:39 +00:00
Richard van der Hoff
efdfd5c835 Merge pull request #2745 from matrix-org/rav/assert_params
Check missing fields in event_from_pdu_json
2018-01-05 09:52:39 +00:00
Richard van der Hoff
bd91857028 Check missing fields in event_from_pdu_json
Return a 400 rather than a 500 when somebody messes up their send_join
2017-12-30 18:40:19 +00:00
Richard van der Hoff
3079f80d4a Factor out event_from_pdu_json
turns out we have two copies of this, and neither needs to be an instance
method
2017-12-30 18:40:19 +00:00
Richard van der Hoff
65abc90fb6 federation_server: clean up imports 2017-12-30 18:40:19 +00:00
Richard van der Hoff
a7b726ad18 federation_client: clean up imports 2017-12-30 18:40:19 +00:00
Richard van der Hoff
75c1b8df01 Better logging when login can't find a 3pid 2017-12-20 19:31:00 +00:00
Richard van der Hoff
3f9f1c50f3 Merge pull request #2683 from seckrv/fix_pwd_auth_prov_typo
synapse/config/password_auth_providers: Fixed bracket typo
2017-12-18 22:37:21 +00:00
Richard van der Hoff
48fa4e1e5b Merge pull request #2435 from silkeh/listen-ipv6-default
Adapt the default config to bind on both IPv4 and IPv6 on all platforms
2017-12-18 22:34:50 +00:00
Silke
df0f602796 Implement listen_tcp method in remaining workers
Signed-off-by: Silke <silke@slxh.eu>
2017-12-18 20:00:42 +01:00
Silke
26cd3f5690 Remove logger argument and do not catch replication listener
Signed-off-by: Silke <silke@slxh.eu>
2017-12-18 20:00:42 +01:00
Silke Hofstra
ed48ecc58c Add methods for listening on multiple addresses
Add listen_tcp and listen_ssl which implement Twisted's reactor.listenTCP
and reactor.listenSSL for multiple addresses.

Signed-off-by: Silke Hofstra <silke@slxh.eu>
2017-12-17 13:15:48 +01:00
Silke Hofstra
37d1a90025 Allow binds to both :: and 0.0.0.0
Binding on 0.0.0.0 when :: is specified in the bind_addresses is now allowed.
This causes a warning explaining the behaviour.
Configuration changed to match.

See #2232

Signed-off-by: Silke Hofstra <silke@slxh.eu>
2017-12-17 13:10:31 +01:00
Willem Mulder
3e59143ba8 Adapt the default config to bind on IPv6.
Most deployments are on Linux (or Mac OS), so this would actually bind
on both IPv4 and IPv6.

Resolves #1886.

Signed-off-by: Willem Mulder <willemmaster@hotmail.com>
2017-12-17 13:07:37 +01:00
Oliver Kurz
83d8d4d8cd Allow use of higher versions of saml2
The package was pinned to <4.0 with 07cf96eb because "from saml2 import
config" did not work. This seems to have been fixed in the mean time in the
saml2 package and therefore should not stop to use a more recent version.

Signed-off-by: Oliver Kurz <okurz@suse.de>
2017-11-20 11:14:39 +01:00
Richard von Seck
6f05de0e5e synapse/config/password_auth_providers: Fixed bracket typo
Signed-off-by: Richard von Seck <richard.von-seck@gmx.net>
2017-11-16 15:59:38 +01:00
118 changed files with 4452 additions and 2115 deletions

2
.gitignore vendored
View File

@@ -46,3 +46,5 @@ static/client/register/register_config.js
env/
*.config
.vscode/

View File

@@ -1,3 +1,13 @@
Unreleased
==========
synctl no longer starts the main synapse when using ``-a`` option with workers.
A new worker file should be added with ``worker_app: synapse.app.homeserver``.
This release also begins the process of renaming a number of the metrics
reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
Changes in synapse v0.26.0 (2018-01-05)
=======================================

View File

@@ -0,0 +1,23 @@
# List all media in a room
This API gets a list of known media in a room.
The API is:
```
GET /_matrix/client/r0/admin/room/<room_id>/media
```
including an `access_token` of a server admin.
It returns a JSON body like the following:
```
{
"local": [
"mxc://localhost/xwvutsrqponmlkjihgfedcba",
"mxc://localhost/abcdefghijklmnopqrstuvwx"
],
"remote": [
"mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
"mxc://matrix.org/abcdefghijklmnopqrstuvwx"
]
}
```

View File

@@ -4,8 +4,6 @@ Purge History API
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
**NB!** This will not delete local events (locally sent messages content etc) from the database, but will remove lots of the metadata about them and does dramatically reduce the on disk space usage
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
@@ -15,3 +13,15 @@ The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.
By default, events sent by local users are not deleted, as they may represent
the only copies of this content in existence. (Events sent by remote users are
deleted, and room state data before the cutoff is always removed).
To delete local events as well, set ``delete_local_events`` in the body:
.. code:: json
{
"delete_local_events": true
}

View File

@@ -16,7 +16,7 @@ How to monitor Synapse metrics using Prometheus
metrics_port: 9092
Also ensure that ``enable_metrics`` is set to ``True``.
Restart synapse.
3. Add a prometheus target for synapse.
@@ -28,11 +28,58 @@ How to monitor Synapse metrics using Prometheus
static_configs:
- targets: ["my.server.here:9092"]
If your prometheus is older than 1.5.2, you will need to replace
If your prometheus is older than 1.5.2, you will need to replace
``static_configs`` in the above with ``target_groups``.
Restart prometheus.
Block and response metrics renamed for 0.27.0
---------------------------------------------
Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
metrics reported for the resource tracking for code blocks and HTTP requests.
At the same time, the corresponding ``*:total`` metrics are being renamed, as
the ``:total`` suffix no longer makes sense in the absence of a corresponding
``:count`` metric.
To enable a graceful migration path, this release just adds new names for the
metrics being renamed. A future release will remove the old ones.
The following table shows the new metrics, and the old metrics which they are
replacing.
==================================================== ===================================================
New name Old name
==================================================== ===================================================
synapse_util_metrics_block_count synapse_util_metrics_block_timer:count
synapse_util_metrics_block_count synapse_util_metrics_block_ru_utime:count
synapse_util_metrics_block_count synapse_util_metrics_block_ru_stime:count
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_count:count
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_duration:count
synapse_util_metrics_block_time_seconds synapse_util_metrics_block_timer:total
synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_utime:total
synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_ru_stime:total
synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_count:total
synapse_util_metrics_block_db_txn_duration_seconds synapse_util_metrics_block_db_txn_duration:total
synapse_http_server_response_count synapse_http_server_requests
synapse_http_server_response_count synapse_http_server_response_time:count
synapse_http_server_response_count synapse_http_server_response_ru_utime:count
synapse_http_server_response_count synapse_http_server_response_ru_stime:count
synapse_http_server_response_count synapse_http_server_response_db_txn_count:count
synapse_http_server_response_count synapse_http_server_response_db_txn_duration:count
synapse_http_server_response_time_seconds synapse_http_server_response_time:total
synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_utime:total
synapse_http_server_response_ru_stime_seconds synapse_http_server_response_ru_stime:total
synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_count:total
synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
==================================================== ===================================================
Standard Metric Names
---------------------
@@ -42,7 +89,7 @@ have been changed to seconds, from miliseconds.
================================== =============================
New name Old name
---------------------------------- -----------------------------
================================== =============================
process_cpu_user_seconds_total process_resource_utime / 1000
process_cpu_system_seconds_total process_resource_stime / 1000
process_open_fds (no 'type' label) process_fds
@@ -52,8 +99,8 @@ The python-specific counts of garbage collector performance have been renamed.
=========================== ======================
New name Old name
--------------------------- ----------------------
python_gc_time reactor_gc_time
=========================== ======================
python_gc_time reactor_gc_time
python_gc_unreachable_total reactor_gc_unreachable
python_gc_counts reactor_gc_counts
=========================== ======================
@@ -62,7 +109,7 @@ The twisted-specific reactor metrics have been renamed.
==================================== =====================
New name Old name
------------------------------------ ---------------------
==================================== =====================
python_twisted_reactor_pending_calls reactor_pending_calls
python_twisted_reactor_tick_time reactor_tick_time
==================================== =====================

View File

@@ -0,0 +1,133 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Moves a list of remote media from one media store to another.
The input should be a list of media files to be moved, one per line. Each line
should be formatted::
<origin server>|<file id>
This can be extracted from postgres with::
psql --tuples-only -A -c "select media_origin, filesystem_id from
matrix.remote_media_cache where ..."
To use, pipe the above into::
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
"""
from __future__ import print_function
import argparse
import logging
import sys
import os
import shutil
from synapse.rest.media.v1.filepath import MediaFilePaths
logger = logging.getLogger()
def main(src_repo, dest_repo):
src_paths = MediaFilePaths(src_repo)
dest_paths = MediaFilePaths(dest_repo)
for line in sys.stdin:
line = line.strip()
parts = line.split('|')
if len(parts) != 2:
print("Unable to parse input line %s" % line, file=sys.stderr)
exit(1)
move_media(parts[0], parts[1], src_paths, dest_paths)
def move_media(origin_server, file_id, src_paths, dest_paths):
"""Move the given file, and any thumbnails, to the dest repo
Args:
origin_server (str):
file_id (str):
src_paths (MediaFilePaths):
dest_paths (MediaFilePaths):
"""
logger.info("%s/%s", origin_server, file_id)
# check that the original exists
original_file = src_paths.remote_media_filepath(origin_server, file_id)
if not os.path.exists(original_file):
logger.warn(
"Original for %s/%s (%s) does not exist",
origin_server, file_id, original_file,
)
else:
mkdir_and_move(
original_file,
dest_paths.remote_media_filepath(origin_server, file_id),
)
# now look for thumbnails
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
origin_server, file_id,
)
if not os.path.exists(original_thumb_dir):
return
mkdir_and_move(
original_thumb_dir,
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
)
def mkdir_and_move(original_file, dest_file):
dirname = os.path.dirname(dest_file)
if not os.path.exists(dirname):
logger.debug("mkdir %s", dirname)
os.makedirs(dirname)
logger.debug("mv %s %s", original_file, dest_file)
shutil.move(original_file, dest_file)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class = argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
"-v", action='store_true', help='enable debug logging')
parser.add_argument(
"src_repo",
help="Path to source content repo",
)
parser.add_argument(
"dest_repo",
help="Path to source content repo",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
logging.basicConfig(**logging_config)
main(args.src_repo, args.dest_repo)

View File

@@ -46,6 +46,7 @@ class Codes(object):
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
@@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
pass
class FederationDeniedError(SynapseError):
"""An error raised when the server tries to federate with a server which
is not on its federation whitelist.
Attributes:
destination (str): The destination which has been denied
"""
def __init__(self, destination):
"""Raised by federation client or server to indicate that we are
are deliberately not attempting to contact a given server because it is
not on our federation whitelist.
Args:
destination (str): the domain in question
"""
self.destination = destination
super(FederationDeniedError, self).__init__(
code=403,
msg="Federation denied with %s." % (self.destination,),
errcode=Codes.FORBIDDEN,
)
class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete

View File

@@ -25,7 +25,9 @@ except Exception:
from daemonize import Daemonize
from synapse.util import PreserveLoggingContext
from synapse.util.rlimit import change_resource_limit
from twisted.internet import reactor
from twisted.internet import error, reactor
logger = logging.getLogger(__name__)
def start_worker_reactor(appname, config):
@@ -120,3 +122,57 @@ def quit_with_error(error_string):
sys.stderr.write(" %s\n" % (line.rstrip(),))
sys.stderr.write("*" * line_length + '\n')
sys.exit(1)
def listen_tcp(bind_addresses, port, factory, backlog=50):
"""
Create a TCP socket for a port and several addresses
"""
for address in bind_addresses:
try:
reactor.listenTCP(
port,
factory,
backlog,
address
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
def listen_ssl(bind_addresses, port, factory, context_factory, backlog=50):
"""
Create an SSL socket for a port and several addresses
"""
for address in bind_addresses:
try:
reactor.listenSSL(
port,
factory,
context_factory,
backlog,
address
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
def check_bind_error(e, address, bind_addresses):
"""
This method checks an exception occurred while binding on 0.0.0.0.
If :: is specified in the bind addresses a warning is shown.
The exception is still raised otherwise.
Binding on both 0.0.0.0 and :: causes an exception on Linux and macOS
because :: binds on both IPv4 and IPv6 (as per RFC 3493).
When binding on 0.0.0.0 after :: this can safely be ignored.
Args:
e (Exception): Exception that was caught.
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
"""
if address == '0.0.0.0' and '::' in bind_addresses:
logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
else:
raise e

View File

@@ -49,19 +49,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
@@ -79,17 +66,16 @@ class AppserviceServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse appservice now listening on port %d", port)
@@ -98,18 +84,15 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
@@ -103,17 +90,16 @@ class ClientReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse client reader now listening on port %d", port)
@@ -122,18 +108,16 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
@@ -92,17 +79,16 @@ class FederationReaderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse federation reader now listening on port %d", port)
@@ -111,18 +97,15 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
@@ -106,17 +93,16 @@ class FederationSenderServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse federation_sender now listening on port %d", port)
@@ -125,18 +111,15 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
@@ -157,17 +144,16 @@ class FrontendProxyServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse client reader now listening on port %d", port)
@@ -176,18 +162,15 @@ class FrontendProxyServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -25,7 +25,7 @@ from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \
LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \
STATIC_PREFIX, WEB_CLIENT_PREFIX
from synapse.app import _base
from synapse.app._base import quit_with_error
from synapse.app._base import quit_with_error, listen_ssl, listen_tcp
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
@@ -130,30 +130,29 @@ class SynapseHomeServer(HomeServer):
root_resource = create_resource_tree(resources, root_resource)
if tls:
for address in bind_addresses:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=address
)
listen_ssl(
bind_addresses,
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
)
else:
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse now listening on port %d", port)
def _configure_named_resource(self, name, compress=False):
@@ -229,18 +228,15 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
elif listener["type"] == "replication":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
@@ -270,19 +266,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(config_options):
"""

View File

@@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
@@ -99,17 +86,16 @@ class MediaRepositoryServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse media repository now listening on port %d", port)
@@ -118,18 +104,15 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -81,19 +81,6 @@ class PusherSlaveStore(
class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
@@ -114,17 +101,16 @@ class PusherServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse pusher now listening on port %d", port)
@@ -133,18 +119,15 @@ class PusherServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
@@ -288,17 +275,16 @@ class SynchrotronServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse synchrotron now listening on port %d", port)
@@ -307,18 +293,15 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -184,6 +184,9 @@ def main():
worker_configfiles.append(worker_configfile)
if options.all_processes:
# To start the main synapse with -a you need to add a worker file
# with worker_app == "synapse.app.homeserver"
start_stop_synapse = False
worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir):
write(
@@ -200,11 +203,29 @@ def main():
with open(worker_configfile) as stream:
worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"]
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
worker_configfile, "worker_daemonize")
worker_cache_factor = worker_config.get("synctl_cache_factor")
if worker_app == "synapse.app.homeserver":
# We need to special case all of this to pick up options that may
# be set in the main config file or in this worker config file.
worker_pidfile = (
worker_config.get("pid_file")
or pidfile
)
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
daemonize = worker_config.get("daemonize") or config.get("daemonize")
assert daemonize, "Main process must have daemonize set to true"
# The master process doesn't support using worker_* config.
for key in worker_config:
if key == "worker_app": # But we allow worker_app
continue
assert not key.startswith("worker_"), \
"Main process cannot use worker_* config"
else:
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
worker_configfile, "worker_daemonize")
worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
))

View File

@@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
@@ -131,17 +118,16 @@ class UserDirectoryServer(HomeServer):
root_resource = create_resource_tree(resources, Resource())
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
_base.listen_tcp(
bind_addresses,
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
)
)
logger.info("Synapse user_dir now listening on port %d", port)
@@ -150,18 +136,15 @@ class UserDirectoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
_base.listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
)
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
- %(message)s'
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
%(request)s - %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ${log_file}
maxBytes: 104857600
backupCount: 10
filters: [context]
console:
class: logging.StreamHandler
formatter: precise
filters: [context]
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ${log_file}
maxBytes: 104857600
backupCount: 10
filters: [context]
console:
class: logging.StreamHandler
formatter: precise
filters: [context]
loggers:
synapse:
@@ -74,17 +74,10 @@ class LoggingConfig(Config):
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
# Logging verbosity level. Ignored if log_config is specified.
verbose: 0
# File to write logging to. Ignored if log_config is specified.
log_file: "%(log_file)s"
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
@@ -123,9 +116,10 @@ class LoggingConfig(Config):
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
with open(log_config, "wb") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
)
@@ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False):
)
if log_config is None:
# We don't have a logfile, so fall back to the 'verbosity' param from
# the config or cmdline. (Note that we generate a log config for new
# installs, so this will be an unusual case)
level = logging.INFO
level_for_storage = logging.INFO
if config.verbosity:
@@ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False):
if config.verbosity > 1:
level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option
logger = logging.getLogger('')
logger.setLevel(level)
logging.getLogger('synapse.storage').setLevel(level_for_storage)
logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
formatter = logging.Formatter(log_format)
if log_file:

View File

@@ -29,10 +29,10 @@ class PasswordAuthProviderConfig(Config):
# param.
ldap_config = config.get("ldap_config", {})
if ldap_config.get("enabled", False):
providers.append[{
providers.append({
'module': LDAP_PROVIDER,
'config': ldap_config,
}]
})
providers.extend(config.get("password_providers", []))
for provider in providers:

View File

@@ -31,6 +31,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -52,6 +54,23 @@ class RegistrationConfig(Config):
# Enable registration for new users.
enable_registration: False
# The user must provide all of the below types of 3PID when registering.
#
# registrations_require_3pid:
# - email
# - msisdn
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
# allowed_local_3pids:
# - medium: email
# pattern: ".*@matrix\\.org"
# - medium: email
# pattern: ".*@vector\\.im"
# - medium: msisdn
# pattern: "\\+44"
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"

View File

@@ -16,6 +16,8 @@
from ._base import Config, ConfigError
from collections import namedtuple
from synapse.util.module_loader import load_module
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
@@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
MediaStorageProviderConfig = namedtuple(
"MediaStorageProviderConfig", (
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
),
)
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
@@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
self.media_store_path = self.ensure_directory(config["media_store_path"])
self.backup_media_store_path = config.get("backup_media_store_path")
if self.backup_media_store_path:
self.backup_media_store_path = self.ensure_directory(
self.backup_media_store_path
)
backup_media_store_path = config.get("backup_media_store_path")
self.synchronous_backup_media_store = config.get(
synchronous_backup_media_store = config.get(
"synchronous_backup_media_store", False
)
storage_providers = config.get("media_storage_providers", [])
if backup_media_store_path:
if storage_providers:
raise ConfigError(
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
)
storage_providers = [{
"module": "file_system",
"store_local": True,
"store_synchronous": synchronous_backup_media_store,
"store_remote": True,
"config": {
"directory": backup_media_store_path,
}
}]
# This is a list of config that can be used to create the storage
# providers. The entries are tuples of (Class, class_config,
# MediaStorageProviderConfig), where Class is the class of the provider,
# the class_config the config to pass to it, and
# MediaStorageProviderConfig are options for StorageProviderWrapper.
#
# We don't create the storage providers here as not all workers need
# them to be started.
self.media_storage_providers = []
for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
if provider_config["module"] == "file_system":
provider_config["module"] = (
"synapse.rest.media.v1.storage_provider"
".FileStorageProviderBackend"
)
provider_class, parsed_config = load_module(provider_config)
wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False),
provider_config.get("store_remote", False),
provider_config.get("store_synchronous", False),
)
self.media_storage_providers.append(
(provider_class, parsed_config, wrapper_config,)
)
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
# A secondary directory where uploaded images and attachments are
# stored as a backup.
# backup_media_store_path: "%(media_store)s"
# Whether to wait for successful write to backup media store before
# returning successfully.
# synchronous_backup_media_store: false
# Media storage providers allow media to be stored in different
# locations.
# media_storage_providers:
# - module: file_system
# # Whether to write new local files.
# store_local: false
# # Whether to write new remote media
# store_remote: false
# # Whether to block upload requests waiting for write to this
# # provider to complete
# store_synchronous: false
# config:
# directory: /mnt/some/other/directory
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"

View File

@@ -55,6 +55,17 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
federation_domain_whitelist = config.get(
"federation_domain_whitelist", None
)
# turn the whitelist into a hash for speed of lookup
if federation_domain_whitelist is not None:
self.federation_domain_whitelist = {}
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -210,6 +221,17 @@ class ServerConfig(Config):
# (except those sent by local server admins). The default is False.
# block_non_admin_invites: True
# Restrict federation to the following whitelist of domains.
# N.B. we recommend also firewalling your federation listener to limit
# inbound federation traffic as early as possible, rather than relying
# purely on this application-layer restriction. If not specified, the
# default is to whitelist everything.
#
# federation_domain_whitelist:
# - lon.example.com
# - nyc.example.com
# - syd.example.com
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
@@ -220,13 +242,12 @@ class ServerConfig(Config):
port: %(bind_port)s
# Local addresses to listen on.
# This will listen on all IPv4 addresses by default.
# On Linux and Mac OS, `::` will listen on all IPv4 and IPv6
# addresses by default. For most other OSes, this will only listen
# on IPv6.
bind_addresses:
- '::'
- '0.0.0.0'
# Uncomment to listen on all IPv6 interfaces
# N.B: On at least Linux this will also listen on all IPv4
# addresses, so you will need to comment out the line above.
# - '::'
# This is a 'http' listener, allows us to specify 'resources'.
type: http
@@ -264,7 +285,7 @@ class ServerConfig(Config):
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
bind_addresses: ['0.0.0.0']
bind_addresses: ['::', '0.0.0.0']
type: http
x_forwarded: false
@@ -278,7 +299,7 @@ class ServerConfig(Config):
# Turn on the twisted ssh manhole service on localhost on the given
# port.
# - port: 9000
# bind_address: 127.0.0.1
# bind_addresses: ['::1', '127.0.0.1']
# type: manhole
""" % locals()

View File

@@ -96,7 +96,7 @@ class TlsConfig(Config):
# certificates returned by this server match one of the fingerprints.
#
# Synapse automatically adds the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse
# to the list. So if federation traffic is handled directly by synapse
# then no modification to the list is required.
#
# If synapse is run behind a load balancer that handles the TLS then it

View File

@@ -23,6 +23,11 @@ class WorkerConfig(Config):
def read_config(self, config):
self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
if self.worker_app == "synapse.app.homeserver":
self.worker_app = None
self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")

View File

@@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
403, "You cannot unban user %s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)

View File

@@ -25,7 +25,9 @@ class EventContext(object):
The current state map excluding the current event.
(type, state_key) -> event_id
state_group (int): state group id
state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else
False

View File

@@ -16,7 +16,9 @@ import logging
from synapse.api.errors import SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_request
from synapse.util import unwrapFirstError, logcontext
from twisted.internet import defer
@@ -169,3 +171,28 @@ class FederationBase(object):
)
return deferreds
def event_from_pdu_json(pdu_json, outlier=False):
"""Construct a FrozenEvent from an event json received over federation
Args:
pdu_json (object): pdu as received over federation
outlier (bool): True to mark this event as an outlier
Returns:
FrozenEvent
Raises:
SynapseError: if the pdu is missing required fields
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
assert_params_in_request(pdu_json, ('event_id', 'type'))
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event

View File

@@ -14,28 +14,28 @@
# limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import NotRetryingDestination
import copy
import itertools
import logging
import random
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
)
from synapse.events import builder
from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
import synapse.metrics
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -184,7 +184,7 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
self.event_from_pdu_json(p, outlier=False)
event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
@@ -244,7 +244,7 @@ class FederationClient(FederationBase):
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
@@ -266,6 +266,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e:
logger.info(e.message)
continue
except FederationDeniedError as e:
logger.info(e.message)
continue
except Exception as e:
pdu_attempts[destination] = now
@@ -336,11 +339,11 @@ class FederationClient(FederationBase):
)
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
@@ -441,7 +444,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
@@ -570,12 +573,12 @@ class FederationClient(FederationBase):
logger.debug("Got content: %s", content)
state = [
self.event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
@@ -650,7 +653,7 @@ class FederationClient(FederationBase):
logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = self.event_from_pdu_json(pdu_dict)
pdu = event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
@@ -740,7 +743,7 @@ class FederationClient(FederationBase):
)
auth_chain = [
self.event_from_pdu_json(e)
event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -788,7 +791,7 @@ class FederationClient(FederationBase):
)
events = [
self.event_from_pdu_json(e)
event_from_pdu_json(e)
for e in content.get("events", [])
]
@@ -805,15 +808,6 @@ class FederationClient(FederationBase):
defer.returnValue(signed_events)
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:

View File

@@ -12,25 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util import async
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
import logging
import simplejson as json
import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
from synapse.federation.units import Edu, Transaction
import synapse.metrics
from synapse.types import get_domain_from_id
from synapse.util import async
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
@@ -172,7 +171,7 @@ class FederationServer(FederationBase):
p["age_ts"] = request_time - int(p["age"])
del p["age"]
event = self.event_from_pdu_json(p)
event = event_from_pdu_json(p)
room_id = event.room_id
pdus_by_room.setdefault(room_id, []).append(event)
@@ -346,7 +345,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content)
pdu = event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -354,7 +353,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
pdu = event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
@@ -374,7 +373,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content):
logger.debug("on_send_leave_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
pdu = event_from_pdu_json(content)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
defer.returnValue((200, {}))
@@ -411,7 +410,7 @@ class FederationServer(FederationBase):
"""
with (yield self._server_linearizer.queue((origin, room_id))):
auth_chain = [
self.event_from_pdu_json(e)
event_from_pdu_json(e)
for e in content["auth_chain"]
]
@@ -586,15 +585,6 @@ class FederationServer(FederationBase):
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks
def exchange_third_party_invite(
self,

View File

@@ -19,7 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException
from synapse.api.errors import HttpResponseException, FederationDeniedError
from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
sent_transactions_counter = client_metrics.register_counter("sent_transactions")
events_processed_counter = client_metrics.register_counter("events_processed")
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -205,6 +207,8 @@ class TransactionQueue(object):
self._send_pdu(event, destinations)
events_processed_counter.inc_by(len(events))
yield self.store.update_federation_out_pos(
"events", next_token
)
@@ -486,6 +490,8 @@ class TransactionQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
except FederationDeniedError as e:
logger.info(e)
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",

View File

@@ -212,6 +212,9 @@ class TransportLayerClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if the remote destination
is not in our federation whitelist
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:

View File

@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -81,6 +81,7 @@ class Authenticator(object):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
@@ -92,6 +93,12 @@ class Authenticator(object):
"signatures": {},
}
if (
self.federation_domain_whitelist is not None and
self.server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(self.server_name)
if content is not None:
json_request["content"] = content

View File

@@ -15,6 +15,7 @@
from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
@@ -23,6 +24,10 @@ import logging
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
events_processed_counter = metrics.register_counter("events_processed")
def log_failure(failure):
logger.error(
@@ -103,6 +108,8 @@ class ApplicationServicesHandler(object):
service, event
)
events_processed_counter.inc_by(len(events))
yield self.store.set_appservice_last_pos(upper_bound)
finally:
self.is_processing = False

View File

@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.internet import defer, threads
from ._base import BaseHandler
from synapse.api.constants import LoginType
@@ -25,6 +25,7 @@ from synapse.module_api import ModuleApi
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable
from twisted.web.client import PartialDownloadError
@@ -714,7 +715,7 @@ class AuthHandler(BaseHandler):
if not lookupres:
defer.returnValue(None)
(user_id, password_hash) = lookupres
result = self.validate_hash(password, password_hash)
result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
defer.returnValue(None)
@@ -842,10 +843,13 @@ class AuthHandler(BaseHandler):
password (str): Password to hash.
Returns:
Hashed password (str).
Deferred(str): Hashed password.
"""
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
def _do_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
return make_deferred_yieldable(threads.deferToThread(_do_hash))
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -855,13 +859,17 @@ class AuthHandler(BaseHandler):
stored_hash (str): Expected hash value.
Returns:
Whether self.hash(password) == stored_hash (bool).
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
if stored_hash:
def _do_validate_hash():
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf8')) == stored_hash
if stored_hash:
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
else:
return False
return defer.succeed(False)
class MacaroonGeneartor(object):

View File

@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.api.errors import FederationDeniedError
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -513,6 +514,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will
# eventually become consistent.
return
except FederationDeniedError as e:
logger.info(e)
return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later

View File

@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
from synapse.types import get_domain_from_id
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id, UserID
from synapse.util.stringutils import random_string
@@ -33,7 +34,7 @@ class DeviceMessageHandler(object):
"""
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.is_mine = hs.is_mine
self.federation = hs.get_federation_sender()
hs.get_replication_layer().register_edu_handler(
@@ -52,6 +53,12 @@ class DeviceMessageHandler(object):
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
messages_by_device = {
device_id: {
"content": message_content,
@@ -77,7 +84,8 @@ class DeviceMessageHandler(object):
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
if self.is_mine_id(user_id):
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
device_id: {
"content": message_content,

View File

@@ -34,6 +34,7 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
@@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Aliases,
@@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str:
return
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,

View File

@@ -19,8 +19,10 @@ import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
from synapse.api.errors import (
SynapseError, CodeMessageException, FederationDeniedError,
)
from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
@@ -32,7 +34,7 @@ class E2eKeysHandler(object):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.device_handler = hs.get_device_handler()
self.is_mine_id = hs.is_mine_id
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
# doesn't really work as part of the generic query API, because the
@@ -70,7 +72,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_ids in device_keys_query.items():
if self.is_mine_id(user_id):
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
local_query[user_id] = device_ids
else:
remote_queries[user_id] = device_ids
@@ -139,6 +142,10 @@ class E2eKeysHandler(object):
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
except FederationDeniedError as e:
failures[destination] = {
"status": 403, "message": "Federation Denied",
}
except Exception as e:
# include ConnectionRefused and other errors
failures[destination] = {
@@ -170,7 +177,8 @@ class E2eKeysHandler(object):
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s",
user_id)
raise SynapseError(400, "Not a user here")
@@ -213,7 +221,8 @@ class E2eKeysHandler(object):
remote_queries = {}
for user_id, device_keys in query.get("one_time_keys", {}).items():
if self.is_mine_id(user_id):
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +23,7 @@ from ._base import BaseHandler
from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
FederationDeniedError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
@@ -74,6 +76,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self)
@@ -782,6 +785,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e:
logger.info(e.message)
continue
except FederationDeniedError as e:
logger.info(e)
continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
@@ -804,13 +810,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
self.state_handler.resolve_state_groups_for_events
)
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
room_id, [e]
)
for e in event_ids
], consumeErrors=True,
[resolve(room_id, [e]) for e in event_ids],
consumeErrors=True,
))
states = dict(zip(event_ids, [s.state for s in states]))
@@ -1004,8 +1009,7 @@ class FederationHandler(BaseHandler):
})
try:
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
except AuthError as e:
@@ -1245,8 +1249,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -1828,8 +1831,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
self._update_context_for_auth_events(
context, auth_events, event_key,
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
if different_auth and not event.internal_metadata.is_outlier():
@@ -1910,8 +1913,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
self._update_context_for_auth_events(
context, auth_events, event_key,
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
try:
@@ -1920,11 +1923,15 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
def _update_context_for_auth_events(self, context, auth_events,
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
event_key):
"""Update the state_ids in an event context after auth event resolution
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args:
event (Event): The event we're handling the context for
context (synapse.events.snapshot.EventContext): event context
to be updated
@@ -1947,7 +1954,13 @@ class FederationHandler(BaseHandler):
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems()
})
context.state_group = self.store.get_next_state_group()
context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
@@ -2117,8 +2130,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -2156,8 +2168,7 @@ class FederationHandler(BaseHandler):
"""
builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -2207,8 +2218,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler
event, context = yield message_handler._create_new_client_event(builder=builder)
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
defer.returnValue((event, context))
@defer.inlineCallbacks

View File

@@ -383,11 +383,12 @@ class GroupsLocalHandler(object):
defer.returnValue({"groups": result})
else:
result = yield self.transport_client.get_publicised_groups_for_user(
get_domain_from_id(user_id), user_id
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id],
)
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
defer.returnValue(result)
defer.returnValue({"groups": result})
@defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True):

View File

@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
# Copyright 2017 - 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -47,23 +47,11 @@ class MessageHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock()
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
def purge_history(self, room_id, event_id, delete_local_events=False):
event = yield self.store.get_event(event_id)
if event.room_id != room_id:
@@ -72,7 +60,7 @@ class MessageHandler(BaseHandler):
depth = event.depth
with (yield self.pagination_lock.write(room_id)):
yield self.store.delete_old_state(room_id, depth)
yield self.store.purge_history(room_id, depth, delete_local_events)
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
@@ -182,166 +170,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
yield self.ratelimit(requester, update=False)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key="", is_guest=False):
@@ -470,9 +298,192 @@ class MessageHandler(BaseHandler):
for user_id, profile in users_with_profile.iteritems()
})
@measure_func("_create_new_client_event")
class EventCreationHandler(object):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@@ -509,9 +520,7 @@ class MessageHandler(BaseHandler):
builder.prev_events = prev_events
builder.depth = depth
state_handler = self.state_handler
context = yield state_handler.compute_event_context(builder)
context = yield self.state.compute_event_context(builder)
if requester:
context.app_service = requester.app_service
@@ -551,7 +560,7 @@ class MessageHandler(BaseHandler):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
yield self.ratelimit(requester)
yield self.base_handler.ratelimit(requester)
try:
yield self.auth.check_from_context(event, context)
@@ -567,7 +576,7 @@ class MessageHandler(BaseHandler):
logger.exception("Failed to encode content: %r", event.content)
raise
yield self.maybe_kick_guest_users(event, context)
yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)

View File

@@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
from synapse import types
from synapse.types import UserID
from synapse.util.async import run_on_reactor
from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -131,7 +132,7 @@ class RegistrationHandler(BaseHandler):
yield run_on_reactor()
password_hash = None
if password:
password_hash = self.auth_handler().hash(password)
password_hash = yield self.auth_handler().hash(password)
if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s",
logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
@@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address'])
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
raise RegistrationError(
403, "Third party identifier is not allowed"
)
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True):
@@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {})
msg_handler = self.hs.get_handlers().message_handler
room_member_handler = self.hs.get_handlers().room_member_handler
yield self._send_events_for_new_room(
requester,
room_id,
msg_handler,
room_member_handler,
preset_config=preset_config,
invite_list=invite_list,
@@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
yield msg_handler.create_and_send_nonmember_event(
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
yield msg_handler.create_and_send_nonmember_event(
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
self,
creator, # A Requester object.
room_id,
msg_handler,
room_member_handler,
preset_config,
invite_list,
@@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
yield msg_handler.create_and_send_nonmember_event(
yield self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
ratelimit=False

View File

@@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
if limit:
step = limit + 1
else:
step = len(rooms_to_scan)
# step cannot be zero
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = []
for i in xrange(0, len(rooms_to_scan), step):

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,6 +47,7 @@ class RoomMemberHandler(BaseHandler):
super(RoomMemberHandler, self).__init__(hs)
self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
@@ -66,13 +68,12 @@ class RoomMemberHandler(BaseHandler):
):
if content is None:
content = {}
msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
event, context = yield self.event_creation_hander.create_event(
requester,
{
"type": EventTypes.Member,
@@ -90,12 +91,14 @@ class RoomMemberHandler(BaseHandler):
)
# Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context)
duplicate = yield self.event_creation_hander.deduplicate_state_event(
event, context,
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event(
yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -394,8 +397,9 @@ class RoomMemberHandler(BaseHandler):
else:
requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler
prev_event = yield message_handler.deduplicate_state_event(event, context)
prev_event = yield self.event_creation_hander.deduplicate_state_event(
event, context,
)
if prev_event is not None:
return
@@ -412,7 +416,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
yield message_handler.handle_new_client_event(
yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -644,8 +648,7 @@ class RoomMemberHandler(BaseHandler):
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,

View File

@@ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler):
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self._auth_handler.hash(newpassword)
password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
except_access_token_id = requester.access_token_id if requester else None

View File

@@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext
import synapse.metrics
@@ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError,
HTTPConnectionPool,
)
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss
@@ -64,13 +66,23 @@ class SimpleHttpClient(object):
"""
def __init__(self, hs):
self.hs = hs
pool = HTTPConnectionPool(reactor)
# the pusher makes lots of concurrent SSL connections to sygnal, and
# tends to do so in batches, so we need to allow the pool to keep lots
# of idle connections around.
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
self.agent = Agent(
reactor,
connectTimeout=15,
contextFactory=hs.get_http_client_context_factory()
contextFactory=hs.get_http_client_context_factory(),
pool=pool,
)
self.user_agent = hs.version_string
self.clock = hs.get_clock()

View File

@@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
def eb(res, record_type):
if res.check(DNSNameError):
return []
logger.warn("Error looking up %s for %s: %s",
record_type, host, res, res.value)
logger.warn("Error looking up %s for %s: %s", record_type, host, res)
return res
# no logcontexts here, so we can safely fire these off and gatherResults

View File

@@ -27,7 +27,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json
from synapse.api.errors import (
SynapseError, Codes, HttpResponseException,
SynapseError, Codes, HttpResponseException, FederationDeniedError,
)
from signedjson.sign import sign_json
@@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
(May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.)
"""
if (
self.hs.config.federation_domain_whitelist and
destination not in self.hs.config.federation_domain_whitelist
):
raise FederationDeniedError(destination)
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
@@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
if not json_data_callback:
@@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
@@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
logger.debug("get_json args: %s", args)
@@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
response = yield self._request(
@@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
encoded_args = {}

View File

@@ -37,41 +37,76 @@ import collections
import logging
import urllib
import ujson
import simplejson
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
incoming_requests_counter = metrics.register_counter(
"requests",
# total number of responses served, split by method/servlet/tag
response_count = metrics.register_counter(
"response_count",
labels=["method", "servlet", "tag"],
alternative_names=(
# the following are all deprecated aliases for the same metric
metrics.name_prefix + x for x in (
"_requests",
"_response_time:count",
"_response_ru_utime:count",
"_response_ru_stime:count",
"_response_db_txn_count:count",
"_response_db_txn_duration:count",
)
)
)
outgoing_responses_counter = metrics.register_counter(
"responses",
labels=["method", "code"],
)
response_timer = metrics.register_distribution(
"response_time",
labels=["method", "servlet", "tag"]
response_timer = metrics.register_counter(
"response_time_seconds",
labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_time:total",
),
)
response_ru_utime = metrics.register_distribution(
"response_ru_utime", labels=["method", "servlet", "tag"]
response_ru_utime = metrics.register_counter(
"response_ru_utime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_utime:total",
),
)
response_ru_stime = metrics.register_distribution(
"response_ru_stime", labels=["method", "servlet", "tag"]
response_ru_stime = metrics.register_counter(
"response_ru_stime_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_ru_stime:total",
),
)
response_db_txn_count = metrics.register_distribution(
"response_db_txn_count", labels=["method", "servlet", "tag"]
response_db_txn_count = metrics.register_counter(
"response_db_txn_count", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_count:total",
),
)
response_db_txn_duration = metrics.register_distribution(
"response_db_txn_duration", labels=["method", "servlet", "tag"]
# seconds spent waiting for db txns, excluding scheduling time, when processing
# this request
response_db_txn_duration = metrics.register_counter(
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
alternative_names=(
metrics.name_prefix + "_response_db_txn_duration:total",
),
)
# seconds spent waiting for a db connection, when processing this request
response_db_sched_duration = metrics.register_counter(
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
)
_next_request_id = 0
@@ -107,6 +142,10 @@ def wrap_request_handler(request_handler, include_metrics=False):
with LoggingContext(request_id) as request_context:
with Measure(self.clock, "wrapped_request_handler"):
request_metrics = RequestMetrics()
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
request_metrics.start(self.clock, name=self.__class__.__name__)
request_context.request = request_id
@@ -249,12 +288,23 @@ class JsonResource(HttpServer, resource.Resource):
if not m:
continue
# We found a match! Trigger callback and then return the
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
# We found a match! First update the metrics object to indicate
# which servlet is handling the request.
callback = path_entry.callback
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict({
name: urllib.unquote(value).decode("UTF-8") if value else value
for name, value in m.groupdict().items()
@@ -265,30 +315,14 @@ class JsonResource(HttpServer, resource.Resource):
code, response = callback_return
self._send_response(request, code, response)
servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
request_metrics.name = servlet_classname
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object,
response_code_message=None):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
"Not sending response to request %s, already disconnected.",
request)
return
outgoing_responses_counter.inc(request.method, str(code))
# TODO: Only enable CORS for the requests that need it.
@@ -322,7 +356,7 @@ class RequestMetrics(object):
)
return
incoming_requests_counter.inc(request.method, self.name, tag)
response_count.inc(request.method, self.name, tag)
response_timer.inc_by(
clock.time_msec() - self.start, request.method,
@@ -341,7 +375,10 @@ class RequestMetrics(object):
context.db_txn_count, request.method, self.name, tag
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, self.name, tag
context.db_txn_duration_ms / 1000., request.method, self.name, tag
)
response_db_sched_duration.inc_by(
context.db_sched_duration_ms / 1000., request.method, self.name, tag
)
@@ -364,6 +401,15 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string="", canonical_json=True):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
"Not sending response to request %s, already disconnected.",
request)
return
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
@@ -371,7 +417,7 @@ def respond_with_json(request, code, json_object, send_cors=False,
json_bytes = encode_canonical_json(json_object)
else:
# ujson doesn't like frozen_dicts.
json_bytes = ujson.dumps(json_object, ensure_ascii=False)
json_bytes = simplejson.dumps(json_object)
return respond_with_json_bytes(
request, code, json_bytes,

View File

@@ -148,11 +148,13 @@ def parse_string_from_args(args, name, default=None, required=False,
return default
def parse_json_value_from_request(request):
def parse_json_value_from_request(request, allow_empty_body=False):
"""Parse a JSON value from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into None
Returns:
The JSON value.
@@ -165,6 +167,9 @@ def parse_json_value_from_request(request):
except Exception:
raise SynapseError(400, "Error reading JSON content.")
if not content_bytes and allow_empty_body:
return None
try:
content = simplejson.loads(content_bytes)
except Exception as e:
@@ -174,17 +179,24 @@ def parse_json_value_from_request(request):
return content
def parse_json_object_from_request(request):
def parse_json_object_from_request(request, allow_empty_body=False):
"""Parse a JSON object from the body of a twisted HTTP request.
Args:
request: the twisted HTTP request.
allow_empty_body (bool): if True, an empty body will be accepted and
turned into an empty dict.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
content = parse_json_value_from_request(request)
content = parse_json_value_from_request(
request, allow_empty_body=allow_empty_body,
)
if allow_empty_body and content is None:
return {}
if type(content) != dict:
message = "Content must be a JSON object."

View File

@@ -66,14 +66,15 @@ class SynapseRequest(Request):
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration
db_txn_duration_ms = context.db_txn_duration_ms
db_sched_duration_ms = context.db_sched_duration_ms
except Exception:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0)
db_txn_count, db_txn_duration_ms = (0, 0)
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms (%dms, %dms) (%dms/%d)"
" Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
@@ -81,7 +82,8 @@ class SynapseRequest(Request):
int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
int(db_txn_duration * 1000),
db_sched_duration_ms,
db_txn_duration_ms,
int(db_txn_count),
self.sentLength,
self.code,

View File

@@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
num_pending += 1
num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
# record the amount of wallclock time spent running pending calls.
# This is a proxy for the actual amount of time between reactor polls,
# since about 25% of time is actually spent running things triggered by
# I/O events, but that is harder to capture without rewriting half the
# reactor.
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending)

View File

@@ -15,18 +15,38 @@
from itertools import chain
import logging
logger = logging.getLogger(__name__)
# TODO(paul): I can't believe Python doesn't have one of these
def map_concat(func, items):
# flatten a list-of-lists
return list(chain.from_iterable(map(func, items)))
def flatten(items):
"""Flatten a list of lists
Args:
items: iterable[iterable[X]]
Returns:
list[X]: flattened list
"""
return list(chain.from_iterable(items))
class BaseMetric(object):
"""Base class for metrics which report a single value per label set
"""
def __init__(self, name, labels=[]):
self.name = name
def __init__(self, name, labels=[], alternative_names=[]):
"""
Args:
name (str): principal name for this metric
labels (list(str)): names of the labels which will be reported
for this metric
alternative_names (iterable(str)): list of alternative names for
this metric. This can be useful to provide a migration path
when renaming metrics.
"""
self._names = [name] + list(alternative_names)
self.labels = labels # OK not to clone as we never write it
def dimension(self):
@@ -36,7 +56,7 @@ class BaseMetric(object):
return not len(self.labels)
def _render_labelvalue(self, value):
# TODO: some kind of value escape
# TODO: escape backslashes, quotes and newlines
return '"%s"' % (value)
def _render_key(self, values):
@@ -47,19 +67,60 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)])
)
def _render_for_labels(self, label_values, value):
"""Render this metric for a single set of labels
Args:
label_values (list[str]): values for each of the labels
value: value of the metric at with these labels
Returns:
iterable[str]: rendered metric
"""
rendered_labels = self._render_key(label_values)
return (
"%s%s %.12g" % (name, rendered_labels, value)
for name in self._names
)
def render(self):
"""Render this metric
Each metric is rendered as:
name{label1="val1",label2="val2"} value
https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
Returns:
iterable[str]: rendered metrics
"""
raise NotImplementedError()
class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing
integer that counts events."""
value that counts events or running totals.
Example use cases for Counters:
- Number of requests processed
- Number of items that were inserted into a queue
- Total amount of data that a system has processed
Counters can only go up (and be reset when the process restarts).
"""
def __init__(self, *args, **kwargs):
super(CounterMetric, self).__init__(*args, **kwargs)
# dict[list[str]]: value for each set of label values. the keys are the
# label values, in the same order as the labels in self.labels.
#
# (if the metric is a scalar, the (single) key is the empty list).
self.counts = {}
# Scalar metrics are never empty
if self.is_scalar():
self.counts[()] = 0
self.counts[()] = 0.
def inc_by(self, incr, *values):
if len(values) != self.dimension():
@@ -77,11 +138,11 @@ class CounterMetric(BaseMetric):
def inc(self, *values):
self.inc_by(1, *values)
def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
return flatten(
self._render_for_labels(k, self.counts[k])
for k in sorted(self.counts.keys())
)
class CallbackMetric(BaseMetric):
@@ -95,13 +156,19 @@ class CallbackMetric(BaseMetric):
self.callback = callback
def render(self):
value = self.callback()
try:
value = self.callback()
except Exception:
logger.exception("Failed to render %s", self.name)
return ["# FAILED to render " + self.name]
if self.is_scalar():
return ["%s %.12g" % (self.name, value)]
return list(self._render_for_labels([], value))
return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
for k in sorted(value.keys())]
return flatten(
self._render_for_labels(k, value[k])
for k in sorted(value.keys())
)
class DistributionMetric(object):
@@ -126,7 +193,9 @@ class DistributionMetric(object):
class CacheMetric(object):
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
__slots__ = (
"name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
)
def __init__(self, name, size_callback, cache_name):
self.name = name
@@ -134,6 +203,7 @@ class CacheMetric(object):
self.hits = 0
self.misses = 0
self.evicted_size = 0
self.size_callback = size_callback
@@ -143,6 +213,9 @@ class CacheMetric(object):
def inc_misses(self):
self.misses += 1
def inc_evictions(self, size=1):
self.evicted_size += size
def render(self):
size = self.size_callback()
hits = self.hits
@@ -152,6 +225,9 @@ class CacheMetric(object):
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
"""%s:evicted_size{name="%s"} %d""" % (
self.name, self.cache_name, self.evicted_size
),
]

View File

@@ -13,21 +13,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push import PusherConfigException
import logging
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging
import push_rule_evaluator
import push_tools
import synapse
from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
http_push_processed_counter = metrics.register_counter(
"http_pushes_processed",
)
http_push_failed_counter = metrics.register_counter(
"http_pushes_failed",
)
class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
@@ -152,9 +161,16 @@ class HttpPusher(object):
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
logger.info(
"Processing %i unprocessed push actions for %s starting at "
"stream_ordering %s",
len(unprocessed), self.name, self.last_stream_ordering,
)
for push_action in unprocessed:
processed = yield self._process_one(push_action)
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success(
@@ -169,6 +185,7 @@ class HttpPusher(object):
self.failing_since
)
else:
http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
@@ -316,7 +333,10 @@ class HttpPusher(object):
try:
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
except Exception:
logger.warn("Failed to push %s ", self.url)
logger.warn(
"Failed to push event %s to %s",
event.event_id, self.name, exc_info=True,
)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
@@ -325,7 +345,7 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
logger.info("Sending updated badge count %d to %s", badge, self.name)
d = {
'notification': {
'id': '',
@@ -347,7 +367,10 @@ class HttpPusher(object):
try:
resp = yield self.http_client.post_json_get_json(self.url, d)
except Exception:
logger.exception("Failed to push %s ", self.url)
logger.warn(
"Failed to send badge count to %s",
self.name, exc_info=True,
)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:

View File

@@ -36,7 +36,7 @@ REQUIREMENTS = {
"pydenticon": ["pydenticon"],
"ujson": ["ujson"],
"blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pysaml2>=3.0.0": ["saml2>=3.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],

View File

@@ -19,7 +19,7 @@ from synapse.storage import DataStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.state import StateGroupReadStore
from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)

View File

@@ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_error("Wrong remote")
def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.inc(stream_name)
try:
row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
except Exception:
logger.exception(
"[%s] Failed to parse RDATA: %r %r",
self.id(), cmd.stream_name, cmd.row
self.id(), stream_name, cmd.row
)
raise
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream. Batch
# until we get an update for the stream with a non None token
self.pending_batches.setdefault(cmd.stream_name, []).append(row)
self.pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(cmd.stream_name, [])
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
self.handler.on_position(cmd.stream_name, cmd.token)
@@ -644,3 +647,9 @@ metrics.register_callback(
},
labels=["command", "name", "conn_id"],
)
# number of updates received for each RDATA stream
inbound_rdata_count = metrics.register_counter(
"inbound_rdata_count",
labels=["stream_name"],
)

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -128,7 +129,14 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
yield self.handlers.message_handler.purge_history(room_id, event_id)
body = parse_json_object_from_request(request, allow_empty_body=True)
delete_local_events = bool(body.get("delete_local_events", False))
yield self.handlers.message_handler.purge_history(
room_id, event_id,
delete_local_events=delete_local_events,
)
defer.returnValue((200, {}))
@@ -171,6 +179,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
@@ -203,8 +212,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
)
new_room_id = info["room_id"]
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_nonmember_event(
yield self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester,
{
"type": "m.room.message",
@@ -289,6 +297,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined}))
class ListMediaInRoom(ClientV1RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
super(ListMediaInRoom, self).__init__(hs)
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
@@ -487,3 +516,4 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)

View File

@@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
if 'medium' not in identifier or 'address' not in identifier:
address = identifier.get('address')
medium = identifier.get('medium')
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
address = identifier['address']
if identifier['medium'] == 'email':
if medium == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
identifier['medium'], address
medium, address,
)
if not user_id:
logger.warn(
"unknown 3pid identifier medium %s, address %r",
medium, address,
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {

View File

@@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
self.handlers = hs.get_handlers()
def on_GET(self, request):
require_email = 'email' in self.hs.config.registrations_require_3pid
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
flows = []
if self.hs.config.enable_registration_captcha:
return (
200,
{"flows": [
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [
@@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD
]
},
])
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
]}
)
])
else:
return (
200,
{"flows": [
# only support the email-only flow if we don't require MSISDN 3PIDs
if require_email or not require_msisdn:
flows.extend([
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
},
}
])
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([
{
"type": LoginType.PASSWORD
}
]}
)
])
return (200, {"flows": flows})
@defer.inlineCallbacks
def on_POST(self, request):

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -82,6 +83,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -162,15 +164,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
msg_handler = self.handlers.message_handler
event, context = yield msg_handler.create_event(
event, context = yield self.event_creation_hander.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
)
yield msg_handler.send_nonmember_event(requester, event, context)
yield self.event_creation_hander.send_nonmember_event(
requester, event, context,
)
ret = {}
if event:
@@ -184,6 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -195,15 +199,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event(
event_dict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
}
if 'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
{
"type": event_type,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
},
event_dict,
txn_id=txn_id,
)
@@ -487,13 +495,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
class RoomEventContext(ClientV1RestServlet):
class RoomEventServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomEventServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
event = yield self.event_handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:
defer.returnValue((200, serialize_event(event, time_now)))
else:
defer.returnValue((404, "Event not found."))
class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@@ -643,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -653,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_nonmember_event(
event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -803,4 +833,5 @@ def register_servlets(hs, http_server):
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
RoomEventContext(hs).register(http_server)
RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)

View File

@@ -26,6 +26,7 @@ from synapse.http.servlet import (
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
)
@@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)

View File

@@ -26,6 +26,7 @@ from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
)
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
@@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
require_email = 'email' in self.hs.config.registrations_require_3pid
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
flows = []
if self.hs.config.enable_registration_captcha:
flows = [
[LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
]
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([[LoginType.RECAPTCHA]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
# always let users provide both MSISDN & email
flows.extend([
[LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else:
flows = [
[LoginType.DUMMY],
[LoginType.EMAIL_IDENTITY],
]
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
flows.extend([[LoginType.DUMMY]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
flows.extend([[LoginType.EMAIL_IDENTITY]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email or require_msisdn:
flows.extend([[LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([
[LoginType.MSISDN],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
])
auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
# Check that we're not trying to register a denied 3pid.
#
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
medium = auth_result[login_type]['medium']
address = auth_result[login_type]['address']
if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
403, "Third party identifier is not allowed",
Codes.THREEPID_DENIED,
)
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",

View File

@@ -33,7 +33,7 @@ from ._base import set_timeline_upper_limit
import itertools
import logging
import ujson as json
import simplejson as json
logger = logging.getLogger(__name__)

View File

@@ -93,6 +93,7 @@ class RemoteKey(Resource):
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request):
self.async_render_GET(request)
@@ -137,6 +138,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
logger.debug("Federation denied with %s", server_name)
continue
if not key_ids:
key_ids = (None,)
for key_id in key_ids:

View File

@@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
yield logcontext.make_deferred_yieldable(
@@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
finish_request(request)
else:
respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name):
"""Adds the correct response headers in preparation for responding with the
media.
Args:
request (twisted.web.http.Request)
media_type (str): The media/content type.
file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any.
"""
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
@defer.inlineCallbacks
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
request (twisted.web.http.Request)
responder (Responder|None)
media_type (str): The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
"""
if not responder:
respond_404(request)
return
add_file_headers(request, media_type, file_size, upload_name)
with responder:
yield responder.write_to_consumer(request)
finish_request(request)
class Responder(object):
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
def write_to_consumer(self, consumer):
"""Stream response into consumer
Args:
consumer (IConsumer)
Returns:
Deferred: Resolves once the response has finished being written
"""
pass
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
pass
class FileInfo(object):
"""Details about a requested/uploaded file.
Attributes:
server_name (str): The server name where the media originated from,
or None if local.
file_id (str): The local ID of the file. For local files this is the
same as the media_id
url_cache (bool): If the file is for the url preview cache
thumbnail (bool): Whether the file is a thumbnail or not.
thumbnail_width (int)
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
"""
def __init__(self, server_name, file_id, url_cache=False,
thumbnail=False, thumbnail_width=None, thumbnail_height=None,
thumbnail_method=None, thumbnail_type=None):
self.server_name = server_name
self.file_id = file_id
self.url_cache = url_cache
self.thumbnail = thumbnail
self.thumbnail_width = thumbnail_width
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import synapse.http.servlet
from ._base import parse_media_id, respond_with_file, respond_404
from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers
@@ -32,12 +32,12 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.version_string = hs.version_string
# Both of these are expected by @request_handler()
self.clock = hs.get_clock()
self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
@@ -57,59 +57,16 @@ class DownloadResource(Resource):
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
yield self._respond_local_file(request, media_id, name)
yield self.media_repo.get_local_media(request, media_id, name)
else:
yield self._respond_remote_file(
request, server_name, media_id, name
)
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
@defer.inlineCallbacks
def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_filepath(media_id)
else:
file_path = self.filepaths.local_media_filepath(media_id)
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)
@defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name):
# don't forward requests for remote media if allow_remote is false
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
server_name, media_id,
)
respond_404(request)
return
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"]
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id
)
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)
yield self.media_repo.get_remote_media(request, server_name, media_id, name)

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
from ._base import respond_404, FileInfo, respond_with_responder
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
@@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
from .thumbnailer import Thumbnailer
from .storage_provider import StorageProviderWrapper
from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError, HttpResponseException, \
NotFoundError
from synapse.api.errors import (
SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
)
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
import os
@@ -47,7 +52,7 @@ import urlparse
logger = logging.getLogger(__name__)
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
@@ -63,96 +68,62 @@ class MediaRepository(object):
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
self.backup_base_path = hs.config.backup_media_store_path
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
self.recently_accessed_locals = set()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
backend = clz(hs, provider_config)
provider = StorageProviderWrapper(
backend,
store_local=wrapper_config.store_local,
store_remote=wrapper_config.store_remote,
store_synchronous=wrapper_config.store_synchronous,
)
storage_providers.append(provider)
self.media_storage = MediaStorage(
self.primary_base_path, self.filepaths, storage_providers,
)
self.clock.looping_call(
self._update_recently_accessed_remotes,
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
self._update_recently_accessed,
UPDATE_RECENTLY_ACCESSED_TS,
)
@defer.inlineCallbacks
def _update_recently_accessed_remotes(self):
media = self.recently_accessed_remotes
def _update_recently_accessed(self):
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
local_media = self.recently_accessed_locals
self.recently_accessed_locals = set()
yield self.store.update_cached_last_access_time(
media, self.clock.time_msec()
local_media, remote_media, self.clock.time_msec()
)
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
@staticmethod
def _write_file_synchronously(source, fname):
"""Write `source` to the path `fname` synchronously. Should be called
from a thread.
def mark_recently_accessed(self, server_name, media_id):
"""Mark the given media as recently accessed.
Args:
source: A file like object to be written
fname (str): Path to write to
server_name (str|None): Origin server of media, or None if local
media_id (str): The media ID of the content
"""
MediaRepository._makedirs(fname)
source.seek(0) # Ensure we read from the start of the file
with open(fname, "wb") as f:
shutil.copyfileobj(source, f)
@defer.inlineCallbacks
def write_to_file_and_backup(self, source, path):
"""Write `source` to the on disk media store, and also the backup store
if configured.
Args:
source: A file like object that should be written
path (str): Relative path to write file to
Returns:
Deferred[str]: the file path written to in the primary media store
"""
fname = os.path.join(self.primary_base_path, path)
# Write to the main repository
yield make_deferred_yieldable(threads.deferToThread(
self._write_file_synchronously, source, fname,
))
# Write to backup repository
yield self.copy_to_backup(path)
defer.returnValue(fname)
@defer.inlineCallbacks
def copy_to_backup(self, path):
"""Copy a file from the primary to backup media store, if configured.
Args:
path(str): Relative path to write file to
"""
if self.backup_base_path:
primary_fname = os.path.join(self.primary_base_path, path)
backup_fname = os.path.join(self.backup_base_path, path)
# We can either wait for successful writing to the backup repository
# or write in the background and immediately return
if self.synchronous_backup_media_store:
yield make_deferred_yieldable(threads.deferToThread(
shutil.copyfile, primary_fname, backup_fname,
))
else:
preserve_fn(threads.deferToThread)(
shutil.copyfile, primary_fname, backup_fname,
)
if server_name:
self.recently_accessed_remotes.add((server_name, media_id))
else:
self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
@@ -171,10 +142,13 @@ class MediaRepository(object):
"""
media_id = random_string(24)
fname = yield self.write_to_file_and_backup(
content, self.filepaths.local_media_filepath_rel(media_id)
file_info = FileInfo(
server_name=None,
file_id=media_id,
)
fname = yield self.media_storage.store_file(content, file_info)
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
@@ -185,134 +159,275 @@ class MediaRepository(object):
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_thumbnails(None, media_id, media_info)
yield self._generate_thumbnails(
None, media_id, media_id, media_type,
)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
def get_remote_media(self, server_name, media_id):
def get_local_media(self, request, media_id, name):
"""Responds to reqests for local media, if exists, or returns 404.
Args:
request(twisted.web.http.Request)
media_id (str): The media ID of the content. (This is the same as
the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
"""
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
respond_404(request)
return
self.mark_recently_accessed(None, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
file_info = FileInfo(
None, media_id,
url_cache=url_cache,
)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
request, responder, media_type, media_length, upload_name,
)
@defer.inlineCallbacks
def get_remote_media(self, request, server_name, media_id, name):
"""Respond to requests for remote media.
Args:
request(twisted.web.http.Request)
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
"""
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
self.mark_recently_accessed(server_name, media_id)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
media_info = yield self._get_remote_media_impl(server_name, media_id)
responder, media_info = yield self._get_remote_media_impl(
server_name, media_id,
)
# We deliberately stream the file outside the lock
if responder:
media_type = media_info["media_type"]
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
request, responder, media_type, media_length, upload_name,
)
else:
respond_404(request)
@defer.inlineCallbacks
def get_remote_media_info(self, server_name, media_id):
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
Returns:
Deferred[dict]: The media_info of the file
"""
if (
self.federation_domain_whitelist is not None and
server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
server_name, media_id,
)
# Ensure we actually use the responder so that it releases resources
if responder:
with responder:
pass
defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
"""Looks for media in local cache, if not there then attempt to
download from remote server.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
Returns:
Deferred[(Responder, media_info)]
"""
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
elif media_info["quarantined_by"]:
raise NotFoundError()
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
# one.
if media_info:
file_id = media_info["filesystem_id"]
else:
self.recently_accessed_remotes.add((server_name, media_id))
yield self.store.update_cached_last_access_time(
[(server_name, media_id)], self.clock.time_msec()
)
defer.returnValue(media_info)
file_id = random_string(24)
file_info = FileInfo(server_name, file_id)
# If we have an entry in the DB, try and look for it
if media_info:
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
raise NotFoundError()
responder = yield self.media_storage.fetch_media(file_info)
if responder:
defer.returnValue((responder, media_info))
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(
server_name, media_id, file_id
)
responder = yield self.media_storage.fetch_media(file_info)
defer.returnValue((responder, media_info))
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
def _download_remote_file(self, server_name, media_id, file_id):
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
fpath = self.filepaths.remote_media_filepath_rel(
server_name, file_id
Args:
server_name (str): Originating server
media_id (str): The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id (str): Local file ID
Returns:
Deferred[MediaInfo]
"""
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
)
fname = os.path.join(self.primary_base_path, fpath)
self._makedirs(fname)
try:
with open(fname, "wb") as f:
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
try:
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size, args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
}
)
except twisted.internet.error.DNSLookupError as e:
logger.warn("HTTP error fetching remote media %s/%s: %r",
server_name, media_id, e)
raise NotFoundError()
except HttpResponseException as e:
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise
except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size, args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
"allow_remote": "false",
}
)
except twisted.internet.error.DNSLookupError as e:
logger.warn("HTTP error fetching remote media %s/%s: %r",
server_name, media_id, e)
raise NotFoundError()
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
except HttpResponseException as e:
logger.warn("HTTP error fetching remote media %s/%s: %s",
server_name, media_id, e.response)
if e.code == twisted.web.http.NOT_FOUND:
raise SynapseError.from_http_response_exception(e)
raise SynapseError(502, "Failed to fetch remote media")
logger.info("Stored remote media in file %r", fname)
except SynapseError:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise
except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception("Failed to fetch remote media %s/%s",
server_name, media_id)
raise SynapseError(502, "Failed to fetch remote media")
yield self.copy_to_backup(fpath)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except Exception:
os.remove(fname)
raise
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
media_info = {
"media_type": media_type,
@@ -323,7 +438,7 @@ class MediaRepository(object):
}
yield self._generate_thumbnails(
server_name, media_id, media_info
server_name, media_id, file_id, media_type,
)
defer.returnValue(media_info)
@@ -357,8 +472,10 @@ class MediaRepository(object):
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id)
t_method, t_type, url_cache):
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
None, media_id, url_cache=url_cache,
))
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -368,11 +485,19 @@ class MediaRepository(object):
if t_byte_source:
try:
output_path = yield self.write_to_file_and_backup(
t_byte_source,
self.filepaths.local_media_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=url_cache,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -390,7 +515,9 @@ class MediaRepository(object):
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
server_name, file_id, url_cache=False,
))
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -400,11 +527,18 @@ class MediaRepository(object):
if t_byte_source:
try:
output_path = yield self.write_to_file_and_backup(
t_byte_source,
self.filepaths.remote_media_thumbnail_rel(
server_name, file_id, t_width, t_height, t_type, t_method
)
file_info = FileInfo(
server_name=server_name,
file_id=media_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -421,31 +555,29 @@ class MediaRepository(object):
defer.returnValue(output_path)
@defer.inlineCallbacks
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
url_cache=False):
"""Generate and store thumbnails for an image.
Args:
server_name(str|None): The server name if remote media, else None if local
media_id(str)
media_info(dict)
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
server_name (str|None): The server name if remote media, else None if local
media_id (str): The media ID of the content. (This is the same as
the file_id for local content)
file_id (str): Local file ID
media_type (str): The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
"""
media_type = media_info["media_type"]
file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
if server_name:
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
elif url_cache:
input_path = self.filepaths.url_cache_filepath(media_id)
else:
input_path = self.filepaths.local_media_filepath(media_id)
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
server_name, file_id, url_cache=url_cache,
))
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
@@ -472,20 +604,6 @@ class MediaRepository(object):
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
# Work out the correct file name for thumbnail
if server_name:
file_path = self.filepaths.remote_media_thumbnail_rel(
server_name, file_id, t_width, t_height, t_type, t_method
)
elif url_cache:
file_path = self.filepaths.url_cache_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
else:
file_path = self.filepaths.local_media_thumbnail_rel(
media_id, t_width, t_height, t_type, t_method
)
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -505,9 +623,19 @@ class MediaRepository(object):
continue
try:
# Write to disk
output_path = yield self.write_to_file_and_backup(
t_byte_source, file_path,
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
thumbnail_method=t_method,
thumbnail_type=t_type,
url_cache=url_cache,
)
output_path = yield self.media_storage.store_file(
t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -620,7 +748,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(
hs, media_repo, media_repo.media_storage,
))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
self.putChild("preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage,
))

View File

@@ -0,0 +1,274 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vecotr Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, threads
from twisted.protocols.basic import FileSender
from ._base import Responder
from synapse.util.file_consumer import BackgroundFileConsumer
from synapse.util.logcontext import make_deferred_yieldable
import contextlib
import os
import logging
import shutil
import sys
logger = logging.getLogger(__name__)
class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files.
"""
def __init__(self, local_media_directory, filepaths, storage_providers):
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
@defer.inlineCallbacks
def store_file(self, source, file_info):
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
file_info (FileInfo): Info about the file to store
Returns:
Deferred[str]: the file path written to in the primary media store
"""
path = self._file_info_to_path(file_info)
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
# Write to the main repository
yield make_deferred_yieldable(threads.deferToThread(
_write_file_synchronously, source, fname,
))
# Tell the storage providers about the new file. They'll decide
# if they should upload it and whether to do so synchronously
# or not.
for provider in self.storage_providers:
yield provider.store_file(path, file_info)
defer.returnValue(fname)
@contextlib.contextmanager
def store_into_file(self, file_info):
"""Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns a Deferred.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
finish_cb must be called and waited on after the file has been
successfully been written to. Should not be called if there was an
error.
Args:
file_info (FileInfo): Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
yield finish_cb()
"""
path = self._file_info_to_path(file_info)
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
finished_called = [False]
@defer.inlineCallbacks
def finish():
for provider in self.storage_providers:
yield provider.store_file(path, file_info)
finished_called[0] = True
try:
with open(fname, "wb") as f:
yield f, fname, finish
except Exception:
t, v, tb = sys.exc_info()
try:
os.remove(fname)
except Exception:
pass
raise t, v, tb
if not finished_called:
raise Exception("Finished callback not called")
@defer.inlineCallbacks
def fetch_media(self, file_info):
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
file_info (FileInfo)
Returns:
Deferred[Responder|None]: Returns a Responder if the file was found,
otherwise None.
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
defer.returnValue(FileResponder(open(local_path, "rb")))
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
defer.returnValue(res)
defer.returnValue(None)
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
file_info (FileInfo)
Returns:
Deferred[str]: Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
defer.returnValue(local_path)
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(open(local_path, "w"))
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)
raise Exception("file could not be found")
def _file_info_to_path(self, file_info):
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
Args:
file_info (FileInfo)
Returns:
str
"""
if file_info.url_cache:
if file_info.thumbnail:
return self.filepaths.url_cache_thumbnail_rel(
media_id=file_info.file_id,
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
method=file_info.thumbnail_method,
)
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
if file_info.server_name:
if file_info.thumbnail:
return self.filepaths.remote_media_thumbnail_rel(
server_name=file_info.server_name,
file_id=file_info.file_id,
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
method=file_info.thumbnail_method
)
return self.filepaths.remote_media_filepath_rel(
file_info.server_name, file_info.file_id,
)
if file_info.thumbnail:
return self.filepaths.local_media_thumbnail_rel(
media_id=file_info.file_id,
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
method=file_info.thumbnail_method
)
return self.filepaths.local_media_filepath_rel(
file_info.file_id,
)
def _write_file_synchronously(source, fname):
"""Write `source` to the path `fname` synchronously. Should be called
from a thread.
Args:
source: A file like object to be written
fname (str): Path to write to
"""
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
source.seek(0) # Ensure we read from the start of the file
with open(fname, "wb") as f:
shutil.copyfileobj(source, f)
class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
open_file (file): A file like object to be streamed ot the client,
is closed when finished streaming.
"""
def __init__(self, open_file):
self.open_file = open_file
def write_to_consumer(self, consumer):
return FileSender().beginFileTransfer(self.open_file, consumer)
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()

View File

@@ -12,11 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cgi
import datetime
import errno
import fnmatch
import itertools
import logging
import os
import re
import shutil
import sys
import traceback
import ujson as json
import urlparse
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from twisted.web.resource import Resource
from ._base import FileInfo
from synapse.api.errors import (
SynapseError, Codes,
)
@@ -31,25 +46,13 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
import os
import re
import fnmatch
import cgi
import ujson as json
import urlparse
import itertools
import datetime
import errno
import shutil
import logging
logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.auth = hs.get_auth()
@@ -62,6 +65,7 @@ class PreviewUrlResource(Resource):
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@@ -182,8 +186,10 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']):
file_id = media_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
None, media_info['filesystem_id'], media_info, url_cache=True,
None, file_id, file_id, media_info["media_type"],
url_cache=True,
)
og = {
@@ -228,8 +234,10 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
file_id = image_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
None, image_info['filesystem_id'], image_info, url_cache=True,
None, file_id, file_id, image_info["media_type"],
url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -273,21 +281,34 @@ class PreviewUrlResource(Resource):
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
fpath = self.filepaths.url_cache_filepath_rel(file_id)
fname = os.path.join(self.primary_base_path, fpath)
self.media_repo._makedirs(fname)
file_info = FileInfo(
server_name=None,
file_id=file_id,
url_cache=True,
)
try:
with open(fname, "wb") as f:
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
logger.warn("Error downloading %s: %r", url, e)
raise SynapseError(
500, "Failed to download content: %s" % (
traceback.format_exception_only(sys.exc_type, e),
),
Codes.UNKNOWN,
)
yield finish()
yield self.media_repo.copy_to_backup(fpath)
media_type = headers["Content-Type"][0]
try:
if "Content-Type" in headers:
media_type = headers["Content-Type"][0]
else:
media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
@@ -327,11 +348,11 @@ class PreviewUrlResource(Resource):
)
except Exception as e:
os.remove(fname)
raise SynapseError(
500, ("Failed to download content: %s" % e),
Codes.UNKNOWN
)
logger.error("Error handling downloaded %s: %r", url, e)
# TODO: we really ought to delete the downloaded file in this
# case, since we won't have recorded it in the db, and will
# therefore not expire it.
raise
defer.returnValue({
"media_type": media_type,

View File

@@ -0,0 +1,140 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, threads
from .media_storage import FileResponder
from synapse.config._base import Config
from synapse.util.logcontext import preserve_fn
import logging
import os
import shutil
logger = logging.getLogger(__name__)
class StorageProvider(object):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
def store_file(self, path, file_info):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
Returns:
Deferred
"""
pass
def fetch(self, path, file_info):
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
Returns:
Deferred(Responder): Returns a Responder if the provider has the file,
otherwise returns None.
"""
pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
backend (StorageProvider)
store_local (bool): Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully
uploaded, or todo the upload in the backgroud.
store_remote (bool): Whether remote media should be uploaded
"""
def __init__(self, backend, store_local, store_synchronous, store_remote):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
self.store_remote = store_remote
def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
return defer.succeed(None)
if file_info.server_name and not self.store_remote:
return defer.succeed(None)
if self.store_synchronous:
return self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
preserve_fn(self.backend.store_file)(path, file_info)
return defer.succeed(None)
def fetch(self, path, file_info):
return self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
hs (HomeServer)
config: The config returned by `parse_config`.
"""
def __init__(self, hs, config):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
backup_fname = os.path.join(self.base_directory, path)
dirname = os.path.dirname(backup_fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
return threads.deferToThread(
shutil.copyfile, primary_fname, backup_fname,
)
def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
@staticmethod
def parse_config(config):
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.
The returned value is passed into the constructor.
In this case we only care about a single param, the directory, so let's
just pull that out.
"""
return Config.ensure_directory(config["directory"])

View File

@@ -14,7 +14,10 @@
# limitations under the License.
from ._base import parse_media_id, respond_404, respond_with_file
from ._base import (
parse_media_id, respond_404, respond_with_file, FileInfo,
respond_with_responder,
)
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
class ThumbnailResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.store = hs.get_datastore()
self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -75,20 +79,20 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type):
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
respond_404(request)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@@ -96,42 +100,39 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_thumbnail(
media_id, t_width, t_height, t_type, t_method,
)
else:
file_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method,
)
yield respond_with_file(request, t_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
file_info = FileInfo(
server_name=None, file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
desired_height, desired_method,
desired_type):
media_info = yield self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
if not media_info:
respond_404(request)
return
if media_info["quarantined_by"]:
logger.info("Media is quarantined")
respond_404(request)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
@@ -141,46 +142,43 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
if media_info["url_cache"]:
# TODO: Check the file still exists, if it doesn't we can redownload
# it from the url `media_info["url_cache"]`
file_path = self.filepaths.url_cache_thumbnail(
media_id, desired_width, desired_height, desired_type,
desired_method,
)
else:
file_path = self.filepaths.local_media_thumbnail(
media_id, desired_width, desired_height, desired_type,
desired_method,
)
yield respond_with_file(request, desired_type, file_path)
return
file_info = FileInfo(
server_name=None, file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=info["thumbnail_width"],
thumbnail_height=info["thumbnail_height"],
thumbnail_type=info["thumbnail_type"],
thumbnail_method=info["thumbnail_method"],
)
logger.debug("We don't have a local thumbnail of that size. Generating")
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
media_id, desired_width, desired_height, desired_method, desired_type
media_id, desired_width, desired_height, desired_method, desired_type,
url_cache=media_info["url_cache"],
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height,
desired_method, desired_type,
)
logger.warn("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -195,14 +193,24 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
file_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, desired_width, desired_height,
desired_type, desired_method,
file_info = FileInfo(
server_name=server_name, file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=info["thumbnail_width"],
thumbnail_height=info["thumbnail_height"],
thumbnail_type=info["thumbnail_type"],
thumbnail_method=info["thumbnail_method"],
)
yield respond_with_file(request, desired_type, file_path)
return
logger.debug("We don't have a local thumbnail of that size. Generating")
t_type = file_info.thumbnail_type
t_length = info["thumbnail_length"]
responder = yield self.media_storage.fetch_media(file_info)
if responder:
yield respond_with_responder(request, responder, t_type, t_length)
return
logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -213,22 +221,16 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height,
desired_method, desired_type,
)
logger.warn("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead.
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -238,59 +240,23 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
file_id = thumbnail_info["filesystem_id"]
file_info = FileInfo(
server_name=server_name, file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
file_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method,
)
yield respond_with_file(request, t_type, file_path, t_length)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(request, responder, t_type, t_length)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
)
@defer.inlineCallbacks
def _respond_default_thumbnail(self, request, media_info, width, height,
method, m_type):
# XXX: how is this meant to work? store.get_default_thumbnails
# appears to always return [] so won't this always 404?
media_type = media_info["media_type"]
top_level_type = media_type.split("/")[0]
sub_type = media_type.split("/")[-1].split(";")[0]
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, sub_type,
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, "_default",
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
"_default", "_default",
)
if not thumbnail_infos:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
return
thumbnail_info = self._select_thumbnail(
width, height, "crop", m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
t_length = thumbnail_info["thumbnail_length"]
file_path = self.filepaths.default_thumbnail(
top_level_type, sub_type, t_width, t_height, t_type, t_method,
)
yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):

View File

@@ -55,6 +55,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler
from synapse.handlers.message import EventCreationHandler
from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@@ -66,7 +67,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
)
from synapse.state import StateHandler
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
@@ -102,6 +103,7 @@ class HomeServer(object):
'v1auth',
'auth',
'state_handler',
'state_resolution_handler',
'presence_handler',
'sync_handler',
'typing_handler',
@@ -117,6 +119,7 @@ class HomeServer(object):
'application_service_handler',
'device_message_handler',
'profile_handler',
'event_creation_handler',
'deactivate_account_handler',
'set_password_handler',
'notifier',
@@ -224,6 +227,9 @@ class HomeServer(object):
def build_state_handler(self):
return StateHandler(self)
def build_state_resolution_handler(self):
return StateResolutionHandler(self)
def build_presence_handler(self):
return PresenceHandler(self)
@@ -272,6 +278,9 @@ class HomeServer(object):
def build_profile_handler(self):
return ProfileHandler(self)
def build_event_creation_handler(self):
return EventCreationHandler(self)
def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self)
@@ -307,6 +316,23 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
def get_db_conn(self, run_new_connection=True):
"""Makes a new connection to the database, skipping the db pool
Returns:
Connection: a connection object implementing the PEP-249 spec
"""
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of

View File

@@ -34,6 +34,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler:
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass

View File

@@ -58,7 +58,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
# dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
# otherwise, None otherwise?
self.state_group = state_group
self.prev_group = prev_group
@@ -81,31 +85,19 @@ class _StateCacheEntry(object):
class StateHandler(object):
""" Responsible for doing state conflict resolution.
"""Fetches bits of state from the stores, and does state resolution
where necessary
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.hs = hs
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
self._state_resolution_handler = hs.get_state_resolution_handler()
def start_caching(self):
logger.debug("start_caching")
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
# TODO: remove this shim
self._state_resolution_handler.start_caching()
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="",
@@ -127,7 +119,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
@@ -146,19 +138,27 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room
Args:
room_id (str):
latest_event_ids (iterable[str]|None): if given, the forward
extremities to resolve. If None, we look them up from the
database (via a cache)
Returns:
Deferred[dict[(str, str), str)]]: the state dict, mapping from
(event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
defer.returnValue(state.get((event_type, state_key)))
return
defer.returnValue(state)
@defer.inlineCallbacks
@@ -166,7 +166,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@@ -175,7 +175,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts)
@@ -183,8 +183,15 @@ class StateHandler(object):
def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
generates a new state group if necessary.
Args:
event (synapse.events.EventBase):
old_state (dict|None): The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns:
synapse.events.snapshot.EventContext:
"""
@@ -208,15 +215,22 @@ class StateHandler(object):
context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = []
context.state_group = self.store.get_next_state_group()
# We don't store state for outliers, so we don't generate a state
# froup for it.
context.state_group = None
defer.returnValue(context)
if old_state:
# We already have the state, so we don't need to calculate it.
# Let's just correctly fill out the context and create a
# new state group for it.
context = EventContext()
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
@@ -229,11 +243,19 @@ class StateHandler(object):
else:
context.current_state_ids = context.prev_state_ids
context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=context.current_state_ids,
)
context.prev_state_events = []
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups(
entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events],
)
@@ -242,7 +264,8 @@ class StateHandler(object):
context = EventContext()
context.prev_state_ids = curr_state
if event.is_state():
context.state_group = self.store.get_next_state_group()
# If this is a state event then we need to create a new state
# group for the state after this event.
key = (event.type, event.state_key)
if key in context.prev_state_ids:
@@ -253,38 +276,57 @@ class StateHandler(object):
context.current_state_ids[key] = event.event_id
if entry.state_group:
# If the state at the event has a state group assigned then
# we can use that as the prev group
context.prev_group = entry.state_group
context.delta_ids = {
key: event.event_id
}
elif entry.prev_group:
# If the state at the event only has a prev group, then we can
# use that as a prev group too.
context.prev_group = entry.prev_group
context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id
else:
if entry.state_group is None:
entry.state_group = self.store.get_next_state_group()
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
)
else:
context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids
if entry.state_group is None:
entry.state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=entry.prev_group,
delta_ids=entry.delta_ids,
current_state_ids=context.current_state_ids,
)
entry.state_id = entry.state_group
context.state_group = entry.state_group
context.prev_state_events = []
defer.returnValue(context)
@defer.inlineCallbacks
@log_function
def resolve_state_groups(self, room_id, event_ids):
def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
room_id (str):
event_ids (list[str]):
Returns:
a Deferred tuple of (`state_group`, `state`, `prev_state`).
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
Deferred[_StateCacheEntry]: resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -295,13 +337,7 @@ class StateHandler(object):
room_id, event_ids
)
logger.debug(
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
if len(group_names) == 1:
if len(state_groups_ids) == 1:
name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -313,6 +349,102 @@ class StateHandler(object):
delta_ids=delta_ids,
))
result = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups_ids, None, self._state_map_factory,
)
defer.returnValue(result)
def _state_map_factory(self, ev_ids):
return self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
class StateResolutionHandler(object):
"""Responsible for doing state conflict resolution.
Note that the storage layer depends on this handler, so all functions must
be storage-independent.
"""
def __init__(self, hs):
self.clock = hs.get_clock()
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self):
logger.debug("start_caching")
self._state_cache = ExpiringCache(
cache_name="state_cache",
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
self._state_cache.start()
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
self, room_id, state_groups_ids, event_map, state_map_factory,
):
"""Resolves conflicts between a set of state groups
Always generates a new state group (unless we hit the cache), so should
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory.
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
logger.debug(
"resolve_state_groups state_groups %s",
state_groups_ids.keys()
)
group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
@@ -341,17 +473,20 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events(
new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
event_map=event_map,
state_map_factory=state_map_factory,
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
}
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None
new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items():
@@ -388,30 +523,6 @@ class StateHandler(object):
defer.returnValue(cache)
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
return new_state
def _ordered_events(events):
def key_func(e):
@@ -420,19 +531,17 @@ def _ordered_events(events):
return sorted(events, key=key_func)
def resolve_events(state_sets, state_map_factory):
def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event.
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -441,13 +550,6 @@ def resolve_events(state_sets, state_map_factory):
state_sets,
)
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
state_map = state_map_factory
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -461,6 +563,21 @@ def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(list[dict[(str, str), str]]):
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
@@ -491,19 +608,50 @@ def _seperate(state_sets):
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
event_map(dict[str,FrozenEvent]|None):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
If None, all events will be fetched via state_map_factory.
state_map_factory(func): will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
Returns
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
if event_map is not None:
needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d conflicted events", len(needed_events))
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict.
# the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
@@ -515,6 +663,8 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(event_map.iterkeys())
logger.info("Asking for %d auth events", len(new_needed_events))

View File

@@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")

View File

@@ -291,33 +291,33 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
"""Starts a transaction on the database and runs a given function
start_time = time.time() * 1000
Arguments:
desc (str): description of the transaction, for logging and metrics
func (func): callback function, which will be called with a
database transaction (twisted.enterprise.adbapi.Transaction) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
Returns:
Deferred: The result of func
"""
current_context = LoggingContext.current_context()
after_callbacks = []
final_callbacks = []
def inner_func(conn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
conn.reconnect()
current_context.copy_to(context)
return self._new_transaction(
conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs
)
return self._new_transaction(
conn, desc, after_callbacks, final_callbacks, current_context,
func, *args, **kwargs
)
try:
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
result = yield self.runWithConnection(inner_func, *args, **kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
@@ -329,14 +329,27 @@ class SQLBaseStore(object):
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
func (func): callback function, which will be called with a
database connection (twisted.enterprise.adbapi.Connection) as
its first argument, followed by `args` and `kwargs`.
args (list): positional args to pass to `func`
kwargs (dict): named args to pass to `func`
Returns:
Deferred: The result of func
"""
current_context = LoggingContext.current_context()
start_time = time.time() * 1000
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
sched_duration_ms = time.time() * 1000 - start_time
sql_scheduling_timer.inc_by(sched_duration_ms)
current_context.add_database_scheduled(sched_duration_ms)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")

View File

@@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
txn.execute("SELECT nextval('state_group_id_seq')")
return txn.fetchone()[0]

View File

@@ -16,6 +16,7 @@
from synapse.storage.prepare_database import prepare_database
import struct
import threading
class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config):
self.module = database_module
# The current max state_group, or None if we haven't looked
# in the DB yet.
self._current_state_group_id = None
self._current_state_group_id_lock = threading.Lock()
def check_database(self, txn):
pass
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table):
return
def get_next_state_group_id(self, txn):
"""Returns an int that can be used as a new state_group ID
"""
# We do application locking here since if we're using sqlite then
# we are a single process synapse.
with self._current_state_group_id_lock:
if self._current_state_group_id is None:
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
self._current_state_group_id = txn.fetchone()[0]
self._current_state_group_id += 1
return self._current_state_group_id
# Following functions taken from: https://github.com/coleifer/peewee

View File

@@ -27,7 +27,6 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.state import resolve_events
from synapse.util.caches.descriptors import cached
from synapse.types import get_domain_from_id
@@ -38,7 +37,7 @@ from functools import wraps
import synapse.metrics
import logging
import ujson as json
import simplejson as json
# these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401
@@ -110,7 +109,7 @@ class _EventPeristenceQueue(object):
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
deferred = ObservableDeferred(defer.Deferred())
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
@@ -146,18 +145,25 @@ class _EventPeristenceQueue(object):
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
# handle_queue_loop runs in the sentinel logcontext, so
# there is no need to preserve_fn when running the
# callbacks on the deferred.
try:
ret = yield per_item_callback(item)
item.deferred.callback(ret)
except Exception as e:
item.deferred.errback(e)
except Exception:
item.deferred.errback()
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
preserve_fn(handle_queue_loop)()
# set handle_queue_loop off on the background. We don't want to
# attribute work done in it to the current request, so we drop the
# logcontext altogether.
with PreserveLoggingContext():
handle_queue_loop()
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
@@ -230,6 +236,8 @@ class EventsStore(SQLBaseStore):
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
def persist_events(self, events_and_contexts, backfilled=False):
"""
Write events to the database
@@ -335,8 +343,20 @@ class EventsStore(SQLBaseStore):
# NB: Assumes that we are only persisting events for one room
# at a time.
# map room_id->list[event_ids] giving the new forward
# extremities in each room
new_forward_extremeties = {}
# map room_id->(type,state_key)->event_id tracking the full
# state in each room after adding these events
current_state_for_room = {}
# map room_id->(to_delete, to_insert) where each entry is
# a map (type,key)->event_id giving the state delta in each
# room
state_delta_for_room = {}
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
@@ -379,11 +399,20 @@ class EventsStore(SQLBaseStore):
if all_single_prev_not_state:
continue
state = yield self._calculate_state_delta(
room_id, ev_ctx_rm, new_latest_event_ids
logger.info(
"Calculating state delta for room %s", room_id,
)
if state:
current_state_for_room[room_id] = state
current_state = yield self._get_new_state_after_events(
room_id,
ev_ctx_rm, new_latest_event_ids,
)
if current_state is not None:
current_state_for_room[room_id] = current_state
delta = yield self._calculate_state_delta(
room_id, current_state,
)
if delta is not None:
state_delta_for_room[room_id] = delta
yield self.runInteraction(
"persist_events",
@@ -391,7 +420,7 @@ class EventsStore(SQLBaseStore):
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
current_state_for_room=current_state_for_room,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
@@ -408,7 +437,7 @@ class EventsStore(SQLBaseStore):
event_counter.inc(event.type, origin_type, origin_entity)
for room_id, (_, _, new_state) in current_state_for_room.iteritems():
for room_id, new_state in current_state_for_room.iteritems():
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
@@ -460,22 +489,31 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
"""Calculate the new state deltas for a room.
def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
"""Calculate the current state dict after adding some new events to
a room
Assumes that we are only persisting events for one room at a time.
Args:
room_id (str):
room to which the events are being added. Used for logging etc
events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room
new_latest_event_ids (iterable[str]):
the new forward extremities for the room.
Returns:
3-tuple (to_delete, to_insert, new_state) where both are state dicts,
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert. `new_state` is the full set of state.
May return None if there are no changes to be applied.
Deferred[dict[(str,str), str]|None]:
None if there are no changes to the room state, or
a dict of (type, state_key) -> event_id].
"""
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
state_groups = set()
if not new_latest_event_ids:
defer.returnValue({})
# map from state_group to ((type, key) -> event_id) state map
state_groups = {}
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
@@ -486,16 +524,19 @@ class EventsStore(SQLBaseStore):
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
if ctx.state_group is None:
# I don't think this can happen, but let's double-check
raise Exception(
"Context for new extremity event %s has no state "
"group" % (event_id, ),
)
# If we've already seen the state group don't bother adding
# it to the state sets again
if ctx.state_group not in state_groups:
state_sets.append(ctx.current_state_ids)
state_groups[ctx.state_group] = ctx.current_state_ids
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
if ctx.state_group:
# Add this as a seen state group (if it has a state
# group)
state_groups.add(ctx.state_group)
break
else:
# If we couldn't find it, then we'll need to pull
@@ -503,60 +544,50 @@ class EventsStore(SQLBaseStore):
was_updated = True
missing_event_ids.append(event_id)
if not was_updated:
return
if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
groups = set(event_to_groups.itervalues()) - state_groups
groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
if groups:
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.itervalues())
state_groups.update(group_to_state)
if not new_latest_event_ids:
current_state = {}
elif was_updated:
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
current_state = state_sets[0]
else:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
# the events we have yet to persist, so we need a slightly more
# complicated event lookup function than simply looking the events
# up in the db.
events_map = {ev.event_id: ev for ev, _ in events_context}
if len(state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
defer.returnValue(state_groups.values()[0])
@defer.inlineCallbacks
def get_events(ev_ids):
# We get the events by first looking at the list of events we
# are trying to persist, and then fetching the rest from the DB.
db = []
to_return = {}
for ev_id in ev_ids:
ev = events_map.get(ev_id, None)
if ev:
to_return[ev_id] = ev
else:
db.append(ev_id)
def get_events(ev_ids):
return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
events_map = {ev.event_id: ev for ev, _ in events_context}
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
room_id, state_groups, events_map, get_events
)
if db:
evs = yield self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
to_return.update(evs)
defer.returnValue(to_return)
defer.returnValue(res.state)
current_state = yield resolve_events(
state_sets,
state_map_factory=get_events,
)
else:
return
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, current_state):
"""Calculate the new state deltas for a room.
Assumes that we are only persisting events for one room at a time.
Returns:
2-tuple (to_delete, to_insert) where both are state dicts,
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert.
"""
existing_state = yield self.get_current_state_ids(room_id)
existing_events = set(existing_state.itervalues())
@@ -576,7 +607,7 @@ class EventsStore(SQLBaseStore):
if ev_id in events_to_insert
}
defer.returnValue((to_delete, to_insert, current_state))
defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
@@ -636,7 +667,7 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False, current_state_for_room={},
delete_existing=False, state_delta_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables.
@@ -652,7 +683,7 @@ class EventsStore(SQLBaseStore):
delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to
IntegrityError.
current_state_for_room (dict[str, (list[str], list[str])]):
state_delta_for_room (dict[str, (list[str], list[str])]):
The current-state delta for each room. For each room, a tuple
(to_delete, to_insert), being a list of event ids to be removed
from the current state, and a list of event ids to be added to
@@ -664,7 +695,7 @@ class EventsStore(SQLBaseStore):
"""
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
self._update_forward_extremities_txn(
txn,
@@ -708,9 +739,8 @@ class EventsStore(SQLBaseStore):
events_and_contexts=events_and_contexts,
)
# Insert into the state_groups, state_groups_state, and
# event_to_state_groups tables.
self._store_mult_state_groups_txn(txn, events_and_contexts)
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@@ -730,7 +760,7 @@ class EventsStore(SQLBaseStore):
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
to_delete, to_insert, _ = current_state_tuple
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],
@@ -945,10 +975,9 @@ class EventsStore(SQLBaseStore):
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
# insert into the state_group, state_groups_state and
# event_to_state_groups tables.
# insert into event_to_state_groups.
try:
self._store_mult_state_groups_txn(txn, ((event, context),))
self._store_event_state_mappings_txn(txn, ((event, context),))
except Exception:
logger.exception("")
raise
@@ -2018,16 +2047,32 @@ class EventsStore(SQLBaseStore):
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
def delete_old_state(self, room_id, topological_ordering):
return self.runInteraction(
"delete_old_state",
self._delete_old_state_txn, room_id, topological_ordering
)
def purge_history(
self, room_id, topological_ordering, delete_local_events,
):
"""Deletes room history before a certain point
def _delete_old_state_txn(self, txn, room_id, topological_ordering):
"""Deletes old room state
Args:
room_id (str):
topological_ordering (int):
minimum topo ordering to preserve
delete_local_events (bool):
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
"""
return self.runInteraction(
"purge_history",
self._purge_history_txn, room_id, topological_ordering,
delete_local_events,
)
def _purge_history_txn(
self, txn, room_id, topological_ordering, delete_local_events,
):
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -2068,7 +2113,7 @@ class EventsStore(SQLBaseStore):
400, "topological_ordering is greater than forward extremeties"
)
logger.debug("[purge] looking for events to delete")
logger.info("[purge] looking for events to delete")
txn.execute(
"SELECT event_id, state_key FROM events"
@@ -2080,16 +2125,16 @@ class EventsStore(SQLBaseStore):
to_delete = [
(event_id,) for event_id, state_key in event_rows
if state_key is None and not self.hs.is_mine_id(event_id)
if state_key is None and (
delete_local_events or not self.hs.is_mine_id(event_id)
)
]
logger.info(
"[purge] found %i events before cutoff, of which %i are remote"
" non-state events to delete", len(event_rows), len(to_delete))
"[purge] found %i events before cutoff, of which %i can be deleted",
len(event_rows), len(to_delete),
)
for event_id, state_key in event_rows:
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
logger.debug("[purge] Finding new backward extremities")
logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
# all events that point to events that are to be purged
@@ -2103,7 +2148,7 @@ class EventsStore(SQLBaseStore):
)
new_backwards_extrems = txn.fetchall()
logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
txn.execute(
"DELETE FROM event_backward_extremities WHERE room_id = ?",
@@ -2119,7 +2164,7 @@ class EventsStore(SQLBaseStore):
]
)
logger.debug("[purge] finding redundant state groups")
logger.info("[purge] finding redundant state groups")
# Get all state groups that are only referenced by events that are
# to be deleted.
@@ -2136,15 +2181,15 @@ class EventsStore(SQLBaseStore):
)
state_rows = txn.fetchall()
logger.debug("[purge] found %i redundant state groups", len(state_rows))
logger.info("[purge] found %i redundant state groups", len(state_rows))
# make a set of the redundant state groups, so that we can look them up
# efficiently
state_groups_to_delete = set([sg for sg, in state_rows])
# Now we get all the state groups that rely on these state groups
logger.debug("[purge] finding state groups which depend on redundant"
" state groups")
logger.info("[purge] finding state groups which depend on redundant"
" state groups")
remaining_state_groups = []
for i in xrange(0, len(state_rows), 100):
chunk = [sg for sg, in state_rows[i:i + 100]]
@@ -2169,7 +2214,7 @@ class EventsStore(SQLBaseStore):
# Now we turn the state groups that reference to-be-deleted state
# groups to non delta versions.
for sg in remaining_state_groups:
logger.debug("[purge] de-delta-ing remaining state group %s", sg)
logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(
txn, [sg], types=None
)
@@ -2206,7 +2251,7 @@ class EventsStore(SQLBaseStore):
],
)
logger.debug("[purge] removing redundant state groups")
logger.info("[purge] removing redundant state groups")
txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?",
state_rows
@@ -2216,18 +2261,15 @@ class EventsStore(SQLBaseStore):
state_rows
)
# Delete all non-state
logger.debug("[purge] removing events from event_to_state_groups")
logger.info("[purge] removing events from event_to_state_groups")
txn.executemany(
"DELETE FROM event_to_state_groups WHERE event_id = ?",
[(event_id,) for event_id, _ in event_rows]
)
logger.debug("[purge] updating room_depth")
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
(topological_ordering, room_id,)
)
for event_id, _ in event_rows:
txn.call_after(self._get_state_group_for_event.invalidate, (
event_id,
))
# Delete all remote non-state events
for table in (
@@ -2245,7 +2287,7 @@ class EventsStore(SQLBaseStore):
"event_signatures",
"rejections",
):
logger.debug("[purge] removing remote non-state events from %s", table)
logger.info("[purge] removing events from %s", table)
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
@@ -2253,16 +2295,30 @@ class EventsStore(SQLBaseStore):
)
# Mark all state and own events as outliers
logger.debug("[purge] marking remaining events as outliers")
logger.info("[purge] marking remaining events as outliers")
txn.executemany(
"UPDATE events SET outlier = ?"
" WHERE event_id = ?",
[
(True, event_id,) for event_id, state_key in event_rows
if state_key is not None or self.hs.is_mine_id(event_id)
if state_key is not None or (
not delete_local_events and self.hs.is_mine_id(event_id)
)
]
)
# synapse tries to take out an exclusive lock on room_depth whenever it
# persists events (because upsert), and once we run this update, we
# will block that for the rest of our transaction.
#
# So, let's stick it at the end so that we don't block event
# persistence.
logger.info("[purge] updating room_depth")
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
(topological_ordering, room_id,)
)
logger.info("[purge] done")
@defer.inlineCallbacks

View File

@@ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause='url_cache IS NOT NULL',
)
def get_default_thumbnails(self, top_level_type, sub_type):
return []
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
@@ -176,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
desc="store_cached_remote_media",
)
def update_cached_last_access_time(self, origin_id_tuples, time_ts):
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
"""Updates the last access time of the given media
Args:
local_media (iterable[str]): Set of media_ids
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
@@ -184,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
txn.executemany(sql, (
(time_ts, media_origin, media_id)
for media_origin, media_id in origin_id_tuples
(time_ms, media_origin, media_id)
for media_origin, media_id in remote_media
))
sql = (
"UPDATE local_media_repository SET last_access_ts = ?"
" WHERE media_id = ?"
)
txn.executemany(sql, (
(time_ms, media_id)
for media_id in local_media
))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)

View File

@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 46
SCHEMA_VERSION = 47
dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@@ -16,11 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from ._base import SQLBaseStore
from .engines import PostgresEngine, Sqlite3Engine
import collections
import logging
import ujson as json
@@ -40,7 +38,7 @@ RatelimitOverride = collections.namedtuple(
)
class RoomStore(SQLBaseStore):
class RoomStore(SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
@@ -263,8 +261,8 @@ class RoomStore(SQLBaseStore):
},
)
self._store_event_search_txn(
txn, event, "content.topic", event.content["topic"]
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"],
)
def _store_room_name_txn(self, txn, event):
@@ -279,14 +277,14 @@ class RoomStore(SQLBaseStore):
}
)
self._store_event_search_txn(
txn, event, "content.name", event.content["name"]
self.store_event_search_txn(
txn, event, "content.name", event.content["name"],
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self._store_event_search_txn(
txn, event, "content.body", event.content["body"]
self.store_event_search_txn(
txn, event, "content.body", event.content["body"],
)
def _store_history_visibility_txn(self, txn, event):
@@ -308,31 +306,6 @@ class RoomStore(SQLBaseStore):
event.content[key]
))
def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
txn.execute(
sql,
(
event.event_id, event.room_id, key, value,
event.internal_metadata.stream_ordering,
event.origin_server_ts,
)
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
txn.execute(sql, (event.event_id, event.room_id, key, value,))
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
@@ -533,73 +506,114 @@ class RoomStore(SQLBaseStore):
)
self.is_room_blocked.invalidate((room_id,))
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
remote_media_mxcs = []
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _get_media_ids_in_room(txn):
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
while next_token:
sql = """
SELECT stream_ordering, content FROM events
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
# Now update all the tables to set the quarantined_by flag
txn.executemany("""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_mxcs))
txn.executemany(
"""
txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None
local_media_mxcs = []
remote_media_mxcs = []
for stream_ordering, content_json in txn:
next_token = stream_ordering
content = json.loads(content_json)
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
# Now update all the tables to set the quarantined_by flag
txn.executemany("""
UPDATE local_media_repository
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_media_mxcs))
txn.executemany(
"""
UPDATE remote_media_cache
SET quarantined_by = ?
WHERE media_origin AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_media_mxcs
)
WHERE media_origin = ? AND media_id = ?
""",
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
)
)
total_media_quarantined += len(local_media_mxcs)
total_media_quarantined += len(remote_media_mxcs)
total_media_quarantined += len(local_mxcs)
total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
return self.runInteraction(
"quarantine_media_in_room",
_quarantine_media_in_room_txn,
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
txn (cursor)
room_id (str)
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
next_token = self.get_current_events_token() + 1
local_media_mxcs = []
remote_media_mxcs = []
while next_token:
sql = """
SELECT stream_ordering, content FROM events
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
ORDER BY stream_ordering DESC
LIMIT ?
"""
txn.execute(sql, (room_id, next_token, True, False, 100))
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
content = json.loads(content_json)
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
for url in (content_url, thumbnail_url):
if not url:
continue
matches = mxc_re.match(url)
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
if hostname == self.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs

View File

@@ -0,0 +1,16 @@
/* Copyright 2018 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT;

View File

@@ -0,0 +1,37 @@
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.engines import PostgresEngine
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
# if we already have some state groups, we want to start making new
# ones with a higher id.
cur.execute("SELECT max(id) FROM state_groups")
row = cur.fetchone()
if row[0] is None:
start_val = 1
else:
start_val = row[0] + 1
cur.execute(
"CREATE SEQUENCE state_group_id_seq START WITH %s",
(start_val, ),
)
def run_upgrade(*args, **kwargs):
pass

View File

@@ -13,19 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
import logging
import re
import ujson as json
from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging
import re
import ujson as json
logger = logging.getLogger(__name__)
SearchEntry = namedtuple('SearchEntry', [
'key', 'value', 'event_id', 'room_id', 'stream_ordering',
'origin_server_ts',
])
class SearchStore(BackgroundUpdateStore):
@@ -49,16 +55,17 @@ class SearchStore(BackgroundUpdateStore):
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
# we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id, room_id, type, content FROM events"
"SELECT stream_ordering, event_id, room_id, type, content, "
" origin_server_ts FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
@@ -67,6 +74,10 @@ class SearchStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
# we could stream straight from the results into
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
rows = self.cursor_to_dict(txn)
if not rows:
return 0
@@ -79,6 +90,8 @@ class SearchStore(BackgroundUpdateStore):
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
content = json.loads(row["content"])
except Exception:
@@ -93,6 +106,8 @@ class SearchStore(BackgroundUpdateStore):
elif etype == "m.room.name":
key = "content.name"
value = content["name"]
else:
raise Exception("unexpected event type %s" % etype)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -103,25 +118,16 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it
continue
event_search_rows.append((event_id, room_id, key, value))
event_search_rows.append(SearchEntry(
key=key,
value=value,
event_id=event_id,
room_id=room_id,
stream_ordering=stream_ordering,
origin_server_ts=origin_server_ts,
))
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, vector)"
" VALUES (?,?,?,to_tsvector('english', ?))"
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
self.store_search_entries_txn(txn, event_search_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -242,6 +248,62 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(num_rows)
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),),
)
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
Args:
txn (cursor):
entries (iterable[SearchEntry]):
entries to be added to the table
"""
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = ((
entry.event_id, entry.room_id, entry.key, entry.value,
entry.stream_ordering, entry.origin_server_ts,
) for entry in entries)
txn.executemany(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = ((
entry.event_id, entry.room_id, entry.key, entry.value,
) for entry in entries)
txn.executemany(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.

View File

@@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
class StateGroupReadStore(SQLBaseStore):
"""The read-only parts of StateGroupStore
None of these functions write to the state tables, so are suitable for
including in the SlavedStores.
class StateGroupWorkerStore(SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
super(StateGroupReadStore, self).__init__(db_conn, hs)
super(StateGroupWorkerStore, self).__init__(db_conn, hs)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@@ -549,8 +546,117 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(results)
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
"""Store a new set of state, returning a newly assigned state group.
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
to event_id.
Returns:
Deferred[int]: The state group ID
"""
def _store_state_group_txn(txn):
if current_state_ids is None:
# AFAIK, this can never happen
raise Exception("current_state_ids cannot be None")
state_group = self.database_engine.get_next_state_group_id(txn)
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": state_group,
"room_id": room_id,
"event_id": event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": state_group,
"prev_state_group": prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": state_group,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": state_group,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=state_group,
value=dict(current_state_ids),
full=True,
)
return state_group
return self.runInteraction("store_state_group", _store_state_group_txn)
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
@@ -591,27 +697,12 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
where_clause="type='m.room.member'",
)
def _have_persisted_state_group_txn(self, txn, state_group):
txn.execute(
"SELECT count(*) FROM state_groups WHERE id = ?",
(state_group,)
)
row = txn.fetchone()
return row and row[0]
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state_ids is None:
# AFAIK, this can never happen
logger.error(
"Non-outlier event %s had current_state_ids==None",
event.event_id)
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
@@ -620,90 +711,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
continue
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": context.state_group,
"room_id": event.room_id,
"event_id": event.event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if context.prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": context.prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
% (context.prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.iteritems()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.iteritems()
],
)
# Prefill the state group cache with this group.
# It's fine to use the sequence like this as the state group map
# is immutable. (If the map wasn't immutable then this prefill could
# race with another update)
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
key=context.state_group,
value=dict(context.current_state_ids),
full=True,
)
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
@@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
return count
def get_next_state_group(self):
return self._state_groups_id_gen.get_next()
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding

View File

@@ -641,8 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
"""
if self.hs.config.user_directory_search_all_users:
join_clause = ""
where_clause = "?<>''" # naughty hack to keep the same number of binds
# make s.user_id null to keep the ordering algorithm happy
join_clause = """
CROSS JOIN (SELECT NULL as user_id) AS s
"""
join_args = ()
where_clause = "1=1"
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
@@ -651,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
WHERE user_id = ? AND share_private
) AS s USING (user_id)
"""
join_args = (user_id,)
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine):
@@ -692,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
@@ -710,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
args = (user_id, search_query, limit + 1)
args = join_args + (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")

View File

@@ -75,6 +75,7 @@ class Cache(object):
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
)
self.name = name
@@ -83,6 +84,9 @@ class Cache(object):
self.thread = None
self.metrics = register_cache(name, self.cache)
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:

View File

@@ -79,7 +79,11 @@ class ExpiringCache(object):
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
removed_len = len(value.value)
self.metrics.inc_evictions(removed_len)
self._size_estimate -= removed_len
else:
self.metrics.inc_evictions()
def __getitem__(self, key):
try:

View File

@@ -49,7 +49,24 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
evicted_callback=None):
"""
Args:
max_size (int):
keylen (int):
cache_type (type):
type of underlying cache to be used. Typically one of dict
or TreeCache.
size_callback (func(V) -> int | None):
evicted_callback (func(int)|None):
if not None, called on eviction with the size of the evicted
entry
"""
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
delete_node(todelete)
evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
if evicted_callback:
evicted_callback(evicted_len)
def synchronized(f):
@wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
deleted_len = 1
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
deleted_len = size_callback(node.value)
cached_cache_len[0] -= deleted_len
for cb in node.callbacks:
cb()
node.callbacks.clear()
return deleted_len
@synchronized
def cache_get(key, default=None, callbacks=[]):

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import threads, reactor
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
import Queue
class BackgroundFileConsumer(object):
"""A consumer that writes to a file like object. Supports both push
and pull producers
Args:
file_obj (file): The file like object to write to. Closed when
finished.
"""
# For PushProducers pause if we have this many unwritten slices
_PAUSE_ON_QUEUE_SIZE = 5
# And resume once the size of the queue is less than this
_RESUME_ON_QUEUE_SIZE = 2
def __init__(self, file_obj):
self._file_obj = file_obj
# Producer we're registered with
self._producer = None
# True if PushProducer, false if PullProducer
self.streaming = False
# For PushProducers, indicates whether we've paused the producer and
# need to call resumeProducing before we get more data.
self._paused_producer = False
# Queue of slices of bytes to be written. When producer calls
# unregister a final None is sent.
self._bytes_queue = Queue.Queue()
# Deferred that is resolved when finished writing
self._finished_deferred = None
# If the _writer thread throws an exception it gets stored here.
self._write_exception = None
def registerProducer(self, producer, streaming):
"""Part of IConsumer interface
Args:
producer (IProducer)
streaming (bool): True if push based producer, False if pull
based.
"""
if self._producer:
raise Exception("registerProducer called twice")
self._producer = producer
self.streaming = streaming
self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
if not streaming:
self._producer.resumeProducing()
def unregisterProducer(self):
"""Part of IProducer interface
"""
self._producer = None
if not self._finished_deferred.called:
self._bytes_queue.put_nowait(None)
def write(self, bytes):
"""Part of IProducer interface
"""
if self._write_exception:
raise self._write_exception
if self._finished_deferred.called:
raise Exception("consumer has closed")
self._bytes_queue.put_nowait(bytes)
# If this is a PushProducer and the queue is getting behind
# then we pause the producer.
if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
self._paused_producer = True
self._producer.pauseProducing()
def _writer(self):
"""This is run in a background thread to write to the file.
"""
try:
while self._producer or not self._bytes_queue.empty():
# If we've paused the producer check if we should resume the
# producer.
if self._producer and self._paused_producer:
if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
reactor.callFromThread(self._resume_paused_producer)
bytes = self._bytes_queue.get()
# If we get a None (or empty list) then that's a signal used
# to indicate we should check if we should stop.
if bytes:
self._file_obj.write(bytes)
# If its a pull producer then we need to explicitly ask for
# more stuff.
if not self.streaming and self._producer:
reactor.callFromThread(self._producer.resumeProducing)
except Exception as e:
self._write_exception = e
raise
finally:
self._file_obj.close()
def wait(self):
"""Returns a deferred that resolves when finished writing to file
"""
return make_deferred_yieldable(self._finished_deferred)
def _resume_paused_producer(self):
"""Gets called if we should resume producing after being paused
"""
if self._paused_producer and self._producer:
self._paused_producer = False
self._producer.resumeProducing()

View File

@@ -52,13 +52,17 @@ except Exception:
class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a
"with" block.
Args:
name (str): Name for the context for debugging.
"""
__slots__ = [
"previous_context", "name", "usage_start", "usage_end", "main_thread",
"__dict__", "tag", "alive",
"previous_context", "name", "ru_stime", "ru_utime",
"db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
"usage_start", "usage_end",
"main_thread", "alive",
"request", "tag",
]
thread_local = threading.local()
@@ -83,6 +87,9 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
pass
def add_database_scheduled(self, sched_ms):
pass
def __nonzero__(self):
return False
@@ -94,9 +101,17 @@ class LoggingContext(object):
self.ru_stime = 0.
self.ru_utime = 0.
self.db_txn_count = 0
self.db_txn_duration = 0.
# ms spent waiting for db txns, excluding scheduling time
self.db_txn_duration_ms = 0
# ms spent waiting for db txns to be scheduled
self.db_sched_duration_ms = 0
self.usage_start = None
self.usage_end = None
self.main_thread = threading.current_thread()
self.request = None
self.tag = ""
self.alive = True
@@ -105,7 +120,11 @@ class LoggingContext(object):
@classmethod
def current_context(cls):
"""Get the current logging context from thread local storage"""
"""Get the current logging context from thread local storage
Returns:
LoggingContext: the current logging context
"""
return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
@@ -155,11 +174,13 @@ class LoggingContext(object):
self.alive = False
def copy_to(self, record):
"""Copy fields from this context to the record"""
for key, value in self.__dict__.items():
setattr(record, key, value)
"""Copy logging fields from this context to a log record or
another LoggingContext
"""
record.ru_utime, record.ru_stime = self.get_resource_usage()
# 'request' is the only field we currently use in the logger, so that's
# all we need to copy
record.request = self.request
def start(self):
if threading.current_thread() is not self.main_thread:
@@ -194,7 +215,16 @@ class LoggingContext(object):
def add_database_transaction(self, duration_ms):
self.db_txn_count += 1
self.db_txn_duration += duration_ms / 1000.
self.db_txn_duration_ms += duration_ms
def add_database_scheduled(self, sched_ms):
"""Record a use of the database pool
Args:
sched_ms (int): number of milliseconds it took us to get a
connection
"""
self.db_sched_duration_ms += sched_ms
class LoggingContextFilter(logging.Filter):

View File

@@ -27,25 +27,62 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
block_timer = metrics.register_distribution(
"block_timer",
labels=["block_name"]
# total number of times we have hit this block
block_counter = metrics.register_counter(
"block_count",
labels=["block_name"],
alternative_names=(
# the following are all deprecated aliases for the same metric
metrics.name_prefix + x for x in (
"_block_timer:count",
"_block_ru_utime:count",
"_block_ru_stime:count",
"_block_db_txn_count:count",
"_block_db_txn_duration:count",
)
)
)
block_ru_utime = metrics.register_distribution(
"block_ru_utime", labels=["block_name"]
block_timer = metrics.register_counter(
"block_time_seconds",
labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_timer:total",
),
)
block_ru_stime = metrics.register_distribution(
"block_ru_stime", labels=["block_name"]
block_ru_utime = metrics.register_counter(
"block_ru_utime_seconds", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_ru_utime:total",
),
)
block_db_txn_count = metrics.register_distribution(
"block_db_txn_count", labels=["block_name"]
block_ru_stime = metrics.register_counter(
"block_ru_stime_seconds", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_ru_stime:total",
),
)
block_db_txn_duration = metrics.register_distribution(
"block_db_txn_duration", labels=["block_name"]
block_db_txn_count = metrics.register_counter(
"block_db_txn_count", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_db_txn_count:total",
),
)
# seconds spent waiting for db txns, excluding scheduling time, in this block
block_db_txn_duration = metrics.register_counter(
"block_db_txn_duration_seconds", labels=["block_name"],
alternative_names=(
metrics.name_prefix + "_block_db_txn_duration:total",
),
)
# seconds spent waiting for a db connection, in this block
block_db_sched_duration = metrics.register_counter(
"block_db_sched_duration_seconds", labels=["block_name"],
)
@@ -64,7 +101,9 @@ def measure_func(name):
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime", "db_txn_count", "db_txn_duration", "created_context"
"ru_stime",
"db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
"created_context",
]
def __init__(self, clock, name):
@@ -84,13 +123,16 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
self.db_txn_duration = self.start_context.db_txn_duration
self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
duration = self.clock.time_msec() - self.start
block_counter.inc(self.name)
block_timer.inc_by(duration, self.name)
context = LoggingContext.current_context()
@@ -114,7 +156,12 @@ class Measure(object):
context.db_txn_count - self.db_txn_count, self.name
)
block_db_txn_duration.inc_by(
context.db_txn_duration - self.db_txn_duration, self.name
(context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
self.name
)
block_db_sched_duration.inc_by(
(context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
self.name
)
if self.created_context:

View File

@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
"""Raised by the limiter (and federation client) to indicate that we are
are deliberately not attempting to contact a given server.
Args:
retry_last_ts (int): the unix ts in milliseconds of our last attempt
to contact the server. 0 indicates that the last attempt was
successful or that we've never actually attempted to connect.
retry_interval (int): the time in milliseconds to wait until the next
attempt.
destination (str): the domain in question
"""
msg = "Not retrying server %s." % (destination,)
super(NotRetryingDestination, self).__init__(msg)

48
synapse/util/threepids.py Normal file
View File

@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
logger = logging.getLogger(__name__)
def check_3pid_allowed(hs, medium, address):
"""Checks whether a given format of 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
medium (str): 3pid medium - e.g. email, msisdn
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
bool: whether the 3PID medium/address is allowed to be added to this HS
"""
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
address, medium, constraint['pattern'], constraint['medium'],
)
if (
medium == constraint['medium'] and
re.match(constraint['pattern'], address)
):
return True
else:
return True
return False

View File

@@ -12,9 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path
import re
import shutil
import tempfile
from synapse.config.homeserver import HomeServerConfig
from tests import unittest
@@ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase):
def setUp(self):
self.dir = tempfile.mkdtemp()
print self.dir
self.file = os.path.join(self.dir, "homeserver.yaml")
def tearDown(self):
@@ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase):
]),
set(os.listdir(self.dir))
)
self.assert_log_filename_is(
os.path.join(self.dir, "lemurs.win.log.config"),
os.path.join(os.getcwd(), "homeserver.log"),
)
def assert_log_filename_is(self, log_config_file, expected):
with open(log_config_file) as f:
config = f.read()
# find the 'filename' line
matches = re.findall("^\s*filename:\s*(.*)$", config, re.M)
self.assertEqual(1, len(matches))
self.assertEqual(matches[0], expected)

Some files were not shown because too many files have changed in this diff Show More