mirror of
https://github.com/element-hq/synapse.git
synced 2025-12-11 01:40:27 +00:00
Compare commits
14 Commits
erikj/test
...
travis/fos
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a06dd1d6b5 | ||
|
|
69961c7e9f | ||
|
|
a01605c136 | ||
|
|
056327457f | ||
|
|
fc2cbce232 | ||
|
|
28f255d5f3 | ||
|
|
f7a03e86e0 | ||
|
|
d9867f1640 | ||
|
|
7d8cc63e37 | ||
|
|
19a4821ffc | ||
|
|
40f96320a2 | ||
|
|
e2377bba70 | ||
|
|
84204f8020 | ||
|
|
95d7074322 |
@@ -9,3 +9,5 @@ apt-get update
|
||||
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
|
||||
|
||||
export LANG="C.UTF-8"
|
||||
|
||||
exec tox -e py35-old,combine
|
||||
17
CHANGES.md
17
CHANGES.md
@@ -1,3 +1,20 @@
|
||||
Synapse 1.26.0rc2 (2021-01-25)
|
||||
==============================
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
|
||||
- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
|
||||
|
||||
|
||||
Internal Changes
|
||||
----------------
|
||||
|
||||
- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
|
||||
- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
|
||||
|
||||
|
||||
Synapse 1.26.0rc1 (2021-01-20)
|
||||
==============================
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Add tests to `test_user.UsersListTestCase` for List Users Admin API.
|
||||
@@ -1 +0,0 @@
|
||||
Various improvements to the federation client.
|
||||
@@ -1 +0,0 @@
|
||||
Add link to Matrix VoIP tester for turn-howto.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).
|
||||
1
changelog.d/9167.feature
Normal file
1
changelog.d/9167.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add server admin endpoints to join users to legacy groups and manage their flair.
|
||||
1
changelog.d/9168.feature
Normal file
1
changelog.d/9168.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add an admin API for retrieving the current room state of a room.
|
||||
@@ -1 +0,0 @@
|
||||
Speed up chain cover calculation when persisting a batch of state events at once.
|
||||
@@ -1 +0,0 @@
|
||||
Add a `long_description_type` to the package metadata.
|
||||
@@ -1 +0,0 @@
|
||||
Speed up batch insertion when using PostgreSQL.
|
||||
@@ -1 +0,0 @@
|
||||
Emit an error at startup if different Identity Providers are configured with the same `idp_id`.
|
||||
@@ -1 +0,0 @@
|
||||
Speed up batch insertion when using PostgreSQL.
|
||||
@@ -1 +0,0 @@
|
||||
Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration.
|
||||
@@ -1 +0,0 @@
|
||||
Improve performance of concurrent use of `StreamIDGenerators`.
|
||||
@@ -1 +0,0 @@
|
||||
Add some missing source directories to the automatic linting script.
|
||||
@@ -1 +0,0 @@
|
||||
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.
|
||||
@@ -1 +0,0 @@
|
||||
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.
|
||||
@@ -367,6 +367,36 @@ Response:
|
||||
}
|
||||
```
|
||||
|
||||
# Room State API
|
||||
|
||||
The Room State admin API allows server admins to get a list of all state events in a room.
|
||||
|
||||
The response includes the following fields:
|
||||
|
||||
* `state` - The current state of the room at the time of request.
|
||||
|
||||
## Usage
|
||||
|
||||
A standard request:
|
||||
|
||||
```
|
||||
GET /_synapse/admin/v1/rooms/<room_id>/state
|
||||
|
||||
{}
|
||||
```
|
||||
|
||||
Response:
|
||||
|
||||
```json
|
||||
{
|
||||
"state": [
|
||||
{"type": "m.room.create", "state_key": "", "etc": true},
|
||||
{"type": "m.room.power_levels", "state_key": "", "etc": true},
|
||||
{"type": "m.room.name", "state_key": "", "etc": true}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
# Delete Room API
|
||||
|
||||
The Delete Room admin API allows server admins to remove rooms from server
|
||||
|
||||
@@ -232,12 +232,6 @@ Here are a few things to try:
|
||||
|
||||
(Understanding the output is beyond the scope of this document!)
|
||||
|
||||
* You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
|
||||
Note that this test is not fully reliable yet, so don't be discouraged if
|
||||
the test fails.
|
||||
[Here](https://github.com/matrix-org/voip-tester) is the github repo of the
|
||||
source of the tester, where you can file bug reports.
|
||||
|
||||
* There is a WebRTC test tool at
|
||||
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
|
||||
use it, you will need a username/password for your TURN server. You can
|
||||
|
||||
@@ -80,8 +80,7 @@ else
|
||||
# then lint everything!
|
||||
if [[ -z ${files+x} ]]; then
|
||||
# Lint all source code files and directories
|
||||
# Note: this list aims the mirror the one in tox.ini
|
||||
files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
|
||||
files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
1
setup.py
1
setup.py
@@ -121,7 +121,6 @@ setup(
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/x-rst",
|
||||
python_requires="~=3.5",
|
||||
classifiers=[
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
|
||||
@@ -48,7 +48,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.26.0rc1"
|
||||
__version__ = "1.26.0rc2"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import string
|
||||
from collections import Counter
|
||||
from typing import Iterable, Optional, Tuple, Type
|
||||
|
||||
import attr
|
||||
@@ -44,16 +43,6 @@ class OIDCConfig(Config):
|
||||
except DependencyException as e:
|
||||
raise ConfigError(e.message) from e
|
||||
|
||||
# check we don't have any duplicate idp_ids now. (The SSO handler will also
|
||||
# check for duplicates when the REST listeners get registered, but that happens
|
||||
# after synapse has forked so doesn't give nice errors.)
|
||||
c = Counter([i.idp_id for i in self.oidc_providers])
|
||||
for idp_id, count in c.items():
|
||||
if count > 1:
|
||||
raise ConfigError(
|
||||
"Multiple OIDC providers have the idp_id %r." % idp_id
|
||||
)
|
||||
|
||||
public_baseurl = self.public_baseurl
|
||||
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ import copy
|
||||
import itertools
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
@@ -27,6 +26,7 @@ from typing import (
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
@@ -61,9 +61,6 @@ from synapse.util import unwrapFirstError
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
|
||||
@@ -83,10 +80,10 @@ class InvalidResponseError(RuntimeError):
|
||||
|
||||
|
||||
class FederationClient(FederationBase):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
def __init__(self, hs):
|
||||
super().__init__(hs)
|
||||
|
||||
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
|
||||
self.pdu_destination_tried = {}
|
||||
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
|
||||
self.state = hs.get_state_handler()
|
||||
self.transport_layer = hs.get_federation_transport_client()
|
||||
@@ -119,32 +116,33 @@ class FederationClient(FederationBase):
|
||||
self.pdu_destination_tried[event_id] = destination_dict
|
||||
|
||||
@log_function
|
||||
async def make_query(
|
||||
def make_query(
|
||||
self,
|
||||
destination: str,
|
||||
query_type: str,
|
||||
args: dict,
|
||||
retry_on_dns_fail: bool = False,
|
||||
ignore_backoff: bool = False,
|
||||
) -> JsonDict:
|
||||
destination,
|
||||
query_type,
|
||||
args,
|
||||
retry_on_dns_fail=False,
|
||||
ignore_backoff=False,
|
||||
):
|
||||
"""Sends a federation Query to a remote homeserver of the given type
|
||||
and arguments.
|
||||
|
||||
Args:
|
||||
destination: Domain name of the remote homeserver
|
||||
query_type: Category of the query type; should match the
|
||||
destination (str): Domain name of the remote homeserver
|
||||
query_type (str): Category of the query type; should match the
|
||||
handler name used in register_query_handler().
|
||||
args: Mapping of strings to strings containing the details
|
||||
args (dict): Mapping of strings to strings containing the details
|
||||
of the query request.
|
||||
ignore_backoff: true to ignore the historical backoff data
|
||||
ignore_backoff (bool): true to ignore the historical backoff data
|
||||
and try the request anyway.
|
||||
|
||||
Returns:
|
||||
The JSON object from the response
|
||||
a Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
"""
|
||||
sent_queries_counter.labels(query_type).inc()
|
||||
|
||||
return await self.transport_layer.make_query(
|
||||
return self.transport_layer.make_query(
|
||||
destination,
|
||||
query_type,
|
||||
args,
|
||||
@@ -153,52 +151,42 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
@log_function
|
||||
async def query_client_keys(
|
||||
self, destination: str, content: JsonDict, timeout: int
|
||||
) -> JsonDict:
|
||||
def query_client_keys(self, destination, content, timeout):
|
||||
"""Query device keys for a device hosted on a remote server.
|
||||
|
||||
Args:
|
||||
destination: Domain name of the remote homeserver
|
||||
content: The query content.
|
||||
destination (str): Domain name of the remote homeserver
|
||||
content (dict): The query content.
|
||||
|
||||
Returns:
|
||||
The JSON object from the response
|
||||
an Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
"""
|
||||
sent_queries_counter.labels("client_device_keys").inc()
|
||||
return await self.transport_layer.query_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
return self.transport_layer.query_client_keys(destination, content, timeout)
|
||||
|
||||
@log_function
|
||||
async def query_user_devices(
|
||||
self, destination: str, user_id: str, timeout: int = 30000
|
||||
) -> JsonDict:
|
||||
def query_user_devices(self, destination, user_id, timeout=30000):
|
||||
"""Query the device keys for a list of user ids hosted on a remote
|
||||
server.
|
||||
"""
|
||||
sent_queries_counter.labels("user_devices").inc()
|
||||
return await self.transport_layer.query_user_devices(
|
||||
destination, user_id, timeout
|
||||
)
|
||||
return self.transport_layer.query_user_devices(destination, user_id, timeout)
|
||||
|
||||
@log_function
|
||||
async def claim_client_keys(
|
||||
self, destination: str, content: JsonDict, timeout: int
|
||||
) -> JsonDict:
|
||||
def claim_client_keys(self, destination, content, timeout):
|
||||
"""Claims one-time keys for a device hosted on a remote server.
|
||||
|
||||
Args:
|
||||
destination: Domain name of the remote homeserver
|
||||
content: The query content.
|
||||
destination (str): Domain name of the remote homeserver
|
||||
content (dict): The query content.
|
||||
|
||||
Returns:
|
||||
The JSON object from the response
|
||||
an Awaitable which will eventually yield a JSON object from the
|
||||
response
|
||||
"""
|
||||
sent_queries_counter.labels("client_one_time_keys").inc()
|
||||
return await self.transport_layer.claim_client_keys(
|
||||
destination, content, timeout
|
||||
)
|
||||
return self.transport_layer.claim_client_keys(destination, content, timeout)
|
||||
|
||||
async def backfill(
|
||||
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
|
||||
@@ -207,10 +195,10 @@ class FederationClient(FederationBase):
|
||||
given destination server.
|
||||
|
||||
Args:
|
||||
dest: The remote homeserver to ask.
|
||||
room_id: The room_id to backfill.
|
||||
limit: The maximum number of events to return.
|
||||
extremities: our current backwards extremities, to backfill from
|
||||
dest (str): The remote homeserver to ask.
|
||||
room_id (str): The room_id to backfill.
|
||||
limit (int): The maximum number of events to return.
|
||||
extremities (list): our current backwards extremities, to backfill from
|
||||
"""
|
||||
logger.debug("backfill extrem=%s", extremities)
|
||||
|
||||
@@ -382,7 +370,7 @@ class FederationClient(FederationBase):
|
||||
for events that have failed their checks
|
||||
|
||||
Returns:
|
||||
A list of PDUs that have valid signatures and hashes.
|
||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||
"""
|
||||
deferreds = self._check_sigs_and_hashes(room_version, pdus)
|
||||
|
||||
@@ -430,9 +418,7 @@ class FederationClient(FederationBase):
|
||||
else:
|
||||
return [p for p in valid_pdus if p]
|
||||
|
||||
async def get_event_auth(
|
||||
self, destination: str, room_id: str, event_id: str
|
||||
) -> List[EventBase]:
|
||||
async def get_event_auth(self, destination, room_id, event_id):
|
||||
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
@@ -714,16 +700,18 @@ class FederationClient(FederationBase):
|
||||
|
||||
return await self._try_destination_list("send_join", destinations, send_request)
|
||||
|
||||
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||
async def _do_send_join(self, destination: str, pdu: EventBase):
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
return await self.transport_layer.send_join_v2(
|
||||
content = await self.transport_layer.send_join_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
@@ -781,7 +769,7 @@ class FederationClient(FederationBase):
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
return await self.transport_layer.send_invite_v2(
|
||||
content = await self.transport_layer.send_invite_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
@@ -791,6 +779,7 @@ class FederationClient(FederationBase):
|
||||
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
|
||||
},
|
||||
)
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
@@ -853,16 +842,18 @@ class FederationClient(FederationBase):
|
||||
"send_leave", destinations, send_request
|
||||
)
|
||||
|
||||
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
|
||||
async def _do_send_leave(self, destination, pdu):
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
return await self.transport_layer.send_leave_v2(
|
||||
content = await self.transport_layer.send_leave_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
@@ -888,7 +879,7 @@ class FederationClient(FederationBase):
|
||||
# content.
|
||||
return resp[1]
|
||||
|
||||
async def get_public_rooms(
|
||||
def get_public_rooms(
|
||||
self,
|
||||
remote_server: str,
|
||||
limit: Optional[int] = None,
|
||||
@@ -896,7 +887,7 @@ class FederationClient(FederationBase):
|
||||
search_filter: Optional[Dict] = None,
|
||||
include_all_networks: bool = False,
|
||||
third_party_instance_id: Optional[str] = None,
|
||||
) -> JsonDict:
|
||||
):
|
||||
"""Get the list of public rooms from a remote homeserver
|
||||
|
||||
Args:
|
||||
@@ -910,7 +901,8 @@ class FederationClient(FederationBase):
|
||||
party instance
|
||||
|
||||
Returns:
|
||||
The response from the remote server.
|
||||
Awaitable[Dict[str, Any]]: The response from the remote server, or None if
|
||||
`remote_server` is the same as the local server_name
|
||||
|
||||
Raises:
|
||||
HttpResponseException: There was an exception returned from the remote server
|
||||
@@ -918,7 +910,7 @@ class FederationClient(FederationBase):
|
||||
requests over federation
|
||||
|
||||
"""
|
||||
return await self.transport_layer.get_public_rooms(
|
||||
return self.transport_layer.get_public_rooms(
|
||||
remote_server,
|
||||
limit,
|
||||
since_token,
|
||||
@@ -931,7 +923,7 @@ class FederationClient(FederationBase):
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
earliest_events_ids: Iterable[str],
|
||||
earliest_events_ids: Sequence[str],
|
||||
latest_events: Iterable[EventBase],
|
||||
limit: int,
|
||||
min_depth: int,
|
||||
@@ -982,9 +974,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
return signed_events
|
||||
|
||||
async def forward_third_party_invite(
|
||||
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
|
||||
) -> None:
|
||||
async def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
@@ -993,7 +983,7 @@ class FederationClient(FederationBase):
|
||||
await self.transport_layer.exchange_third_party_invite(
|
||||
destination=destination, room_id=room_id, event_dict=event_dict
|
||||
)
|
||||
return
|
||||
return None
|
||||
except CodeMessageException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -1005,7 +995,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
async def get_room_complexity(
|
||||
self, destination: str, room_id: str
|
||||
) -> Optional[JsonDict]:
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Fetch the complexity of a remote room from another server.
|
||||
|
||||
@@ -1018,9 +1008,10 @@ class FederationClient(FederationBase):
|
||||
could not fetch the complexity.
|
||||
"""
|
||||
try:
|
||||
return await self.transport_layer.get_room_complexity(
|
||||
complexity = await self.transport_layer.get_room_complexity(
|
||||
destination=destination, room_id=room_id
|
||||
)
|
||||
return complexity
|
||||
except CodeMessageException as e:
|
||||
# We didn't manage to get it -- probably a 404. We are okay if other
|
||||
# servers don't give it to us.
|
||||
|
||||
@@ -365,6 +365,32 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
return {}
|
||||
|
||||
async def force_join_user_to_group(self, group_id, user_id):
|
||||
"""Forces a user to join a group.
|
||||
"""
|
||||
if not self.is_mine_id(group_id):
|
||||
raise SynapseError(400, "Can only affect local groups")
|
||||
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(400, "Can only affect local users")
|
||||
|
||||
# Bypass the group server to avoid business logic regarding whether or not
|
||||
# the user can actually join.
|
||||
await self.store.add_user_to_group(group_id, user_id)
|
||||
|
||||
token = await self.store.register_user_group_membership(
|
||||
group_id,
|
||||
user_id,
|
||||
membership="join",
|
||||
is_admin=False,
|
||||
local_attestation=None,
|
||||
remote_attestation=None,
|
||||
is_publicised=False,
|
||||
)
|
||||
self.notifier.on_new_event("groups_key", token, users=[user_id])
|
||||
|
||||
return {}
|
||||
|
||||
async def accept_invite(self, group_id, user_id, content):
|
||||
"""Accept an invite to a group
|
||||
"""
|
||||
|
||||
@@ -174,7 +174,7 @@ class MessageHandler:
|
||||
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
||||
|
||||
visible_events = await filter_events_for_client(
|
||||
self.storage, user_id, last_events, filter_send_to_client=False
|
||||
self.storage, user_id, last_events, filter_send_to_client=False,
|
||||
)
|
||||
|
||||
event = last_events[0]
|
||||
|
||||
@@ -86,8 +86,8 @@ REQUIREMENTS = [
|
||||
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
|
||||
# we use execute_batch, which arrived in psycopg 2.7.
|
||||
"postgres": ["psycopg2>=2.7"],
|
||||
# we use execute_values with the fetch param, which arrived in psycopg 2.8.
|
||||
"postgres": ["psycopg2>=2.8"],
|
||||
# ACME support is required to provision TLS certificates from authorities
|
||||
# that use the protocol, such as Let's Encrypt.
|
||||
"acme": [
|
||||
|
||||
@@ -31,7 +31,11 @@ from synapse.rest.admin.event_reports import (
|
||||
EventReportDetailRestServlet,
|
||||
EventReportsRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
|
||||
from synapse.rest.admin.groups import (
|
||||
DeleteGroupAdminRestServlet,
|
||||
ForceJoinGroupAdminRestServlet,
|
||||
UpdatePublicityGroupAdminRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
|
||||
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
|
||||
from synapse.rest.admin.rooms import (
|
||||
@@ -41,6 +45,7 @@ from synapse.rest.admin.rooms import (
|
||||
MakeRoomAdminRestServlet,
|
||||
RoomMembersRestServlet,
|
||||
RoomRestServlet,
|
||||
RoomStateRestServlet,
|
||||
ShutdownRoomRestServlet,
|
||||
)
|
||||
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
|
||||
@@ -209,6 +214,7 @@ def register_servlets(hs, http_server):
|
||||
"""
|
||||
register_servlets_for_client_rest_resource(hs, http_server)
|
||||
ListRoomRestServlet(hs).register(http_server)
|
||||
RoomStateRestServlet(hs).register(http_server)
|
||||
RoomRestServlet(hs).register(http_server)
|
||||
RoomMembersRestServlet(hs).register(http_server)
|
||||
DeleteRoomRestServlet(hs).register(http_server)
|
||||
@@ -244,6 +250,8 @@ def register_servlets_for_client_rest_resource(hs, http_server):
|
||||
ShutdownRoomRestServlet(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
DeleteGroupAdminRestServlet(hs).register(http_server)
|
||||
ForceJoinGroupAdminRestServlet(hs).register(http_server)
|
||||
UpdatePublicityGroupAdminRestServlet(hs).register(http_server)
|
||||
AccountValidityRenewServlet(hs).register(http_server)
|
||||
|
||||
# Load the media repo ones if we're using them. Otherwise load the servlets which
|
||||
|
||||
@@ -15,7 +15,11 @@
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -41,3 +45,57 @@ class DeleteGroupAdminRestServlet(RestServlet):
|
||||
|
||||
await self.group_server.delete_group(group_id, requester.user.to_string())
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ForceJoinGroupAdminRestServlet(RestServlet):
|
||||
"""Allows a server admin to force-join a local user to a local group.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/group/(?P<group_id>[^/]*)/force_join$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
if not self.is_mine_id(group_id):
|
||||
raise SynapseError(400, "Can only affect local groups")
|
||||
|
||||
body = parse_json_object_from_request(request, allow_empty_body=False)
|
||||
assert_params_in_dict(body, ["user_id"])
|
||||
target_user_id = body["user_id"]
|
||||
await self.groups_handler.force_join_user_to_group(group_id, target_user_id)
|
||||
|
||||
return 200, {}
|
||||
|
||||
|
||||
class UpdatePublicityGroupAdminRestServlet(RestServlet):
|
||||
"""Allows a server admin to update a user's publicity (flair) for a given group.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/group/(?P<group_id>[^/]*)/update_publicity$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, group_id):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
body = parse_json_object_from_request(request, allow_empty_body=False)
|
||||
assert_params_in_dict(body, ["user_id"])
|
||||
target_user_id = body["user_id"]
|
||||
if not self.is_mine_id(target_user_id):
|
||||
raise SynapseError(400, "Can only affect local users")
|
||||
|
||||
# Logic copied from `/self/update_publicity` endpoint.
|
||||
publicise = body["publicise"]
|
||||
await self.store.update_group_publicity(group_id, target_user_id, publicise)
|
||||
|
||||
return 200, {}
|
||||
|
||||
@@ -292,6 +292,45 @@ class RoomMembersRestServlet(RestServlet):
|
||||
return 200, ret
|
||||
|
||||
|
||||
class RoomStateRestServlet(RestServlet):
|
||||
"""
|
||||
Get full state within a room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
if not ret:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
event_ids = await self.store.get_current_state_ids(room_id)
|
||||
events = await self.store.get_events(event_ids.values())
|
||||
now = self.clock.time_msec()
|
||||
room_state = await self._event_serializer.serialize_events(
|
||||
events.values(),
|
||||
now,
|
||||
# We don't bother bundling aggregations in when asked for state
|
||||
# events, as clients won't use them.
|
||||
bundle_aggregations=False,
|
||||
)
|
||||
ret = {"state": room_state}
|
||||
|
||||
return 200, ret
|
||||
|
||||
|
||||
class JoinRoomAliasServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
|
||||
@@ -83,32 +83,17 @@ class UsersRestServletV2(RestServlet):
|
||||
The parameter `deactivated` can be used to include deactivated users.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(self, request):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
|
||||
if start < 0:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Query parameter from must be a string representing a positive integer.",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if limit < 0:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Query parameter limit must be a string representing a positive integer.",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
user_id = parse_string(request, "user_id", default=None)
|
||||
name = parse_string(request, "name", default=None)
|
||||
guests = parse_boolean(request, "guests", default=True)
|
||||
@@ -118,7 +103,7 @@ class UsersRestServletV2(RestServlet):
|
||||
start, limit, user_id, name, guests, deactivated
|
||||
)
|
||||
ret = {"users": users, "total": total}
|
||||
if (start + limit) < total:
|
||||
if len(users) >= limit:
|
||||
ret["next_token"] = str(start + len(users))
|
||||
|
||||
return 200, ret
|
||||
|
||||
@@ -300,7 +300,6 @@ class FileInfo:
|
||||
thumbnail_height (int)
|
||||
thumbnail_method (str)
|
||||
thumbnail_type (str): Content type of thumbnail, e.g. image/png
|
||||
thumbnail_length (int): The size of the media file, in bytes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -313,7 +312,6 @@ class FileInfo:
|
||||
thumbnail_height=None,
|
||||
thumbnail_method=None,
|
||||
thumbnail_type=None,
|
||||
thumbnail_length=None,
|
||||
):
|
||||
self.server_name = server_name
|
||||
self.file_id = file_id
|
||||
@@ -323,7 +321,6 @@ class FileInfo:
|
||||
self.thumbnail_height = thumbnail_height
|
||||
self.thumbnail_method = thumbnail_method
|
||||
self.thumbnail_type = thumbnail_type
|
||||
self.thumbnail_length = thumbnail_length
|
||||
|
||||
|
||||
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
@@ -106,17 +106,31 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
return
|
||||
|
||||
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
|
||||
await self._select_and_respond_with_thumbnail(
|
||||
request,
|
||||
width,
|
||||
height,
|
||||
method,
|
||||
m_type,
|
||||
thumbnail_infos,
|
||||
media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
server_name=None,
|
||||
)
|
||||
|
||||
if thumbnail_infos:
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||
)
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
await respond_with_responder(request, responder, t_type, t_length)
|
||||
else:
|
||||
logger.info("Couldn't find any generated thumbnails")
|
||||
respond_404(request)
|
||||
|
||||
async def _select_or_generate_local_thumbnail(
|
||||
self,
|
||||
@@ -262,64 +276,26 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id
|
||||
)
|
||||
await self._select_and_respond_with_thumbnail(
|
||||
request,
|
||||
width,
|
||||
height,
|
||||
method,
|
||||
m_type,
|
||||
thumbnail_infos,
|
||||
media_info["filesystem_id"],
|
||||
url_cache=None,
|
||||
server_name=server_name,
|
||||
)
|
||||
|
||||
async def _select_and_respond_with_thumbnail(
|
||||
self,
|
||||
request: Request,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos: List[Dict[str, Any]],
|
||||
file_id: str,
|
||||
url_cache: Optional[str] = None,
|
||||
server_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||
desired_method: The desired method used to generate the thumbnail.
|
||||
desired_type: The desired content-type of the thumbnail.
|
||||
thumbnail_infos: A list of dictionaries of candidate thumbnails.
|
||||
file_id: The ID of the media that a thumbnail is being requested for.
|
||||
url_cache: The URL cache value.
|
||||
server_name: The server name, if this is a remote thumbnail.
|
||||
"""
|
||||
if thumbnail_infos:
|
||||
file_info = self._select_thumbnail(
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
thumbnail_infos,
|
||||
file_id,
|
||||
url_cache,
|
||||
server_name,
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
if not file_info:
|
||||
logger.info("Couldn't find a thumbnail matching the desired inputs")
|
||||
respond_404(request)
|
||||
return
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=media_info["filesystem_id"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||
)
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
responder = await self.media_storage.fetch_media(file_info)
|
||||
await respond_with_responder(
|
||||
request, responder, file_info.thumbnail_type, file_info.thumbnail_length
|
||||
)
|
||||
await respond_with_responder(request, responder, t_type, t_length)
|
||||
else:
|
||||
logger.info("Failed to find any generated thumbnails")
|
||||
respond_404(request)
|
||||
@@ -330,117 +306,67 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos: List[Dict[str, Any]],
|
||||
file_id: str,
|
||||
url_cache: Optional[str],
|
||||
server_name: Optional[str],
|
||||
) -> Optional[FileInfo]:
|
||||
"""
|
||||
Choose an appropriate thumbnail from the previously generated thumbnails.
|
||||
|
||||
Args:
|
||||
desired_width: The desired width, the returned thumbnail may be larger than this.
|
||||
desired_height: The desired height, the returned thumbnail may be larger than this.
|
||||
desired_method: The desired method used to generate the thumbnail.
|
||||
desired_type: The desired content-type of the thumbnail.
|
||||
thumbnail_infos: A list of dictionaries of candidate thumbnails.
|
||||
file_id: The ID of the media that a thumbnail is being requested for.
|
||||
url_cache: The URL cache value.
|
||||
server_name: The server name, if this is a remote thumbnail.
|
||||
|
||||
Returns:
|
||||
The thumbnail which best matches the desired parameters.
|
||||
"""
|
||||
desired_method = desired_method.lower()
|
||||
|
||||
# The chosen thumbnail.
|
||||
thumbnail_info = None
|
||||
|
||||
thumbnail_infos,
|
||||
) -> dict:
|
||||
d_w = desired_width
|
||||
d_h = desired_height
|
||||
|
||||
if desired_method == "crop":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
if desired_method.lower() == "crop":
|
||||
crop_info_list = []
|
||||
# Other thumbnails.
|
||||
crop_info_list2 = []
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
if info["thumbnail_method"] != "crop":
|
||||
continue
|
||||
|
||||
t_w = info["thumbnail_width"]
|
||||
t_h = info["thumbnail_height"]
|
||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
crop_info_list.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
t_method = info["thumbnail_method"]
|
||||
if t_method == "crop":
|
||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
crop_info_list.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
crop_info_list2.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
else:
|
||||
crop_info_list2.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
)
|
||||
if crop_info_list:
|
||||
thumbnail_info = min(crop_info_list)[-1]
|
||||
elif crop_info_list2:
|
||||
thumbnail_info = min(crop_info_list2)[-1]
|
||||
elif desired_method == "scale":
|
||||
# Thumbnails that match equal or larger sizes of desired width/height.
|
||||
return min(crop_info_list)[-1]
|
||||
else:
|
||||
return min(crop_info_list2)[-1]
|
||||
else:
|
||||
info_list = []
|
||||
# Other thumbnails.
|
||||
info_list2 = []
|
||||
|
||||
for info in thumbnail_infos:
|
||||
# Skip thumbnails generated with different methods.
|
||||
if info["thumbnail_method"] != "scale":
|
||||
continue
|
||||
|
||||
t_w = info["thumbnail_width"]
|
||||
t_h = info["thumbnail_height"]
|
||||
t_method = info["thumbnail_method"]
|
||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
|
||||
info_list.append((size_quality, type_quality, length_quality, info))
|
||||
else:
|
||||
elif t_method == "scale":
|
||||
info_list2.append(
|
||||
(size_quality, type_quality, length_quality, info)
|
||||
)
|
||||
if info_list:
|
||||
thumbnail_info = min(info_list)[-1]
|
||||
elif info_list2:
|
||||
thumbnail_info = min(info_list2)[-1]
|
||||
|
||||
if thumbnail_info:
|
||||
return FileInfo(
|
||||
file_id=file_id,
|
||||
url_cache=url_cache,
|
||||
server_name=server_name,
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||
thumbnail_length=thumbnail_info["thumbnail_length"],
|
||||
)
|
||||
|
||||
# No matching thumbnail was found.
|
||||
return None
|
||||
return min(info_list)[-1]
|
||||
else:
|
||||
return min(info_list2)[-1]
|
||||
|
||||
@@ -262,18 +262,13 @@ class LoggingTransaction:
|
||||
return self.txn.description
|
||||
|
||||
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
|
||||
"""Similar to `executemany`, except `txn.rowcount` will not be correct
|
||||
afterwards.
|
||||
|
||||
More efficient than `executemany` on PostgreSQL
|
||||
"""
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import execute_batch # type: ignore
|
||||
|
||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||
else:
|
||||
self.executemany(sql, args)
|
||||
for val in args:
|
||||
self.execute(sql, val)
|
||||
|
||||
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
|
||||
"""Corresponds to psycopg2.extras.execute_values. Only available when
|
||||
@@ -893,7 +888,7 @@ class DatabasePool:
|
||||
", ".join("?" for _ in keys[0]),
|
||||
)
|
||||
|
||||
txn.execute_batch(sql, vals)
|
||||
txn.executemany(sql, vals)
|
||||
|
||||
async def simple_upsert(
|
||||
self,
|
||||
|
||||
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
DELETE FROM device_lists_outbound_last_success
|
||||
WHERE destination = ? AND user_id = ?
|
||||
"""
|
||||
txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
|
||||
txn.executemany(sql, ((row[0], row[1]) for row in rows))
|
||||
|
||||
logger.info("Pruned %d device list outbound pokes", count)
|
||||
|
||||
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
|
||||
# Delete older entries in the table, as we really only care about
|
||||
# when the latest change happened.
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"""
|
||||
DELETE FROM device_lists_stream
|
||||
WHERE user_id = ? AND device_id = ? AND stream_id < ?
|
||||
|
||||
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
sql,
|
||||
(
|
||||
_gen_entry(user_id, actions)
|
||||
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
],
|
||||
)
|
||||
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"""
|
||||
UPDATE event_push_summary
|
||||
SET notif_count = ?, unread_count = ?, stream_ordering = ?
|
||||
|
||||
@@ -473,9 +473,8 @@ class PersistEventsStore:
|
||||
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@staticmethod
|
||||
def _add_chain_cover_index(
|
||||
cls,
|
||||
txn,
|
||||
db_pool: DatabasePool,
|
||||
event_to_room_id: Dict[str, str],
|
||||
@@ -615,17 +614,60 @@ class PersistEventsStore:
|
||||
if not events_to_calc_chain_id_for:
|
||||
return
|
||||
|
||||
# Allocate chain ID/sequence numbers to each new event.
|
||||
new_chain_tuples = cls._allocate_chain_ids(
|
||||
txn,
|
||||
db_pool,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
events_to_calc_chain_id_for,
|
||||
chain_map,
|
||||
)
|
||||
chain_map.update(new_chain_tuples)
|
||||
# We now calculate the chain IDs/sequence numbers for the events. We
|
||||
# do this by looking at the chain ID and sequence number of any auth
|
||||
# event with the same type/state_key and incrementing the sequence
|
||||
# number by one. If there was no match or the chain ID/sequence
|
||||
# number is already taken we generate a new chain.
|
||||
#
|
||||
# We need to do this in a topologically sorted order as we want to
|
||||
# generate chain IDs/sequence numbers of an event's auth events
|
||||
# before the event itself.
|
||||
chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
|
||||
new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
|
||||
for event_id in sorted_topologically(
|
||||
events_to_calc_chain_id_for, event_to_auth_chain
|
||||
):
|
||||
existing_chain_id = None
|
||||
for auth_id in event_to_auth_chain.get(event_id, []):
|
||||
if event_to_types.get(event_id) == event_to_types.get(auth_id):
|
||||
existing_chain_id = chain_map[auth_id]
|
||||
break
|
||||
|
||||
new_chain_tuple = None
|
||||
if existing_chain_id:
|
||||
# We found a chain ID/sequence number candidate, check its
|
||||
# not already taken.
|
||||
proposed_new_id = existing_chain_id[0]
|
||||
proposed_new_seq = existing_chain_id[1] + 1
|
||||
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
|
||||
already_allocated = db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="event_auth_chains",
|
||||
keyvalues={
|
||||
"chain_id": proposed_new_id,
|
||||
"sequence_number": proposed_new_seq,
|
||||
},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
)
|
||||
if already_allocated:
|
||||
# Mark it as already allocated so we don't need to hit
|
||||
# the DB again.
|
||||
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
|
||||
else:
|
||||
new_chain_tuple = (
|
||||
proposed_new_id,
|
||||
proposed_new_seq,
|
||||
)
|
||||
|
||||
if not new_chain_tuple:
|
||||
new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
|
||||
|
||||
chains_tuples_allocated.add(new_chain_tuple)
|
||||
|
||||
chain_map[event_id] = new_chain_tuple
|
||||
new_chain_tuples[event_id] = new_chain_tuple
|
||||
|
||||
db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
@@ -752,137 +794,6 @@ class PersistEventsStore:
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _allocate_chain_ids(
|
||||
txn,
|
||||
db_pool: DatabasePool,
|
||||
event_to_room_id: Dict[str, str],
|
||||
event_to_types: Dict[str, Tuple[str, str]],
|
||||
event_to_auth_chain: Dict[str, List[str]],
|
||||
events_to_calc_chain_id_for: Set[str],
|
||||
chain_map: Dict[str, Tuple[int, int]],
|
||||
) -> Dict[str, Tuple[int, int]]:
|
||||
"""Allocates, but does not persist, chain ID/sequence numbers for the
|
||||
events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
|
||||
for info on args)
|
||||
"""
|
||||
|
||||
# We now calculate the chain IDs/sequence numbers for the events. We do
|
||||
# this by looking at the chain ID and sequence number of any auth event
|
||||
# with the same type/state_key and incrementing the sequence number by
|
||||
# one. If there was no match or the chain ID/sequence number is already
|
||||
# taken we generate a new chain.
|
||||
#
|
||||
# We try to reduce the number of times that we hit the database by
|
||||
# batching up calls, to make this more efficient when persisting large
|
||||
# numbers of state events (e.g. during joins).
|
||||
#
|
||||
# We do this by:
|
||||
# 1. Calculating for each event which auth event will be used to
|
||||
# inherit the chain ID, i.e. converting the auth chain graph to a
|
||||
# tree that we can allocate chains on. We also keep track of which
|
||||
# existing chain IDs have been referenced.
|
||||
# 2. Fetching the max allocated sequence number for each referenced
|
||||
# existing chain ID, generating a map from chain ID to the max
|
||||
# allocated sequence number.
|
||||
# 3. Iterating over the tree and allocating a chain ID/seq no. to the
|
||||
# new event, by incrementing the sequence number from the
|
||||
# referenced event's chain ID/seq no. and checking that the
|
||||
# incremented sequence number hasn't already been allocated (by
|
||||
# looking in the map generated in the previous step). We generate a
|
||||
# new chain if the sequence number has already been allocated.
|
||||
#
|
||||
|
||||
existing_chains = set() # type: Set[int]
|
||||
tree = [] # type: List[Tuple[str, Optional[str]]]
|
||||
|
||||
# We need to do this in a topologically sorted order as we want to
|
||||
# generate chain IDs/sequence numbers of an event's auth events before
|
||||
# the event itself.
|
||||
for event_id in sorted_topologically(
|
||||
events_to_calc_chain_id_for, event_to_auth_chain
|
||||
):
|
||||
for auth_id in event_to_auth_chain.get(event_id, []):
|
||||
if event_to_types.get(event_id) == event_to_types.get(auth_id):
|
||||
existing_chain_id = chain_map.get(auth_id)
|
||||
if existing_chain_id:
|
||||
existing_chains.add(existing_chain_id[0])
|
||||
|
||||
tree.append((event_id, auth_id))
|
||||
break
|
||||
else:
|
||||
tree.append((event_id, None))
|
||||
|
||||
# Fetch the current max sequence number for each existing referenced chain.
|
||||
sql = """
|
||||
SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
|
||||
WHERE %s
|
||||
GROUP BY chain_id
|
||||
"""
|
||||
clause, args = make_in_list_sql_clause(
|
||||
db_pool.engine, "chain_id", existing_chains
|
||||
)
|
||||
txn.execute(sql % (clause,), args)
|
||||
|
||||
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
|
||||
|
||||
# Allocate the new events chain ID/sequence numbers.
|
||||
#
|
||||
# To reduce the number of calls to the database we don't allocate a
|
||||
# chain ID number in the loop, instead we use a temporary `object()` for
|
||||
# each new chain ID. Once we've done the loop we generate the necessary
|
||||
# number of new chain IDs in one call, replacing all temporary
|
||||
# objects with real allocated chain IDs.
|
||||
|
||||
unallocated_chain_ids = set() # type: Set[object]
|
||||
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
|
||||
for event_id, auth_event_id in tree:
|
||||
# If we reference an auth_event_id we fetch the allocated chain ID,
|
||||
# either from the existing `chain_map` or the newly generated
|
||||
# `new_chain_tuples` map.
|
||||
existing_chain_id = None
|
||||
if auth_event_id:
|
||||
existing_chain_id = new_chain_tuples.get(auth_event_id)
|
||||
if not existing_chain_id:
|
||||
existing_chain_id = chain_map[auth_event_id]
|
||||
|
||||
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
|
||||
if existing_chain_id:
|
||||
# We found a chain ID/sequence number candidate, check its
|
||||
# not already taken.
|
||||
proposed_new_id = existing_chain_id[0]
|
||||
proposed_new_seq = existing_chain_id[1] + 1
|
||||
|
||||
if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
|
||||
new_chain_tuple = (
|
||||
proposed_new_id,
|
||||
proposed_new_seq,
|
||||
)
|
||||
|
||||
# If we need to start a new chain we allocate a temporary chain ID.
|
||||
if not new_chain_tuple:
|
||||
new_chain_tuple = (object(), 1)
|
||||
unallocated_chain_ids.add(new_chain_tuple[0])
|
||||
|
||||
new_chain_tuples[event_id] = new_chain_tuple
|
||||
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
|
||||
|
||||
# Generate new chain IDs for all unallocated chain IDs.
|
||||
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
|
||||
txn, len(unallocated_chain_ids)
|
||||
)
|
||||
|
||||
# Map from potentially temporary chain ID to real chain ID
|
||||
chain_id_to_allocated_map = dict(
|
||||
zip(unallocated_chain_ids, newly_allocated_chain_ids)
|
||||
) # type: Dict[Any, int]
|
||||
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
|
||||
|
||||
return {
|
||||
event_id: (chain_id_to_allocated_map[chain_id], seq)
|
||||
for event_id, (chain_id, seq) in new_chain_tuples.items()
|
||||
}
|
||||
|
||||
def _persist_transaction_ids_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
@@ -965,7 +876,7 @@ class PersistEventsStore:
|
||||
WHERE room_id = ? AND type = ? AND state_key = ?
|
||||
)
|
||||
"""
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
sql,
|
||||
(
|
||||
(
|
||||
@@ -984,7 +895,7 @@ class PersistEventsStore:
|
||||
)
|
||||
# Now we actually update the current_state_events table
|
||||
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"DELETE FROM current_state_events"
|
||||
" WHERE room_id = ? AND type = ? AND state_key = ?",
|
||||
(
|
||||
@@ -996,7 +907,7 @@ class PersistEventsStore:
|
||||
# We include the membership in the current state table, hence we do
|
||||
# a lookup when we insert. This assumes that all events have already
|
||||
# been inserted into room_memberships.
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"""INSERT INTO current_state_events
|
||||
(room_id, type, state_key, event_id, membership)
|
||||
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
||||
@@ -1016,7 +927,7 @@ class PersistEventsStore:
|
||||
# we have no record of the fact the user *was* a member of the
|
||||
# room but got, say, state reset out of it.
|
||||
if to_delete or to_insert:
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"DELETE FROM local_current_membership"
|
||||
" WHERE room_id = ? AND user_id = ?",
|
||||
(
|
||||
@@ -1027,7 +938,7 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
if to_insert:
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"""INSERT INTO local_current_membership
|
||||
(room_id, user_id, event_id, membership)
|
||||
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
|
||||
@@ -1827,7 +1738,7 @@ class PersistEventsStore:
|
||||
"""
|
||||
|
||||
if events_and_contexts:
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
sql,
|
||||
(
|
||||
(
|
||||
@@ -1856,7 +1767,7 @@ class PersistEventsStore:
|
||||
|
||||
# Now we delete the staging area for *all* events that were being
|
||||
# persisted.
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
|
||||
((event.event_id,) for event, _ in all_events_and_contexts),
|
||||
)
|
||||
@@ -1975,7 +1886,7 @@ class PersistEventsStore:
|
||||
" )"
|
||||
)
|
||||
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
query,
|
||||
[
|
||||
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
||||
@@ -1989,7 +1900,7 @@ class PersistEventsStore:
|
||||
"DELETE FROM event_backward_extremities"
|
||||
" WHERE event_id = ? AND room_id = ?"
|
||||
)
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
query,
|
||||
[
|
||||
(ev.event_id, ev.room_id)
|
||||
|
||||
@@ -139,6 +139,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
|
||||
INSERT_CLUMP_SIZE = 1000
|
||||
|
||||
def reindex_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id, json FROM events"
|
||||
@@ -176,7 +178,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
|
||||
|
||||
txn.execute_batch(sql, update_rows)
|
||||
for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
|
||||
clump = update_rows[index : index + INSERT_CLUMP_SIZE]
|
||||
txn.executemany(sql, clump)
|
||||
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
@@ -206,6 +210,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
|
||||
INSERT_CLUMP_SIZE = 1000
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id FROM events"
|
||||
@@ -250,7 +256,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
|
||||
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
|
||||
|
||||
txn.execute_batch(sql, rows_to_update)
|
||||
for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
|
||||
clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
|
||||
txn.executemany(sql, clump)
|
||||
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
|
||||
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
" WHERE media_origin = ? AND media_id = ?"
|
||||
)
|
||||
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
sql,
|
||||
(
|
||||
(time_ms, media_origin, media_id)
|
||||
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
|
||||
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"update_cached_last_access_time", update_cache_txn
|
||||
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
|
||||
|
||||
def _delete_url_cache_txn(txn):
|
||||
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"delete_url_cache", _delete_url_cache_txn
|
||||
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
def _delete_url_cache_media_txn(txn):
|
||||
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
|
||||
|
||||
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
|
||||
|
||||
txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"delete_url_cache_media", _delete_url_cache_media_txn
|
||||
|
||||
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
# Update backward extremeties
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"INSERT INTO event_backward_extremities (room_id, event_id)"
|
||||
" VALUES (?, ?)",
|
||||
[(room_id, event_id) for event_id, in new_backwards_extrems],
|
||||
|
||||
@@ -1104,7 +1104,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||
FROM user_threepids
|
||||
"""
|
||||
|
||||
txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
|
||||
txn.executemany(sql, [(id_server,) for id_server in id_servers])
|
||||
|
||||
if id_servers:
|
||||
await self.db_pool.runInteraction(
|
||||
|
||||
@@ -873,6 +873,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
"max_stream_id_exclusive", self._stream_order_on_start + 1
|
||||
)
|
||||
|
||||
INSERT_CLUMP_SIZE = 1000
|
||||
|
||||
def add_membership_profile_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_ordering, event_id, events.room_id, event_json.json
|
||||
@@ -913,7 +915,9 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
|
||||
UPDATE room_memberships SET display_name = ?, avatar_url = ?
|
||||
WHERE event_id = ? AND room_id = ?
|
||||
"""
|
||||
txn.execute_batch(to_update_sql, to_update)
|
||||
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
|
||||
clump = to_update[index : index + INSERT_CLUMP_SIZE]
|
||||
txn.executemany(to_update_sql, clump)
|
||||
|
||||
progress = {
|
||||
"target_min_stream_id_inclusive": target_min_stream_id,
|
||||
|
||||
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
|
||||
# { "ignored_users": "@someone:example.org": {} }
|
||||
ignored_users = content.get("ignored_users", {})
|
||||
if isinstance(ignored_users, dict) and ignored_users:
|
||||
cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
|
||||
cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
|
||||
|
||||
# Add indexes after inserting data for efficiency.
|
||||
logger.info("Adding constraints to ignored_users table")
|
||||
|
||||
@@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore):
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
txn.execute_batch(sql, args)
|
||||
txn.executemany(sql, args)
|
||||
|
||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||
sql = (
|
||||
@@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore):
|
||||
for entry in entries
|
||||
)
|
||||
|
||||
txn.execute_batch(sql, args)
|
||||
txn.executemany(sql, args)
|
||||
else:
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
logger.info("[purge] removing redundant state groups")
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"DELETE FROM state_groups_state WHERE state_group = ?",
|
||||
((sg,) for sg in state_groups_to_delete),
|
||||
)
|
||||
txn.execute_batch(
|
||||
txn.executemany(
|
||||
"DELETE FROM state_groups WHERE id = ?",
|
||||
((sg,) for sg in state_groups_to_delete),
|
||||
)
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import Deque
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
@@ -100,13 +101,7 @@ class StreamIdGenerator:
|
||||
self._current = (max if step > 0 else min)(
|
||||
self._current, _load_current_id(db_conn, table, column, step)
|
||||
)
|
||||
|
||||
# We use this as an ordered set, as we want to efficiently append items,
|
||||
# remove items and get the first item. Since we insert IDs in order, the
|
||||
# insertion ordering will ensure its in the correct ordering.
|
||||
#
|
||||
# The key and values are the same, but we never look at the values.
|
||||
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
def get_next(self):
|
||||
"""
|
||||
@@ -118,7 +113,7 @@ class StreamIdGenerator:
|
||||
self._current += self._step
|
||||
next_id = self._current
|
||||
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextmanager
|
||||
def manager():
|
||||
@@ -126,7 +121,7 @@ class StreamIdGenerator:
|
||||
yield next_id
|
||||
finally:
|
||||
with self._lock:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
@@ -145,7 +140,7 @@ class StreamIdGenerator:
|
||||
self._current += n * self._step
|
||||
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids[next_id] = next_id
|
||||
self._unfinished_ids.append(next_id)
|
||||
|
||||
@contextmanager
|
||||
def manager():
|
||||
@@ -154,7 +149,7 @@ class StreamIdGenerator:
|
||||
finally:
|
||||
with self._lock:
|
||||
for next_id in next_ids:
|
||||
self._unfinished_ids.pop(next_id)
|
||||
self._unfinished_ids.remove(next_id)
|
||||
|
||||
return _AsyncCtxManagerWrapper(manager())
|
||||
|
||||
@@ -167,7 +162,7 @@ class StreamIdGenerator:
|
||||
"""
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
return next(iter(self._unfinished_ids)) - self._step
|
||||
return self._unfinished_ids[0] - self._step
|
||||
|
||||
return self._current
|
||||
|
||||
|
||||
@@ -69,11 +69,6 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
|
||||
"""Gets the next ID in the sequence"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||
"""Get the next `n` IDs in the sequence"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def check_consistency(
|
||||
self,
|
||||
@@ -224,17 +219,6 @@ class LocalSequenceGenerator(SequenceGenerator):
|
||||
self._current_max_id += 1
|
||||
return self._current_max_id
|
||||
|
||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||
with self._lock:
|
||||
if self._current_max_id is None:
|
||||
assert self._callback is not None
|
||||
self._current_max_id = self._callback(txn)
|
||||
self._callback = None
|
||||
|
||||
first_id = self._current_max_id + 1
|
||||
self._current_max_id += n
|
||||
return [first_id + i for i in range(n)]
|
||||
|
||||
def check_consistency(
|
||||
self,
|
||||
db_conn: Connection,
|
||||
|
||||
@@ -78,7 +78,7 @@ def sorted_topologically(
|
||||
if node not in degree_map:
|
||||
continue
|
||||
|
||||
for edge in edges:
|
||||
for edge in set(edges):
|
||||
if edge in degree_map:
|
||||
degree_map[node] += 1
|
||||
|
||||
|
||||
@@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.json_body["total"], 3)
|
||||
|
||||
def test_room_state(self):
|
||||
"""Test that room state can be requested correctly"""
|
||||
# Create two test rooms
|
||||
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
|
||||
|
||||
url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,)
|
||||
channel = self.make_request(
|
||||
"GET", url.encode("ascii"), access_token=self.admin_user_tok,
|
||||
)
|
||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||
self.assertIn("state", channel.json_body)
|
||||
# testing that the state events match is painful and not done here. We assume that
|
||||
# the create_room already does the right thing, so no need to verify that we got
|
||||
# the state events it created.
|
||||
|
||||
|
||||
class JoinAliasRoomTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.rest.client.v1 import login, logout, profile, room
|
||||
from synapse.rest.client.v2_alpha import devices, sync
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@@ -469,6 +468,13 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
|
||||
self.user1 = self.register_user(
|
||||
"user1", "pass1", admin=False, displayname="Name 1"
|
||||
)
|
||||
self.user2 = self.register_user(
|
||||
"user2", "pass2", admin=False, displayname="Name 2"
|
||||
)
|
||||
|
||||
def test_no_auth(self):
|
||||
"""
|
||||
Try to list users without authentication.
|
||||
@@ -482,7 +488,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
If the user is not a server admin, an error is returned.
|
||||
"""
|
||||
self._create_users(1)
|
||||
other_user_token = self.login("user1", "pass1")
|
||||
|
||||
channel = self.make_request("GET", self.url, access_token=other_user_token)
|
||||
@@ -494,8 +499,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
List all users, including deactivated users.
|
||||
"""
|
||||
self._create_users(2)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
self.url + "?deactivated=true",
|
||||
@@ -508,7 +511,14 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(3, channel.json_body["total"])
|
||||
|
||||
# Check that all fields are available
|
||||
self._check_fields(channel.json_body["users"])
|
||||
for u in channel.json_body["users"]:
|
||||
self.assertIn("name", u)
|
||||
self.assertIn("is_guest", u)
|
||||
self.assertIn("admin", u)
|
||||
self.assertIn("user_type", u)
|
||||
self.assertIn("deactivated", u)
|
||||
self.assertIn("displayname", u)
|
||||
self.assertIn("avatar_url", u)
|
||||
|
||||
def test_search_term(self):
|
||||
"""Test that searching for a users works correctly"""
|
||||
@@ -539,7 +549,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Check that users were returned
|
||||
self.assertTrue("users" in channel.json_body)
|
||||
self._check_fields(channel.json_body["users"])
|
||||
users = channel.json_body["users"]
|
||||
|
||||
# Check that the expected number of users were returned
|
||||
@@ -552,30 +561,25 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
u = users[0]
|
||||
self.assertEqual(expected_user_id, u["name"])
|
||||
|
||||
self._create_users(2)
|
||||
|
||||
user1 = "@user1:test"
|
||||
user2 = "@user2:test"
|
||||
|
||||
# Perform search tests
|
||||
_search_test(user1, "er1")
|
||||
_search_test(user1, "me 1")
|
||||
_search_test(self.user1, "er1")
|
||||
_search_test(self.user1, "me 1")
|
||||
|
||||
_search_test(user2, "er2")
|
||||
_search_test(user2, "me 2")
|
||||
_search_test(self.user2, "er2")
|
||||
_search_test(self.user2, "me 2")
|
||||
|
||||
_search_test(user1, "er1", "user_id")
|
||||
_search_test(user2, "er2", "user_id")
|
||||
_search_test(self.user1, "er1", "user_id")
|
||||
_search_test(self.user2, "er2", "user_id")
|
||||
|
||||
# Test case insensitive
|
||||
_search_test(user1, "ER1")
|
||||
_search_test(user1, "NAME 1")
|
||||
_search_test(self.user1, "ER1")
|
||||
_search_test(self.user1, "NAME 1")
|
||||
|
||||
_search_test(user2, "ER2")
|
||||
_search_test(user2, "NAME 2")
|
||||
_search_test(self.user2, "ER2")
|
||||
_search_test(self.user2, "NAME 2")
|
||||
|
||||
_search_test(user1, "ER1", "user_id")
|
||||
_search_test(user2, "ER2", "user_id")
|
||||
_search_test(self.user1, "ER1", "user_id")
|
||||
_search_test(self.user2, "ER2", "user_id")
|
||||
|
||||
_search_test(None, "foo")
|
||||
_search_test(None, "bar")
|
||||
@@ -583,179 +587,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
_search_test(None, "foo", "user_id")
|
||||
_search_test(None, "bar", "user_id")
|
||||
|
||||
def test_invalid_parameter(self):
|
||||
"""
|
||||
If parameters are invalid, an error is returned.
|
||||
"""
|
||||
|
||||
# negative limit
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# negative from
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# invalid guests
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
|
||||
# invalid deactivated
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
|
||||
def test_limit(self):
|
||||
"""
|
||||
Testing list of users with limit
|
||||
"""
|
||||
|
||||
number_users = 20
|
||||
# Create one less user (since there's already an admin user).
|
||||
self._create_users(number_users - 1)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], number_users)
|
||||
self.assertEqual(len(channel.json_body["users"]), 5)
|
||||
self.assertEqual(channel.json_body["next_token"], "5")
|
||||
self._check_fields(channel.json_body["users"])
|
||||
|
||||
def test_from(self):
|
||||
"""
|
||||
Testing list of users with a defined starting point (from)
|
||||
"""
|
||||
|
||||
number_users = 20
|
||||
# Create one less user (since there's already an admin user).
|
||||
self._create_users(number_users - 1)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], number_users)
|
||||
self.assertEqual(len(channel.json_body["users"]), 15)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
self._check_fields(channel.json_body["users"])
|
||||
|
||||
def test_limit_and_from(self):
|
||||
"""
|
||||
Testing list of users with a defined starting point and limit
|
||||
"""
|
||||
|
||||
number_users = 20
|
||||
# Create one less user (since there's already an admin user).
|
||||
self._create_users(number_users - 1)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], number_users)
|
||||
self.assertEqual(channel.json_body["next_token"], "15")
|
||||
self.assertEqual(len(channel.json_body["users"]), 10)
|
||||
self._check_fields(channel.json_body["users"])
|
||||
|
||||
def test_next_token(self):
|
||||
"""
|
||||
Testing that `next_token` appears at the right place
|
||||
"""
|
||||
|
||||
number_users = 20
|
||||
# Create one less user (since there's already an admin user).
|
||||
self._create_users(number_users - 1)
|
||||
|
||||
# `next_token` does not appear
|
||||
# Number of results is the number of entries
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], number_users)
|
||||
self.assertEqual(len(channel.json_body["users"]), number_users)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
|
||||
# `next_token` does not appear
|
||||
# Number of max results is larger than the number of entries
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], number_users)
|
||||
self.assertEqual(len(channel.json_body["users"]), number_users)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
|
||||
# `next_token` does appear
|
||||
# Number of max results is smaller than the number of entries
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], number_users)
|
||||
self.assertEqual(len(channel.json_body["users"]), 19)
|
||||
self.assertEqual(channel.json_body["next_token"], "19")
|
||||
|
||||
# Check
|
||||
# Set `from` to value of `next_token` for request remaining entries
|
||||
# `next_token` does not appear
|
||||
channel = self.make_request(
|
||||
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(channel.json_body["total"], number_users)
|
||||
self.assertEqual(len(channel.json_body["users"]), 1)
|
||||
self.assertNotIn("next_token", channel.json_body)
|
||||
|
||||
def _check_fields(self, content: JsonDict):
|
||||
"""Checks that the expected user attributes are present in content
|
||||
Args:
|
||||
content: List that is checked for content
|
||||
"""
|
||||
for u in content:
|
||||
self.assertIn("name", u)
|
||||
self.assertIn("is_guest", u)
|
||||
self.assertIn("admin", u)
|
||||
self.assertIn("user_type", u)
|
||||
self.assertIn("deactivated", u)
|
||||
self.assertIn("displayname", u)
|
||||
self.assertIn("avatar_url", u)
|
||||
|
||||
def _create_users(self, number_users: int):
|
||||
"""
|
||||
Create a number of users
|
||||
Args:
|
||||
number_users: Number of users to be created
|
||||
"""
|
||||
for i in range(1, number_users + 1):
|
||||
self.register_user(
|
||||
"user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
|
||||
)
|
||||
|
||||
|
||||
class DeactivateAccountTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
||||
@@ -202,6 +202,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
|
||||
config = self.default_config()
|
||||
config["media_store_path"] = self.media_store_path
|
||||
config["thumbnail_requirements"] = {}
|
||||
config["max_image_pixels"] = 2000000
|
||||
|
||||
provider_config = {
|
||||
@@ -312,39 +313,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
|
||||
|
||||
def test_thumbnail_crop(self):
|
||||
"""Test that a cropped remote thumbnail is available."""
|
||||
self._test_thumbnail(
|
||||
"crop", self.test_image.expected_cropped, self.test_image.expected_found
|
||||
)
|
||||
|
||||
def test_thumbnail_scale(self):
|
||||
"""Test that a scaled remote thumbnail is available."""
|
||||
self._test_thumbnail(
|
||||
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
||||
)
|
||||
|
||||
def test_invalid_type(self):
|
||||
"""An invalid thumbnail type is never available."""
|
||||
self._test_thumbnail("invalid", None, False)
|
||||
|
||||
@unittest.override_config(
|
||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
|
||||
)
|
||||
def test_no_thumbnail_crop(self):
|
||||
"""
|
||||
Override the config to generate only scaled thumbnails, but request a cropped one.
|
||||
"""
|
||||
self._test_thumbnail("crop", None, False)
|
||||
|
||||
@unittest.override_config(
|
||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
|
||||
)
|
||||
def test_no_thumbnail_scale(self):
|
||||
"""
|
||||
Override the config to generate only cropped thumbnails, but request a scaled one.
|
||||
"""
|
||||
self._test_thumbnail("scale", None, False)
|
||||
|
||||
def _test_thumbnail(self, method, expected_body, expected_found):
|
||||
params = "?width=32&height=32&method=" + method
|
||||
channel = make_request(
|
||||
|
||||
@@ -92,3 +92,15 @@ class SortTopologically(TestCase):
|
||||
# Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
|
||||
# always get the same one.
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_duplicates(self):
|
||||
"Test that a graph with duplicate edges work"
|
||||
graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]]
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_multiple_paths(self):
|
||||
"Test that a graph with multiple paths between two nodes work"
|
||||
graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]]
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
11
tox.ini
11
tox.ini
@@ -24,8 +24,7 @@ deps =
|
||||
# install the "enum34" dependency of cryptography.
|
||||
pip>=10
|
||||
|
||||
# directories/files we run the linters on.
|
||||
# if you update this list, make sure to do the same in scripts-dev/lint.sh
|
||||
# directories/files we run the linters on
|
||||
lint_targets =
|
||||
setup.py
|
||||
synapse
|
||||
@@ -101,7 +100,7 @@ usedevelop=true
|
||||
|
||||
# A test suite for the oldest supported versions of Python libraries, to catch
|
||||
# any uses of APIs not available in them.
|
||||
[testenv:py35-{old,old-postgres}]
|
||||
[testenv:py35-old]
|
||||
skip_install=True
|
||||
deps =
|
||||
# Ensure a version of setuptools that supports Python 3.5 is installed.
|
||||
@@ -114,17 +113,11 @@ deps =
|
||||
coverage
|
||||
coverage-enable-subprocess==1.0
|
||||
|
||||
setenv =
|
||||
postgres: SYNAPSE_POSTGRES = 1
|
||||
|
||||
commands =
|
||||
# Make all greater-thans equals so we test the oldest version of our direct
|
||||
# dependencies, but make the pyopenssl 17.0, which can work against an
|
||||
# OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis).
|
||||
#
|
||||
# Also strip out psycopg2 unless we need it.
|
||||
/bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
|
||||
postgres: /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" | grep -F "psycopg2" | xargs -d"\n" pip install'
|
||||
|
||||
# Install Synapse itself. This won't update any libraries.
|
||||
pip install -e ".[test]"
|
||||
|
||||
Reference in New Issue
Block a user