Compare commits

..

14 Commits

Author SHA1 Message Date
Dan Callahan
a06dd1d6b5 Merge tag 'v1.26.0rc2' into travis/fosdem/hotfixes 2021-01-26 15:38:22 +00:00
Patrick Cloke
69961c7e9f Tweak changes. 2021-01-25 08:26:42 -05:00
Patrick Cloke
a01605c136 1.26.0rc2 2021-01-25 08:25:40 -05:00
Erik Johnston
056327457f Fix chain cover update to handle events with duplicate auth events (#9210) 2021-01-22 19:44:08 +00:00
Travis Ralston
fc2cbce232 Fix state endpoint to be faster 2021-01-22 12:37:42 -07:00
Erik Johnston
28f255d5f3 Bump psycopg2 version (#9204)
As we use `execute_values` with the `fetch` parameter.
2021-01-22 11:14:49 +00:00
Travis Ralston
f7a03e86e0 Merge branch 'travis/fosdem/admin-api-room-state' into travis/fosdem/hotfixes 2021-01-21 12:35:05 -07:00
Travis Ralston
d9867f1640 Merge branch 'travis/fosdem/admin-api-groups' into travis/fosdem/hotfixes 2021-01-21 12:34:59 -07:00
Travis Ralston
7d8cc63e37 Get the right requester object 2021-01-19 14:03:39 -07:00
Travis Ralston
19a4821ffc Changelog 2021-01-19 14:01:08 -07:00
Travis Ralston
40f96320a2 Add an admin API to get the current room state
This could arguably replace the existing admin API for `/members`, however that is out of scope of this change.

This sort of endpoint is ideal for moderation use cases as well as other applications, such as needing to retrieve various bits of information about a room to perform a task (like syncing power levels between two places). This endpoint exposes nothing more than an admin would be able to access with a `select *` query on their database.
2021-01-19 13:59:29 -07:00
Travis Ralston
e2377bba70 Appease the linters 2021-01-19 13:25:10 -07:00
Travis Ralston
84204f8020 Changelog 2021-01-19 13:23:40 -07:00
Travis Ralston
95d7074322 Add admin APIs to force-join users to groups and manage their flair
Fixes https://github.com/matrix-org/synapse/issues/9143

Though the groups API is disappearing soon, these functions are intended to make flair management easier in the short term.
2021-01-19 13:21:17 -07:00
54 changed files with 495 additions and 722 deletions

View File

@@ -9,3 +9,5 @@ apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8"
exec tox -e py35-old,combine

View File

@@ -1,3 +1,20 @@
Synapse 1.26.0rc2 (2021-01-25)
==============================
Bugfixes
--------
- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
Internal Changes
----------------
- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
Synapse 1.26.0rc1 (2021-01-20)
==============================

View File

@@ -1 +0,0 @@
Add tests to `test_user.UsersListTestCase` for List Users Admin API.

View File

@@ -1 +0,0 @@
Various improvements to the federation client.

View File

@@ -1 +0,0 @@
Add link to Matrix VoIP tester for turn-howto.

View File

@@ -1 +0,0 @@
Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).

1
changelog.d/9167.feature Normal file
View File

@@ -0,0 +1 @@
Add server admin endpoints to join users to legacy groups and manage their flair.

1
changelog.d/9168.feature Normal file
View File

@@ -0,0 +1 @@
Add an admin API for retrieving the current room state of a room.

View File

@@ -1 +0,0 @@
Speed up chain cover calculation when persisting a batch of state events at once.

View File

@@ -1 +0,0 @@
Add a `long_description_type` to the package metadata.

View File

@@ -1 +0,0 @@
Speed up batch insertion when using PostgreSQL.

View File

@@ -1 +0,0 @@
Emit an error at startup if different Identity Providers are configured with the same `idp_id`.

View File

@@ -1 +0,0 @@
Speed up batch insertion when using PostgreSQL.

View File

@@ -1 +0,0 @@
Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration.

View File

@@ -1 +0,0 @@
Improve performance of concurrent use of `StreamIDGenerators`.

View File

@@ -1 +0,0 @@
Add some missing source directories to the automatic linting script.

View File

@@ -1 +0,0 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

View File

@@ -1 +0,0 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

View File

@@ -367,6 +367,36 @@ Response:
}
```
# Room State API
The Room State admin API allows server admins to get a list of all state events in a room.
The response includes the following fields:
* `state` - The current state of the room at the time of request.
## Usage
A standard request:
```
GET /_synapse/admin/v1/rooms/<room_id>/state
{}
```
Response:
```json
{
"state": [
{"type": "m.room.create", "state_key": "", "etc": true},
{"type": "m.room.power_levels", "state_key": "", "etc": true},
{"type": "m.room.name", "state_key": "", "etc": true}
]
}
```
# Delete Room API
The Delete Room admin API allows server admins to remove rooms from server

View File

@@ -232,12 +232,6 @@ Here are a few things to try:
(Understanding the output is beyond the scope of this document!)
* You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
Note that this test is not fully reliable yet, so don't be discouraged if
the test fails.
[Here](https://github.com/matrix-org/voip-tester) is the github repo of the
source of the tester, where you can file bug reports.
* There is a WebRTC test tool at
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
use it, you will need a username/password for your TURN server. You can

View File

@@ -80,8 +80,7 @@ else
# then lint everything!
if [[ -z ${files+x} ]]; then
# Lint all source code files and directories
# Note: this list aims the mirror the one in tox.ini
files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
fi
fi

View File

@@ -121,7 +121,6 @@ setup(
include_package_data=True,
zip_safe=False,
long_description=long_description,
long_description_content_type="text/x-rst",
python_requires="~=3.5",
classifiers=[
"Development Status :: 5 - Production/Stable",

View File

@@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.26.0rc1"
__version__ = "1.26.0rc2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View File

@@ -15,7 +15,6 @@
# limitations under the License.
import string
from collections import Counter
from typing import Iterable, Optional, Tuple, Type
import attr
@@ -44,16 +43,6 @@ class OIDCConfig(Config):
except DependencyException as e:
raise ConfigError(e.message) from e
# check we don't have any duplicate idp_ids now. (The SSO handler will also
# check for duplicates when the REST listeners get registered, but that happens
# after synapse has forked so doesn't give nice errors.)
c = Counter([i.idp_id for i in self.oidc_providers])
for idp_id, count in c.items():
if count > 1:
raise ConfigError(
"Multiple OIDC providers have the idp_id %r." % idp_id
)
public_baseurl = self.public_baseurl
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"

View File

@@ -18,7 +18,6 @@ import copy
import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -27,6 +26,7 @@ from typing import (
List,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
@@ -61,9 +61,6 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -83,10 +80,10 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"):
def __init__(self, hs):
super().__init__(hs)
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self.pdu_destination_tried = {}
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -119,32 +116,33 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict
@log_function
async def make_query(
def make_query(
self,
destination: str,
query_type: str,
args: dict,
retry_on_dns_fail: bool = False,
ignore_backoff: bool = False,
) -> JsonDict:
destination,
query_type,
args,
retry_on_dns_fail=False,
ignore_backoff=False,
):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination: Domain name of the remote homeserver
query_type: Category of the query type; should match the
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
handler name used in register_query_handler().
args: Mapping of strings to strings containing the details
args (dict): Mapping of strings to strings containing the details
of the query request.
ignore_backoff: true to ignore the historical backoff data
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
Returns:
The JSON object from the response
a Awaitable which will eventually yield a JSON object from the
response
"""
sent_queries_counter.labels(query_type).inc()
return await self.transport_layer.make_query(
return self.transport_layer.make_query(
destination,
query_type,
args,
@@ -153,52 +151,42 @@ class FederationClient(FederationBase):
)
@log_function
async def query_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
def query_client_keys(self, destination, content, timeout):
"""Query device keys for a device hosted on a remote server.
Args:
destination: Domain name of the remote homeserver
content: The query content.
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
The JSON object from the response
an Awaitable which will eventually yield a JSON object from the
response
"""
sent_queries_counter.labels("client_device_keys").inc()
return await self.transport_layer.query_client_keys(
destination, content, timeout
)
return self.transport_layer.query_client_keys(destination, content, timeout)
@log_function
async def query_user_devices(
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
def query_user_devices(self, destination, user_id, timeout=30000):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.labels("user_devices").inc()
return await self.transport_layer.query_user_devices(
destination, user_id, timeout
)
return self.transport_layer.query_user_devices(destination, user_id, timeout)
@log_function
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.
Args:
destination: Domain name of the remote homeserver
content: The query content.
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
The JSON object from the response
an Awaitable which will eventually yield a JSON object from the
response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
return await self.transport_layer.claim_client_keys(
destination, content, timeout
)
return self.transport_layer.claim_client_keys(destination, content, timeout)
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -207,10 +195,10 @@ class FederationClient(FederationBase):
given destination server.
Args:
dest: The remote homeserver to ask.
room_id: The room_id to backfill.
limit: The maximum number of events to return.
extremities: our current backwards extremities, to backfill from
dest (str): The remote homeserver to ask.
room_id (str): The room_id to backfill.
limit (int): The maximum number of events to return.
extremities (list): our current backwards extremities, to backfill from
"""
logger.debug("backfill extrem=%s", extremities)
@@ -382,7 +370,7 @@ class FederationClient(FederationBase):
for events that have failed their checks
Returns:
A list of PDUs that have valid signatures and hashes.
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@@ -430,9 +418,7 @@ class FederationClient(FederationBase):
else:
return [p for p in valid_pdus if p]
async def get_event_auth(
self, destination: str, room_id: str, event_id: str
) -> List[EventBase]:
async def get_event_auth(self, destination, room_id, event_id):
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id)
@@ -714,16 +700,18 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request)
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
async def _do_send_join(self, destination: str, pdu: EventBase):
time_now = self._clock.time_msec()
try:
return await self.transport_layer.send_join_v2(
content = await self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -781,7 +769,7 @@ class FederationClient(FederationBase):
time_now = self._clock.time_msec()
try:
return await self.transport_layer.send_invite_v2(
content = await self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
@@ -791,6 +779,7 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -853,16 +842,18 @@ class FederationClient(FederationBase):
"send_leave", destinations, send_request
)
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
async def _do_send_leave(self, destination, pdu):
time_now = self._clock.time_msec()
try:
return await self.transport_layer.send_leave_v2(
content = await self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -888,7 +879,7 @@ class FederationClient(FederationBase):
# content.
return resp[1]
async def get_public_rooms(
def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
@@ -896,7 +887,7 @@ class FederationClient(FederationBase):
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
) -> JsonDict:
):
"""Get the list of public rooms from a remote homeserver
Args:
@@ -910,7 +901,8 @@ class FederationClient(FederationBase):
party instance
Returns:
The response from the remote server.
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
Raises:
HttpResponseException: There was an exception returned from the remote server
@@ -918,7 +910,7 @@ class FederationClient(FederationBase):
requests over federation
"""
return await self.transport_layer.get_public_rooms(
return self.transport_layer.get_public_rooms(
remote_server,
limit,
since_token,
@@ -931,7 +923,7 @@ class FederationClient(FederationBase):
self,
destination: str,
room_id: str,
earliest_events_ids: Iterable[str],
earliest_events_ids: Sequence[str],
latest_events: Iterable[EventBase],
limit: int,
min_depth: int,
@@ -982,9 +974,7 @@ class FederationClient(FederationBase):
return signed_events
async def forward_third_party_invite(
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
async def forward_third_party_invite(self, destinations, room_id, event_dict):
for destination in destinations:
if destination == self.server_name:
continue
@@ -993,7 +983,7 @@ class FederationClient(FederationBase):
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
return
return None
except CodeMessageException:
raise
except Exception as e:
@@ -1005,7 +995,7 @@ class FederationClient(FederationBase):
async def get_room_complexity(
self, destination: str, room_id: str
) -> Optional[JsonDict]:
) -> Optional[dict]:
"""
Fetch the complexity of a remote room from another server.
@@ -1018,9 +1008,10 @@ class FederationClient(FederationBase):
could not fetch the complexity.
"""
try:
return await self.transport_layer.get_room_complexity(
complexity = await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.

View File

@@ -365,6 +365,32 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
async def force_join_user_to_group(self, group_id, user_id):
"""Forces a user to join a group.
"""
if not self.is_mine_id(group_id):
raise SynapseError(400, "Can only affect local groups")
if not self.is_mine_id(user_id):
raise SynapseError(400, "Can only affect local users")
# Bypass the group server to avoid business logic regarding whether or not
# the user can actually join.
await self.store.add_user_to_group(group_id, user_id)
token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
is_admin=False,
local_attestation=None,
remote_attestation=None,
is_publicised=False,
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
return {}
async def accept_invite(self, group_id, user_id, content):
"""Accept an invite to a group
"""

View File

@@ -174,7 +174,7 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False
self.storage, user_id, last_events, filter_send_to_client=False,
)
event = last_events[0]

View File

@@ -86,8 +86,8 @@ REQUIREMENTS = [
CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
# we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.7"],
# we use execute_values with the fetch param, which arrived in psycopg 2.8.
"postgres": ["psycopg2>=2.8"],
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [

View File

@@ -31,7 +31,11 @@ from synapse.rest.admin.event_reports import (
EventReportDetailRestServlet,
EventReportsRestServlet,
)
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.groups import (
DeleteGroupAdminRestServlet,
ForceJoinGroupAdminRestServlet,
UpdatePublicityGroupAdminRestServlet,
)
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
@@ -41,6 +45,7 @@ from synapse.rest.admin.rooms import (
MakeRoomAdminRestServlet,
RoomMembersRestServlet,
RoomRestServlet,
RoomStateRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -209,6 +214,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
@@ -244,6 +250,8 @@ def register_servlets_for_client_rest_resource(hs, http_server):
ShutdownRoomRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)
ForceJoinGroupAdminRestServlet(hs).register(http_server)
UpdatePublicityGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
# Load the media repo ones if we're using them. Otherwise load the servlets which

View File

@@ -15,7 +15,11 @@
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
logger = logging.getLogger(__name__)
@@ -41,3 +45,57 @@ class DeleteGroupAdminRestServlet(RestServlet):
await self.group_server.delete_group(group_id, requester.user.to_string())
return 200, {}
class ForceJoinGroupAdminRestServlet(RestServlet):
"""Allows a server admin to force-join a local user to a local group.
"""
PATTERNS = admin_patterns("/group/(?P<group_id>[^/]*)/force_join$")
def __init__(self, hs):
self.groups_handler = hs.get_groups_local_handler()
self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth()
async def on_POST(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
if not self.is_mine_id(group_id):
raise SynapseError(400, "Can only affect local groups")
body = parse_json_object_from_request(request, allow_empty_body=False)
assert_params_in_dict(body, ["user_id"])
target_user_id = body["user_id"]
await self.groups_handler.force_join_user_to_group(group_id, target_user_id)
return 200, {}
class UpdatePublicityGroupAdminRestServlet(RestServlet):
"""Allows a server admin to update a user's publicity (flair) for a given group.
"""
PATTERNS = admin_patterns("/group/(?P<group_id>[^/]*)/update_publicity$")
def __init__(self, hs):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
self.auth = hs.get_auth()
async def on_POST(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
body = parse_json_object_from_request(request, allow_empty_body=False)
assert_params_in_dict(body, ["user_id"])
target_user_id = body["user_id"]
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "Can only affect local users")
# Logic copied from `/self/update_publicity` endpoint.
publicise = body["publicise"]
await self.store.update_group_publicity(group_id, target_user_id, publicise)
return 200, {}

View File

@@ -292,6 +292,45 @@ class RoomMembersRestServlet(RestServlet):
return 200, ret
class RoomStateRestServlet(RestServlet):
"""
Get full state within a room.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
ret = await self.store.get_room(room_id)
if not ret:
raise NotFoundError("Room not found")
event_ids = await self.store.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = await self._event_serializer.serialize_events(
events.values(),
now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
)
ret = {"state": room_state}
return 200, ret
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")

View File

@@ -83,32 +83,17 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
def __init__(self, hs: "HomeServer"):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
async def on_GET(self, request):
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
if start < 0:
raise SynapseError(
400,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
400,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@@ -118,7 +103,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
if (start + limit) < total:
if len(users) >= limit:
ret["next_token"] = str(start + len(users))
return 200, ret

View File

@@ -300,7 +300,6 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@@ -313,7 +312,6 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@@ -323,7 +321,6 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:

View File

@@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING
from twisted.web.http import Request
@@ -106,17 +106,31 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
url_cache=media_info["url_cache"],
server_name=None,
)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
async def _select_or_generate_local_thumbnail(
self,
@@ -262,64 +276,26 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_info["filesystem_id"],
url_cache=None,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: Request,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
url_cache: Optional[str] = None,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos:
file_info = self._select_thumbnail(
desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
if not file_info:
logger.info("Couldn't find a thumbnail matching the desired inputs")
respond_404(request)
return
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
request, responder, file_info.thumbnail_type, file_info.thumbnail_length
)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
@@ -330,117 +306,67 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
url_cache: Optional[str],
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
thumbnail_infos,
) -> dict:
d_w = desired_width
d_h = desired_height
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
if desired_method.lower() == "crop":
crop_info_list = []
# Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
continue
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
t_method = info["thumbnail_method"]
if t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
)
if crop_info_list:
thumbnail_info = min(crop_info_list)[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2)[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
return min(crop_info_list)[-1]
else:
return min(crop_info_list2)[-1]
else:
info_list = []
# Other thumbnails.
info_list2 = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "scale":
continue
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
info_list.append((size_quality, type_quality, length_quality, info))
else:
elif t_method == "scale":
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
thumbnail_info = min(info_list)[-1]
elif info_list2:
thumbnail_info = min(info_list2)[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
thumbnail_length=thumbnail_info["thumbnail_length"],
)
# No matching thumbnail was found.
return None
return min(info_list)[-1]
else:
return min(info_list2)[-1]

View File

@@ -262,18 +262,13 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
"""Similar to `executemany`, except `txn.rowcount` will not be correct
afterwards.
More efficient than `executemany` on PostgreSQL
"""
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
self.executemany(sql, args)
for val in args:
self.execute(sql, val)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
@@ -893,7 +888,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]),
)
txn.execute_batch(sql, vals)
txn.executemany(sql, vals)
async def simple_upsert(
self,

View File

@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
txn.executemany(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count)
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about
# when the latest change happened.
txn.execute_batch(
txn.executemany(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?

View File

@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?)
"""
txn.execute_batch(
txn.executemany(
sql,
(
_gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
],
)
txn.execute_batch(
txn.executemany(
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?

View File

@@ -473,9 +473,8 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)
@classmethod
@staticmethod
def _add_chain_cover_index(
cls,
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
@@ -615,17 +614,60 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for:
return
# Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
events_to_calc_chain_id_for,
chain_map,
)
chain_map.update(new_chain_tuples)
# We now calculate the chain IDs/sequence numbers for the events. We
# do this by looking at the chain ID and sequence number of any auth
# event with the same type/state_key and incrementing the sequence
# number by one. If there was no match or the chain ID/sequence
# number is already taken we generate a new chain.
#
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events
# before the event itself.
chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
new_chain_tuple = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
already_allocated = db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
"chain_id": proposed_new_id,
"sequence_number": proposed_new_seq,
},
retcol="event_id",
allow_none=True,
)
if already_allocated:
# Mark it as already allocated so we don't need to hit
# the DB again.
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
else:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
if not new_chain_tuple:
new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
db_pool.simple_insert_many_txn(
txn,
@@ -752,137 +794,6 @@ class PersistEventsStore:
],
)
@staticmethod
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
"""Allocates, but does not persist, chain ID/sequence numbers for the
events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
for info on args)
"""
# We now calculate the chain IDs/sequence numbers for the events. We do
# this by looking at the chain ID and sequence number of any auth event
# with the same type/state_key and incrementing the sequence number by
# one. If there was no match or the chain ID/sequence number is already
# taken we generate a new chain.
#
# We try to reduce the number of times that we hit the database by
# batching up calls, to make this more efficient when persisting large
# numbers of state events (e.g. during joins).
#
# We do this by:
# 1. Calculating for each event which auth event will be used to
# inherit the chain ID, i.e. converting the auth chain graph to a
# tree that we can allocate chains on. We also keep track of which
# existing chain IDs have been referenced.
# 2. Fetching the max allocated sequence number for each referenced
# existing chain ID, generating a map from chain ID to the max
# allocated sequence number.
# 3. Iterating over the tree and allocating a chain ID/seq no. to the
# new event, by incrementing the sequence number from the
# referenced event's chain ID/seq no. and checking that the
# incremented sequence number hasn't already been allocated (by
# looking in the map generated in the previous step). We generate a
# new chain if the sequence number has already been allocated.
#
existing_chains = set() # type: Set[int]
tree = [] # type: List[Tuple[str, Optional[str]]]
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
# the event itself.
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map.get(auth_id)
if existing_chain_id:
existing_chains.add(existing_chain_id[0])
tree.append((event_id, auth_id))
break
else:
tree.append((event_id, None))
# Fetch the current max sequence number for each existing referenced chain.
sql = """
SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
WHERE %s
GROUP BY chain_id
"""
clause, args = make_in_list_sql_clause(
db_pool.engine, "chain_id", existing_chains
)
txn.execute(sql % (clause,), args)
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
# Allocate the new events chain ID/sequence numbers.
#
# To reduce the number of calls to the database we don't allocate a
# chain ID number in the loop, instead we use a temporary `object()` for
# each new chain ID. Once we've done the loop we generate the necessary
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.
unallocated_chain_ids = set() # type: Set[object]
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
# `new_chain_tuples` map.
existing_chain_id = None
if auth_event_id:
existing_chain_id = new_chain_tuples.get(auth_event_id)
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
# If we need to start a new chain we allocate a temporary chain ID.
if not new_chain_tuple:
new_chain_tuple = (object(), 1)
unallocated_chain_ids.add(new_chain_tuple[0])
new_chain_tuples[event_id] = new_chain_tuple
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)
# Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int]
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
return {
event_id: (chain_id_to_allocated_map[chain_id], seq)
for event_id, (chain_id, seq) in new_chain_tuples.items()
}
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
@@ -965,7 +876,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
txn.execute_batch(
txn.executemany(
sql,
(
(
@@ -984,7 +895,7 @@ class PersistEventsStore:
)
# Now we actually update the current_state_events table
txn.execute_batch(
txn.executemany(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
@@ -996,7 +907,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships.
txn.execute_batch(
txn.executemany(
"""INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1016,7 +927,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
txn.execute_batch(
txn.executemany(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
@@ -1027,7 +938,7 @@ class PersistEventsStore:
)
if to_insert:
txn.execute_batch(
txn.executemany(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1827,7 +1738,7 @@ class PersistEventsStore:
"""
if events_and_contexts:
txn.execute_batch(
txn.executemany(
sql,
(
(
@@ -1856,7 +1767,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being
# persisted.
txn.execute_batch(
txn.executemany(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@@ -1975,7 +1886,7 @@ class PersistEventsStore:
" )"
)
txn.execute_batch(
txn.executemany(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1989,7 +1900,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.execute_batch(
txn.executemany(
query,
[
(ev.event_id, ev.room_id)

View File

@@ -139,6 +139,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
@@ -176,7 +178,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
txn.execute_batch(sql, update_rows)
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
clump = update_rows[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -206,6 +210,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
@@ -250,7 +256,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
txn.execute_batch(sql, rows_to_update)
for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,

View File

@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?"
)
txn.execute_batch(
txn.executemany(
sql,
(
(time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?"
)
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
txn.executemany(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn

View File

@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
)
# Update backward extremeties
txn.execute_batch(
txn.executemany(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],

View File

@@ -1104,7 +1104,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids
"""
txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(

View File

@@ -873,6 +873,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
INSERT_CLUMP_SIZE = 1000
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -913,7 +915,9 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
"""
txn.execute_batch(to_update_sql, to_update)
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
clump = to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(to_update_sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,

View File

@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users:
cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table")

View File

@@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
txn.execute_batch(sql, args)
txn.executemany(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
txn.execute_batch(sql, args)
txn.executemany(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")

View File

@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
logger.info("[purge] removing redundant state groups")
txn.execute_batch(
txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete),
)
txn.execute_batch(
txn.executemany(
"DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete),
)

View File

@@ -15,11 +15,12 @@
import heapq
import logging
import threading
from collections import OrderedDict
from collections import deque
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -100,13 +101,7 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
# We use this as an ordered set, as we want to efficiently append items,
# remove items and get the first item. Since we insert IDs in order, the
# insertion ordering will ensure its in the correct ordering.
#
# The key and values are the same, but we never look at the values.
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -118,7 +113,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
self._unfinished_ids[next_id] = next_id
self._unfinished_ids.append(next_id)
@contextmanager
def manager():
@@ -126,7 +121,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
self._unfinished_ids.pop(next_id)
self._unfinished_ids.remove(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -145,7 +140,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids[next_id] = next_id
self._unfinished_ids.append(next_id)
@contextmanager
def manager():
@@ -154,7 +149,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
self._unfinished_ids.pop(next_id)
self._unfinished_ids.remove(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -167,7 +162,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
return self._unfinished_ids[0] - self._step
return self._current

View File

@@ -69,11 +69,6 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""Gets the next ID in the sequence"""
...
@abc.abstractmethod
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
"""Get the next `n` IDs in the sequence"""
...
@abc.abstractmethod
def check_consistency(
self,
@@ -224,17 +219,6 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
with self._lock:
if self._current_max_id is None:
assert self._callback is not None
self._current_max_id = self._callback(txn)
self._callback = None
first_id = self._current_max_id + 1
self._current_max_id += n
return [first_id + i for i in range(n)]
def check_consistency(
self,
db_conn: Connection,

View File

@@ -78,7 +78,7 @@ def sorted_topologically(
if node not in degree_map:
continue
for edge in edges:
for edge in set(edges):
if edge in degree_map:
degree_map[node] += 1

View File

@@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.json_body["total"], 3)
def test_room_state(self):
"""Test that room state can be requested correctly"""
# Create two test rooms
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
channel = self.make_request(
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertIn("state", channel.json_body)
# testing that the state events match is painful and not done here. We assume that
# the create_room already does the right thing, so no need to verify that we got
# the state events it created.
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):

View File

@@ -28,7 +28,6 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync
from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import make_awaitable
@@ -469,6 +468,13 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.user1 = self.register_user(
"user1", "pass1", admin=False, displayname="Name 1"
)
self.user2 = self.register_user(
"user2", "pass2", admin=False, displayname="Name 2"
)
def test_no_auth(self):
"""
Try to list users without authentication.
@@ -482,7 +488,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
If the user is not a server admin, an error is returned.
"""
self._create_users(1)
other_user_token = self.login("user1", "pass1")
channel = self.make_request("GET", self.url, access_token=other_user_token)
@@ -494,8 +499,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
"""
List all users, including deactivated users.
"""
self._create_users(2)
channel = self.make_request(
"GET",
self.url + "?deactivated=true",
@@ -508,7 +511,14 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, channel.json_body["total"])
# Check that all fields are available
self._check_fields(channel.json_body["users"])
for u in channel.json_body["users"]:
self.assertIn("name", u)
self.assertIn("is_guest", u)
self.assertIn("admin", u)
self.assertIn("user_type", u)
self.assertIn("deactivated", u)
self.assertIn("displayname", u)
self.assertIn("avatar_url", u)
def test_search_term(self):
"""Test that searching for a users works correctly"""
@@ -539,7 +549,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Check that users were returned
self.assertTrue("users" in channel.json_body)
self._check_fields(channel.json_body["users"])
users = channel.json_body["users"]
# Check that the expected number of users were returned
@@ -552,30 +561,25 @@ class UsersListTestCase(unittest.HomeserverTestCase):
u = users[0]
self.assertEqual(expected_user_id, u["name"])
self._create_users(2)
user1 = "@user1:test"
user2 = "@user2:test"
# Perform search tests
_search_test(user1, "er1")
_search_test(user1, "me 1")
_search_test(self.user1, "er1")
_search_test(self.user1, "me 1")
_search_test(user2, "er2")
_search_test(user2, "me 2")
_search_test(self.user2, "er2")
_search_test(self.user2, "me 2")
_search_test(user1, "er1", "user_id")
_search_test(user2, "er2", "user_id")
_search_test(self.user1, "er1", "user_id")
_search_test(self.user2, "er2", "user_id")
# Test case insensitive
_search_test(user1, "ER1")
_search_test(user1, "NAME 1")
_search_test(self.user1, "ER1")
_search_test(self.user1, "NAME 1")
_search_test(user2, "ER2")
_search_test(user2, "NAME 2")
_search_test(self.user2, "ER2")
_search_test(self.user2, "NAME 2")
_search_test(user1, "ER1", "user_id")
_search_test(user2, "ER2", "user_id")
_search_test(self.user1, "ER1", "user_id")
_search_test(self.user2, "ER2", "user_id")
_search_test(None, "foo")
_search_test(None, "bar")
@@ -583,179 +587,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id")
def test_invalid_parameter(self):
"""
If parameters are invalid, an error is returned.
"""
# negative limit
channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
channel = self.make_request(
"GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid deactivated
channel = self.make_request(
"GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def test_limit(self):
"""
Testing list of users with limit
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
self._check_fields(channel.json_body["users"])
def test_from(self):
"""
Testing list of users with a defined starting point (from)
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["users"])
def test_limit_and_from(self):
"""
Testing list of users with a defined starting point and limit
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
self._check_fields(channel.json_body["users"])
def test_next_token(self):
"""
Testing that `next_token` appears at the right place
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
# `next_token` does not appear
# Number of results is the number of entries
channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
# `next_token` does not appear
# Number of max results is larger than the number of entries
channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
# `next_token` does appear
# Number of max results is smaller than the number of entries
channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
"""
for u in content:
self.assertIn("name", u)
self.assertIn("is_guest", u)
self.assertIn("admin", u)
self.assertIn("user_type", u)
self.assertIn("deactivated", u)
self.assertIn("displayname", u)
self.assertIn("avatar_url", u)
def _create_users(self, number_users: int):
"""
Create a number of users
Args:
number_users: Number of users to be created
"""
for i in range(1, number_users + 1):
self.register_user(
"user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
)
class DeactivateAccountTestCase(unittest.HomeserverTestCase):

View File

@@ -202,6 +202,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config()
config["media_store_path"] = self.media_store_path
config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000
provider_config = {
@@ -312,39 +313,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self):
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self):
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
def test_invalid_type(self):
"""An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
def test_no_thumbnail_crop(self):
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
self._test_thumbnail("crop", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
def test_no_thumbnail_scale(self):
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail("scale", None, False)
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
channel = make_request(

View File

@@ -92,3 +92,15 @@ class SortTopologically(TestCase):
# Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
# always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_duplicates(self):
"Test that a graph with duplicate edges work"
graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_multiple_paths(self):
"Test that a graph with multiple paths between two nodes work"
graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])

11
tox.ini
View File

@@ -24,8 +24,7 @@ deps =
# install the "enum34" dependency of cryptography.
pip>=10
# directories/files we run the linters on.
# if you update this list, make sure to do the same in scripts-dev/lint.sh
# directories/files we run the linters on
lint_targets =
setup.py
synapse
@@ -101,7 +100,7 @@ usedevelop=true
# A test suite for the oldest supported versions of Python libraries, to catch
# any uses of APIs not available in them.
[testenv:py35-{old,old-postgres}]
[testenv:py35-old]
skip_install=True
deps =
# Ensure a version of setuptools that supports Python 3.5 is installed.
@@ -114,17 +113,11 @@ deps =
coverage
coverage-enable-subprocess==1.0
setenv =
postgres: SYNAPSE_POSTGRES = 1
commands =
# Make all greater-thans equals so we test the oldest version of our direct
# dependencies, but make the pyopenssl 17.0, which can work against an
# OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis).
#
# Also strip out psycopg2 unless we need it.
/bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
postgres: /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" | grep -F "psycopg2" | xargs -d"\n" pip install'
# Install Synapse itself. This won't update any libraries.
pip install -e ".[test]"